00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00040
#ifndef GaussMix_INC
00041
#define GaussMix_INC
00042
00043
#include "PDistribution.h"
00044
00045
namespace PLearn {
00046
using namespace std;
00047
00048 class GaussMix:
public PDistribution
00049 {
00050
00051
private:
00052
00053 typedef PDistribution inherited;
00054
00056 Vec sample_row,
log_likelihood_post;
00057 mutable Vec x_minus_mu_x,
mu_target,
log_likelihood_dens;
00058
00059
protected:
00060
00061
00062
00063
00064
00065 Mat eigenvalues;
00066 TVec<Mat> eigenvectors;
00067 int D;
00068 Mat diags;
00069 Vec log_coeff;
00070 Vec log_p_j_x;
00071 Vec log_p_x_j_alphaj;
00072 Mat mu_y_x;
00073 int n_eigen_computed;
00074 int nsamples;
00075 Vec p_j_x;
00076
00077
00078
00079 TVec<Mat> cov_x;
00080 TVec<Mat> cov_y_x;
00081 Mat eigenvalues_x;
00082 Mat eigenvalues_y_x;
00083 TVec<Mat> eigenvectors_x;
00084 TVec<Mat> eigenvectors_y_x;
00085 TVec<Mat> full_cov;
00086 TVec<Mat> y_x_mat;
00087
00090 Mat posteriors;
00091
00094 Vec initial_weights;
00095
00098 Mat updated_weights;
00099
00100
public:
00101
00102
00103
00104
00105
00106 real alpha_min;
00107 real epsilon;
00108 int kmeans_iterations;
00109 int L;
00110 int n_eigen;
00111 real sigma_min;
00112 string type;
00113
00114
00115 Vec alpha;
00116 Mat mu;
00117 Vec sigma;
00118
00119
00120
00121
00122
00124
GaussMix();
00125
00126
00127
00128
00129
00130
protected:
00131
00133
virtual void computeMeansAndCovariances();
00134
00138
virtual real computeLogLikelihood(
const Vec& y,
int j,
bool is_input =
false)
const;
00139
00141
virtual void computePosteriors();
00142
00146
virtual bool computeWeights();
00147
00150
virtual void generateFromGaussian(
Vec& s,
int given_gaussian)
const;
00151
00153
virtual void precomputeStuff();
00154
00158
virtual void replaceGaussian(
int j);
00159
00161
void resizeStuffBeforeTraining();
00162
00165
void updateSampleWeights();
00166
00167
public:
00168
00170
virtual void generate(
Vec& s)
const;
00171
00172
virtual void resetGenerator(
long g_seed)
const;
00173
00174
private:
00175
00177
void build_();
00178
00179
protected:
00180
00182
static void declareOptions(
OptionList& ol);
00183
00185
void kmeans(
VMat samples,
int nclust,
TVec<int> & clust_idx,
Mat & clust,
int maxit=9999);
00186
00187
public:
00188
00191
virtual void forget();
00192
00193
00194
virtual void build();
00195
00197
virtual void makeDeepCopyFromShallowCopy(map<const void*, void*>& copies);
00198
00200
PLEARN_DECLARE_OBJECT(
GaussMix);
00201
00202
00203
00204
00205
00207
virtual void train();
00208
00209
00210
00211
00212
00214
virtual void setInput(
const Vec& input)
const;
00215
00218
virtual void updateFromConditionalSorting();
00219
00221
virtual real log_density(
const Vec& y)
const;
00222
00224
virtual real survival_fn(
const Vec& y)
const;
00225
00227
virtual real cdf(
const Vec& y)
const;
00228
00230
virtual void expectation(
Vec& mu)
const;
00231
00233
virtual void variance(
Mat& cov)
const;
00234
00236
int getNEigenComputed() const;
00237
Mat getEigenvectors(
int j) const;
00238
Vec getEigenvals(
int j) const;
00239
00240 };
00241
00242
00243 DECLARE_OBJECT_PTR(
GaussMix);
00244
00245 }
00246
00247 #endif