00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00331
#ifndef RANDOMVAR_INC
00332
#define RANDOMVAR_INC
00333
00334
#include <plearn/opt/Optimizer.h>
00335
00338
#include "SampleVariable.h"
00339
00341
#include <plearn/vmat/VMat.h>
00342
00343
namespace PLearn {
00344
using namespace std;
00345
00346
00347
class RandomVariable;
00348
class RVArray;
00349
class RVInstance;
00350
class RVInstanceArray;
00351
class ConditionalExpression;
00352
00354 class RandomVar:
public PP<RandomVariable>
00355 {
00356
public:
00357
RandomVar();
00358
RandomVar(
int length,
int width=1);
00359
RandomVar(
RandomVariable* v);
00360
RandomVar(
const RandomVar& other);
00361
00362
RandomVar(
const Vec& vec);
00363
RandomVar(
const Mat& mat);
00364
RandomVar(
const Var&
var);
00365
00366
RandomVar(
const RVArray& vars);
00367
00368
RandomVar operator[](
RandomVar index);
00369
#if 0
00370
RandomVar operator[](
int i);
00371
#endif
00372
00373
void operator=(
const RVArray& vars);
00374
00376
void operator=(
real f);
00377
void operator=(
const Vec& v);
00378
void operator=(
const Mat& m);
00379
void operator=(
const Var& v);
00380
00383
RVInstance operator==(
const Var& v)
const;
00384
00386 bool operator==(
const RandomVar& rv)
const {
return rv.
ptr == this->ptr; }
00387 bool operator!=(
const RandomVar& rv)
const {
return rv.
ptr != this->ptr; }
00388
00390
RVArray operator&(
const RandomVar& v)
const;
00391
00401
ConditionalExpression operator|(
RVArray rhs)
const;
00402
00404
ConditionalExpression operator|(
RVInstanceArray rhs)
const;
00405
00406
#if 0
00409
RandomVar operator[](RandomVar index);
00410
RandomVar operator[](
int i);
00412
RandomVar operator()(
RandomVar i,
RandomVar j);
00413
RandomVar operator()(
int i,
int j);
00414
#endif
00415
00416 };
00417
00418 typedef RandomVar MatRandomVar;
00419
00420
00422 class RVArray:
public Array<RandomVar>
00423 {
00424
public:
00425
RVArray();
00426
RVArray(
int n,
int n_extra_allocated=0);
00427
RVArray(
const Array<RandomVar>& va);
00428
RVArray(
const RandomVar& v,
int n_extra_allocated=0);
00429
RVArray(
const RandomVar& v1,
const RandomVar& v2,
int n_extra_allocated=0);
00430
RVArray(
const RandomVar& v1,
const RandomVar& v2,
const RandomVar& v3,
00431
int n_extra_allocated=0);
00432
00433
int length()
const;
00434
00436
VarArray values()
const;
00437
00440
00443
00445
RandomVar operator[](
RandomVar index);
00446
00447 RandomVar&
operator[](
int i)
00448 {
return Array<RandomVar>::operator[](i); }
00449
00450 const RandomVar&
operator[](
int i)
const
00451
{
return Array<RandomVar>::operator[](i); }
00452
00453
static int compareRVnumbers(
const RandomVar* v1,
const RandomVar* v2);
00454
00457
void sort();
00458 };
00459
00460
00462 class RVInstance
00463 {
00464
public:
00465 RandomVar V;
00466 Var v;
00467
00468
RVInstance(
const RandomVar& VV,
const Var& vv);
00469
RVInstance();
00470
00471
RVInstanceArray operator&&(
RVInstance rvi);
00472
00473
ConditionalExpression operator|(
RVInstanceArray a);
00474
00476
void swap_v_and_Vvalue();
00477
00478 };
00479
00480 class RVInstanceArray:
public Array<RVInstance>
00481 {
00482
public:
00483
RVInstanceArray();
00484
RVInstanceArray(
int n,
int n_extra_allocated=0);
00485
RVInstanceArray(
const Array<RVInstance>& a);
00486
RVInstanceArray(
const RVInstance& v,
int n_extra_allocated=0);
00487
RVInstanceArray(
const RVInstance& v1,
const RVInstance& v2,
00488
int n_extra_allocated=0);
00489
RVInstanceArray(
const RVInstance& v1,
const RVInstance& v2,
00490
const RVInstance& v3,
int n_extra_allocated=0);
00491
00493
int length()
const;
00494
00497
RVInstanceArray operator&&(
RVInstance rhs);
00498
00506
ConditionalExpression operator|(
RVInstanceArray rhs);
00507
00509
RVArray random_variables()
const;
00510
00512
VarArray values()
const;
00514
VarArray instances()
const;
00515
00517 void swap_v_and_Vvalue()
00518 {
for (
int i=0;i<
size();i++) (*this)[i].swap_v_and_Vvalue(); }
00519
00520
static int compareRVnumbers(
const RVInstance* rvi1,
const RVInstance* rvi2);
00521
00524
void sort();
00525
00526 };
00527
00528 class ConditionalExpression
00529 {
00530
public:
00531 RVInstance LHS;
00532 RVInstanceArray RHS;
00533
00535
ConditionalExpression(
RVInstance lhs,
RVInstanceArray rhs);
00537
ConditionalExpression(
RVInstance lhs);
00539
ConditionalExpression(
RandomVar lhs);
00542
ConditionalExpression(
RVInstanceArray lhs);
00543 };
00544
00545 class RandomVariable:
public PPointable
00546 {
00547
friend class RandomVar;
00548
friend class RVInstanceArray;
00549
friend class RVArray;
00550
00551
static int rv_counter;
00552
00553
00554
protected:
00555
00558 const int rv_number;
00559
00560
public:
00561 const RVArray parents;
00562
00567 Var value;
00568
00569
protected:
00574 bool marked;
00575
00576 bool EMmark;
00577 bool pmark;
00578
00581 bool*
learn_the_parameters;
00582
00583
public:
00585
RandomVariable(
int thelength,
int thewidth=1);
00586
RandomVariable(
const Vec& the_value);
00587
RandomVariable(
const Mat& the_value);
00588
RandomVariable(
const Var& the_value);
00590
RandomVariable(
const RVArray& parents,
int thelength);
00591
RandomVariable(
const RVArray& parents,
int thelength,
int thewidth);
00592
00593
virtual char*
classname() = 0;
00594
00595 virtual int length() {
return value->
length(); }
00596 virtual int width() {
return value->
width(); }
00597 int nelems() {
return value->nelems(); }
00598 bool isScalar() {
return length()==1 &&
width()==1; }
00599 bool isVec() {
return width()==1 ||
length()==1; }
00600 bool isColumnVec() {
return width()==1; }
00601 bool isRowVec() {
return length()==1; }
00602
00609
virtual bool isNonRandom() = 0;
00612 inline bool isConstant() {
return isNonRandom() &&
value->isConstant(); }
00613
00616
virtual bool isDiscrete() = 0;
00617
00622
RandomVar subVec(
int start,
int length);
00623
00625
00632
00633
00638
virtual void setValueFromParentsValue() = 0;
00639
00640 void markRHSandSetKnownValues(
const RVInstanceArray& RHS)
00641 {
00642
for (
int i=0;i<RHS.
size();i++)
00643 RHS[i].V->mark(RHS[i].v);
00644
setKnownValues();
00645 }
00646
00654
virtual void EMBprop(
const Vec obs,
real posterior) = 0;
00655
00660
virtual void EMUpdate();
00661
00668
virtual bool canStopEM();
00669
00672
virtual void EMTrainingInitialize(
const RVArray& parameters_to_learn);
00673
00676
virtual void EMEpochInitialize();
00677
00681
00682 virtual void mark(
Var v) {
marked =
true;
value = v; }
00683 virtual void mark() {
marked =
true; }
00684 virtual void unmark() {
marked =
false; }
00685
virtual void clearEMmarks();
00686
00688
virtual void unmarkAncestors();
00689
00690 virtual bool isMarked() {
return marked; }
00691
00696
virtual void setKnownValues();
00697
00703
virtual Var logP(
const Var& obs,
const RVInstanceArray& RHS,
00704
RVInstanceArray* parameters_to_learn=0) = 0;
00705
virtual Var P(
const Var& obs,
const RVInstanceArray& RHS);
00706
00712
virtual Var ElogP(
const Var& obs,
RVArray& parameters_to_learn,
00713
const RVInstanceArray& RHS);
00714
00728
00729
00734
00735
virtual real EM(
const RVArray& parameters_to_learn,
00736
VarArray& prop_path,
00737
VarArray& observedVars,
00738
VMat distr,
int n_samples,
00739
int max_n_iterations,
00740
real relative_improvement_threshold,
00741
bool accept_worsening_likelihood=
false);
00742
00750
virtual real epoch(
VarArray& prop_path,
00751
VarArray& observed_vars,
const VMat& distr,
00752
int n_samples,
00753
bool do_EM_learning=
true);
00754
00755
virtual ~RandomVariable();
00756
00757 };
00758
00759
00761
00771 RandomVar
operator*(RandomVar a, RandomVar b);
00772
00779 RandomVar
operator+(RandomVar a, RandomVar b);
00780
00783 RandomVar
operator-(RandomVar a, RandomVar b);
00784
00787 RandomVar
operator/(RandomVar a, RandomVar b);
00788
00790 RandomVar
exp(RandomVar x);
00791
00793 RandomVar
log(RandomVar x);
00794
00795 RandomVar
extend(RandomVar v,
real extension_value = 1.0,
int n_extend = 1);
00796
00797 RandomVar
hconcat(
const RVArray& a);
00798
00813
00814
00819
00820
00821
real EM(ConditionalExpression conditional_expression,
00822 RVArray parameters_to_learn,
00823 VMat distr,
int n_samples,
int max_n_iterations=1,
00824
real relative_improvement_threshold=0.001,
00825
bool accept_worsening_likelihood=
false,
00826
bool compute_final_train_NLL=
true);
00827
00828
real oEM(ConditionalExpression conditional_expression,
00829 RVArray parameters_to_learn,
00830 VMat distr,
int n_samples,
int max_n_iterations,
00831
real relative_improvement_threshold=0.001,
00832
bool compute_final_train_NLL=
true);
00833
00834
real oEM(ConditionalExpression conditional_expression,
00835 RVArray parameters_to_learn,
00836 VMat distr,
int n_samples,
00837 Optimizer& MStepOptimizer,
00838
int max_n_iterations,
00839
real relative_improvement_threshold=0.001,
00840
bool compute_final_train_NLL=
true);
00841
00851 Var
logP(ConditionalExpression conditional_expression,
00852
bool clearMarksUponReturn=
true,
00853 RVInstanceArray* parameters_to_learn=0);
00854
00864 Var
P(ConditionalExpression conditional_expression,
00865
bool clearMarksUponReturn=
true);
00866
00872 Var
ElogP(ConditionalExpression conditional_expression,
00873 RVInstanceArray& parameters_to_learn,
00874
bool clearMarksUponReturn=
true);
00875
00880 RandomVar
marginalize(
const RandomVar& RV,
const RandomVar& hiddenRV);
00891 Vec
sample(ConditionalExpression conditional_expression);
00892
00897 Var
Sample(ConditionalExpression conditional_expression);
00898
00906
void sample(ConditionalExpression conditional_expression,Mat& samples);
00907
00909
00914 RandomVar
normal(
real mean=0,
real standard_dev=1,
int d=1,
00915
real minimum_standard_deviation=1e-6);
00916
00921 RandomVar
normal(RandomVar mean, RandomVar log_variance,
00922
real minimum_standard_deviation=1e-6);
00923
00929 RandomVar
mixture(RVArray components, RandomVar log_weights);
00930
00937 RandomVar
multinomial(RandomVar log_probabilities);
00938
00961 class StochasticRandomVariable:
public RandomVariable
00962 {
00963
public:
00964
StochasticRandomVariable(
int length=1);
00965
StochasticRandomVariable(
const RVArray& params,
int length);
00966
StochasticRandomVariable(
const RVArray& params,
int length,
int width);
00967
00969
00971 virtual bool isNonRandom() {
return false; }
00972
00974 virtual bool isDiscrete() {
return false; }
00975
00976
virtual void setKnownValues();
00977
00981
01000 };
01001
01006 class FunctionalRandomVariable:
public RandomVariable {
01007
public:
01008
FunctionalRandomVariable(
int length);
01009
FunctionalRandomVariable(
int length,
int width);
01010
FunctionalRandomVariable(
const Vec& the_value);
01011
FunctionalRandomVariable(
const Mat& the_value);
01012
FunctionalRandomVariable(
const Var& the_value);
01013
FunctionalRandomVariable(
const RVArray& parents,
int length);
01014
FunctionalRandomVariable(
const RVArray& parents,
int length,
int width);
01015
01017
01018
virtual Var logP(
const Var& obs,
const RVInstanceArray& RHS,
01019
RVInstanceArray* parameters_to_learn);
01020
01022
bool isNonRandom();
01023
01025
virtual bool isDiscrete();
01026
01028
01032
01047
virtual bool invertible(
const Var& obs,
01048
RVInstanceArray& unobserved_parents,
01049
Var** JacobianCorrection);
01050
01052
virtual void setValueFromParentsValue() = 0;
01053
01057
01080 };
01081
01087 class NonRandomVariable:
public FunctionalRandomVariable
01088 {
01089
public:
01092
NonRandomVariable(
int thelength);
01093
NonRandomVariable(
int thelength,
int thewidth);
01098
NonRandomVariable(
const Var& v);
01099
01100 virtual char*
classname() {
return "NonRandomVariable"; }
01101
01102 void setValueFromParentsValue() { }
01103 bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01104
Var** JacobianCorrection)
01105 {
return true; }
01106 void EMBprop(
const Vec obs,
real post) { }
01107 };
01108
01109 class JointRandomVariable:
public FunctionalRandomVariable
01110 {
01111
public:
01112
JointRandomVariable(
const RVArray& variables);
01113
01114 virtual char*
classname() {
return "JointRandomVariable"; }
01115
01116
void setValueFromParentsValue();
01117
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01118
Var** JacobianCorrection);
01119
void EMBprop(
const Vec obs,
real post);
01120 };
01121
01124 class RandomElementOfRandomVariable:
public FunctionalRandomVariable
01125 {
01126
public:
01127
RandomElementOfRandomVariable(
const RandomVar&
v,
const RandomVar&
index);
01128
01129 virtual char*
classname() {
return "RandomElementOfRandomVariable"; }
01130
01131
void setValueFromParentsValue();
01132
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01133
Var** JacobianCorrection);
01134
void EMBprop(
const Vec obs,
real post);
01135
01137 inline const RandomVar&
v() {
return parents[0]; }
01138 inline const RandomVar&
index() {
return parents[1]; }
01139
01140 };
01141
01142
01148 class RVArrayRandomElementRandomVariable:
public FunctionalRandomVariable
01149 {
01150
public:
01151
RVArrayRandomElementRandomVariable(
const RVArray& table,
const RandomVar&
index);
01152
01153 virtual char*
classname() {
return "RVArrayRandomElementRandomVariable"; }
01154
01155
void setValueFromParentsValue();
01156
virtual Var logP(
const Var& obs,
const RVInstanceArray& RHS,
01157
RVInstanceArray* parameters_to_learn=0);
01158
void EMBprop(
const Vec obs,
real post);
01159
01161 inline const RandomVar&
index() {
return parents[parents.
size()-1]; }
01162
01163 };
01164
01165 class NegRandomVariable:
public FunctionalRandomVariable
01166 {
01167
public:
01168
NegRandomVariable(
RandomVariable* input);
01169
01170 virtual char*
classname() {
return "NegRandomVariable"; }
01171
01172
void setValueFromParentsValue();
01173
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01174
Var** JacobianCorrection);
01175
void EMBprop(
const Vec obs,
real post);
01176 };
01177
01178 class ExpRandomVariable:
public FunctionalRandomVariable
01179 {
01180
public:
01181
ExpRandomVariable(
RandomVar& input);
01182
01183 virtual char*
classname() {
return "ExpRandomVariable"; }
01184
01185
void setValueFromParentsValue();
01186
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01187
Var** JacobianCorrection);
01188
void EMBprop(
const Vec obs,
real post);
01189 };
01190
01191
01192 class LogRandomVariable:
public FunctionalRandomVariable
01193 {
01194
public:
01195
LogRandomVariable(
RandomVar& input);
01196
01197 virtual char*
classname() {
return "LogRandomVariable"; }
01198
01199
void setValueFromParentsValue();
01200
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01201
Var** JacobianCorrection);
01202
void EMBprop(
const Vec obs,
real post);
01203 };
01204
01205 class DiagonalNormalRandomVariable:
public StochasticRandomVariable
01206 {
01211
protected:
01212 real minimum_variance;
01213 real normfactor;
01214 bool shared_variance;
01215
01216
public:
01217
DiagonalNormalRandomVariable(
const RandomVar& mean,
01218
const RandomVar& log_variance,
01219
real minimum_standard_deviation = 1e-10);
01220
01221 virtual char*
classname() {
return "DiagonalNormalRandomVariable"; }
01222
01223
Var logP(
const Var& obs,
const RVInstanceArray& RHS,
01224
RVInstanceArray* parameters_to_learn);
01225
void setValueFromParentsValue();
01226
void EMUpdate();
01227
void EMBprop(
const Vec obs,
real posterior);
01228
void EMEpochInitialize();
01229
01231 inline const RandomVar&
mean() {
return parents[0]; }
01232 inline const RandomVar&
log_variance() {
return parents[1]; }
01233 inline bool&
learn_the_mean() {
return learn_the_parameters[0]; }
01234 inline bool&
learn_the_variance() {
return learn_the_parameters[1]; }
01235
01236
protected:
01238 Vec mu_num;
01239 Vec sigma_num;
01240 real denom;
01241 };
01242
01243 class MixtureRandomVariable:
public StochasticRandomVariable
01244 {
01245
protected:
01247 RVArray components;
01248
01249
public:
01250
MixtureRandomVariable(
const RVArray& components,
01251
const RandomVar& log_weights);
01252
01253 virtual char*
classname() {
return "MixtureRandomVariable"; }
01254
01257 inline const RandomVar&
log_weights() {
return parents[0]; }
01258 inline bool&
learn_the_weights() {
return learn_the_parameters[0]; }
01259
01260
virtual Var logP(
const Var& obs,
const RVInstanceArray& RHS,
01261
RVInstanceArray* parameters_to_learn);
01262
virtual Var ElogP(
const Var& obs,
RVInstanceArray& parameters_to_learn,
01263
const RVInstanceArray& RHS);
01264
01265
virtual void setValueFromParentsValue();
01266
virtual void EMUpdate();
01267
virtual void EMBprop(
const Vec obs,
real posterior);
01268
virtual void EMEpochInitialize();
01269
virtual void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01270
virtual bool isDiscrete();
01271
virtual bool canStopEM();
01272
virtual void setKnownValues();
01273
virtual void unmarkAncestors();
01274
virtual void clearEMmarks();
01275
01276
protected:
01278 Vec posteriors;
01279 Vec sum_posteriors;
01280
01282 VarArray componentsLogP;
01283 Var lw;
01284 Var logp;
01285 };
01286
01293 class PlusRandomVariable:
public FunctionalRandomVariable
01294 {
01295
public:
01296
PlusRandomVariable(
RandomVar input1,
RandomVar input2);
01297
01298 virtual char*
classname() {
return "PlusRandomVariable"; }
01299
01300
void setValueFromParentsValue();
01301
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01302
Var** JacobianCorrection);
01303
void EMBprop(
const Vec obs,
real post);
01304
void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01305
void EMEpochInitialize();
01306
void EMUpdate();
01307
01309 const RandomVar&
X0() {
return parents[0]; }
01310 const RandomVar&
X1() {
return parents[1]; }
01311
01313 bool learn_X0() {
return learn_the_parameters[0]; }
01314 bool learn_X1() {
return learn_the_parameters[1]; }
01315 bool learn_something;
01316 RandomVar parent_to_learn;
01317 RandomVar other_parent;
01318 Vec numerator;
01319 Vec difference;
01320 real denom;
01321 };
01322
01329 class MinusRandomVariable:
public FunctionalRandomVariable
01330 {
01331
public:
01332
MinusRandomVariable(
RandomVar input1,
RandomVar input2);
01333
01334 virtual char*
classname() {
return "MinusRandomVariable"; }
01335
01336
void setValueFromParentsValue();
01337
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01338
Var** JacobianCorrection);
01339
void EMBprop(
const Vec obs,
real post);
01340
void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01341
void EMEpochInitialize();
01342
void EMUpdate();
01343
01345 const RandomVar&
X0() {
return parents[0]; }
01346 const RandomVar&
X1() {
return parents[1]; }
01347
01349 bool learn_X0() {
return learn_the_parameters[0]; }
01350 bool learn_X1() {
return learn_the_parameters[1]; }
01351 bool learn_something;
01352 RandomVar parent_to_learn;
01353 RandomVar other_parent;
01354 Vec numerator;
01355 Vec difference;
01356 real denom;
01357 };
01358
01359
01364 class ElementWiseDivisionRandomVariable:
public FunctionalRandomVariable
01365 {
01366
public:
01367
ElementWiseDivisionRandomVariable(
RandomVar input1,
RandomVar input2);
01368
01369 virtual char*
classname() {
return "ElementWiseDivisionRandomVariable"; }
01370
01371
void setValueFromParentsValue();
01372
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01373
Var** JacobianCorrection);
01374
void EMBprop(
const Vec obs,
real post);
01375
void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01376
void EMEpochInitialize();
01377
void EMUpdate();
01378
01380 const RandomVar&
X0() {
return parents[0]; }
01381 const RandomVar&
X1() {
return parents[1]; }
01382
01383 };
01384
01385
01417 class ProductRandomVariable:
public FunctionalRandomVariable
01418 {
01419
public:
01420 int m,
n,
l;
01421
01422
ProductRandomVariable(
MatRandomVar input1,
MatRandomVar input2);
01423
01424 virtual char*
classname() {
return "ProductRandomVariable"; }
01425
01426
void setValueFromParentsValue();
01427
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01428
Var** JacobianCorrection);
01429
void EMBprop(
const Vec obs,
real post);
01430
void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01431
void EMEpochInitialize();
01432
void EMUpdate();
01433
01435 const RandomVar&
X0() {
return parents[0]; }
01436 const RandomVar&
X1() {
return parents[1]; }
01437 bool scalars;
01438
01440 bool learn_X0() {
return learn_the_parameters[0]; }
01441 bool learn_X1() {
return learn_the_parameters[1]; }
01442 bool learn_something;
01443 Mat X0numerator;
01444 Mat X1numerator;
01445 Mat denom;
01446 Mat tmp1;
01447 Mat tmp2;
01448 Mat tmp3;
01449 Vec vtmp3;
01450 Vec tmp4;
01451 };
01452
01454 class SubVecRandomVariable:
public FunctionalRandomVariable
01455 {
01456
protected:
01457 int start;
01458
public:
01459
SubVecRandomVariable(
const RandomVar& parent,
int start,
int length);
01460 virtual char*
classname() {
return "SubvecRandomVariable"; }
01461
void setValueFromParentsValue();
01462
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01463
Var** JacobianCorrection);
01464
void EMBprop(
const Vec obs,
real posterior);
01465 };
01466
01472 class MultinomialRandomVariable:
public StochasticRandomVariable
01473 {
01474
public:
01477
MultinomialRandomVariable(
const RandomVar&
log_probabilities);
01478
01480 inline const RandomVar&
log_probabilities() {
return parents[0]; }
01481 inline bool learn_the_probabilities() {
return learn_the_parameters[0]; }
01482
01483 virtual char*
classname() {
return "MultinomialRandomVariable"; }
01484
01485
Var logP(
const Var& obs,
const RVInstanceArray& RHS,
01486
RVInstanceArray* parameters_to_learn);
01487
void setValueFromParentsValue();
01488
void EMUpdate();
01489
void EMBprop(
const Vec obs,
real posterior);
01490
void EMEpochInitialize();
01491
bool isDiscrete();
01492
01493
protected:
01495 Vec sum_posteriors;
01496 };
01497
01498
01505 class ExtendedRandomVariable:
public FunctionalRandomVariable
01506 {
01507
protected:
01508 int n_extend;
01509 real fill_value;
01510
public:
01511
ExtendedRandomVariable(
const RandomVar& parent,
real fill_value=1.0,
int n_extend=1);
01512 virtual char*
classname() {
return "ExtendedRandomVariable"; }
01513
void setValueFromParentsValue();
01514
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01515
Var** JacobianCorrection);
01516
void EMBprop(
const Vec obs,
real posterior);
01517 };
01518
01521 class ConcatColumnsRandomVariable:
public FunctionalRandomVariable
01522 {
01523
public:
01524
ConcatColumnsRandomVariable(
const RVArray& vars);
01525 virtual char*
classname() {
return "ConcatColumnsRandomVariable"; }
01526
void setValueFromParentsValue();
01527
bool invertible(
const Var& obs,
RVInstanceArray& unobserved_parents,
01528
Var** JacobianCorrection);
01529
void EMBprop(
const Vec obs,
real posterior);
01530 };
01531
01537
01538
01539 class RandomVarVMatrix:
public VMatrix
01540 {
01541
protected:
01542 RandomVar rv;
01543 Var instance;
01544 VarArray prop_path;
01545
01546
public:
01547
RandomVarVMatrix(
ConditionalExpression conditional_expression);
01548 virtual int nVars() {
return instance->
length(); }
01549 virtual Vec sample()
01550 {
01551
prop_path.
fprop();
01552
return instance->value;
01553 }
01554 };
01555
01557
01558 }
01559
01560
#endif
01561