00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 
00037 
00038  
00039 
00040 
00041 
00042 
00043 
00044 
00045 
00046 
00047 
00048 
00049 
00050 
00051 
00052 
00053 
00054 
00055 
00056 
00057 
00058 
00059 
00060 
00061 
00062 
00063 
00064 
00065 
00066 
00067 
00068 
00069 
00070 
00071 
00072 
00073 
00074 
00075 
00076 
00077 
00078 
00079 
00080 
00081 
00082 
00083 
00084 
00085 
00086 
00087 
00088 
00089 
00090 
00091 
00092 
00093 
00094 
00095 
00096 
00097 
00098 
00099 
00100 
00101 
00102 
00103 
00104 
00105 
00106 
00107 
00108 
00109 
00110 
00111 
00112 
00113 
00114 
00115 
00116 
00117 
00118 
00119 
00120 
00121 
00122 
00123 
00124 
00125 
00126 
00127 
00128 
00129 
00130 
00131 
00132 
00133 
00134 
00135 
00136 
00137 
00138 
00139 
00140 
00141 
00142 
00143 
00144 
00145 
00146 
00147 
00148 
00149 
00150 
00151 
00152 
00153 
00154 
00155 
00156 
00157 
00158 
00159 
00160 
00161 
00162 
00163 
00164 
00165 
00166 
00167 
00168 
00169 
00170 
00171 
00172 
00173 
00174 
00175 
00176 
00177 
00178 
00179 
00180 
00181 
00182 
00183 
00184 
00185 
00186 
00187 
00188 
00189 
00190 
00191 
00192 
00193 
00194 
00195 
00196 
00197 
00198 
00199 
00200 
00201 
00202 
00203 
00204 
00205 
00206 
00207 
00208 
00209 
00210 
00211 
00212 
00213 
00214 
00215 
00216 
00217 
00218 
00219 
00220 
00221 
00222 
00223 
00224 
00225 
00226 
00227 
00228 
00229 
00230 
00231 
00232 
00233 
00234 
00235 
00236 
00237 
00238 
00239 
00240 
00241 
00242 
00243 
00244 
00245 
00246 
00247 
00248 
00249 
00250 
00251 
00252 
00253 
00254 
00255 
00256 
00257 
00258 
00259 
00260 
00261 
00262 
00263 
00264 
00265 
00266 
00267 
00268 
00269 
00270 
00271 
00272 
00273 
00274 
00275 
00276 
00277 
00278 
00279 
00280 
00281 
00282 
00283 
00284 
00285 
00286 
00287 
00288 
00289 
00290 
00291 
00292 
00293 
00294 
00295 
00296 
00297 
00298 
00299 
00300 
00301 
00302 
00303 
00304 
00305 
00306 
00307 
00308 
00309 
00310 
00311 
00312 
00313 
00314 
00315 
00316 
00317 
00318 
00319 
00320 
00321 
00322 
00323 
00324 
00325 
00326 
00327 
00328 
00331 
#ifndef RANDOMVAR_INC
00332 
#define RANDOMVAR_INC
00333 
00334 
#include <plearn/opt/Optimizer.h>
00335 
00338 
#include "SampleVariable.h"
00339 
00341 
#include <plearn/vmat/VMat.h>
00342 
00343 
namespace PLearn {
00344 
using namespace std;
00345 
00346 
00347 
class RandomVariable;
00348 
class RVArray;
00349 
class RVInstance;
00350 
class RVInstanceArray;
00351 
class ConditionalExpression;
00352 
00354 class RandomVar: 
public PP<RandomVariable>
00355 {
00356  
public:
00357   
RandomVar();
00358   
RandomVar(
int length, 
int width=1);
00359   
RandomVar(
RandomVariable* v);
00360   
RandomVar(
const RandomVar& other);
00361 
00362   
RandomVar(
const Vec& vec);
00363   
RandomVar(
const Mat& mat);
00364   
RandomVar(
const Var& 
var);
00365 
00366   
RandomVar(
const RVArray& vars); 
00367  
00368   
RandomVar operator[](
RandomVar index); 
00369 
#if 0
00370 
  RandomVar operator[](
int i); 
00371 
#endif
00372 
00373   
void operator=(
const RVArray& vars);
00374 
00376   
void operator=(
real f); 
00377   
void operator=(
const Vec& v); 
00378   
void operator=(
const Mat& m);
00379   
void operator=(
const Var& v); 
00380 
00383   
RVInstance operator==(
const Var& v) 
const;
00384 
00386   bool operator==(
const RandomVar& rv)
 const { 
return rv.
ptr == this->ptr; }
00387   bool operator!=(
const RandomVar& rv)
 const { 
return rv.
ptr != this->ptr; }
00388 
00390   
RVArray operator&(
const RandomVar& v) 
const;
00391 
00401   
ConditionalExpression operator|(
RVArray rhs) 
const;
00402 
00404   
ConditionalExpression operator|(
RVInstanceArray rhs) 
const;
00405 
00406 
#if 0
00409 
  RandomVar operator[](RandomVar index);
00410 
  RandomVar operator[](
int i);
00412   
RandomVar operator()(
RandomVar i, 
RandomVar j);
00413   
RandomVar operator()(
int i, 
int j);
00414 
#endif
00415 
00416 };
00417 
00418 typedef RandomVar MatRandomVar;
00419 
00420 
00422 class RVArray: 
public Array<RandomVar>
00423 {
00424  
public:
00425   
RVArray();
00426   
RVArray(
int n, 
int n_extra_allocated=0);
00427   
RVArray(
const Array<RandomVar>& va);
00428   
RVArray(
const RandomVar& v, 
int n_extra_allocated=0);
00429   
RVArray(
const RandomVar& v1, 
const RandomVar& v2, 
int n_extra_allocated=0);
00430   
RVArray(
const RandomVar& v1, 
const RandomVar& v2, 
const RandomVar& v3, 
00431           
int n_extra_allocated=0);
00432 
00433   
int length() 
const;
00434 
00436   
VarArray values() 
const;
00437 
00440 
00443 
00445   
RandomVar operator[](
RandomVar index);
00446 
00447   RandomVar& 
operator[](
int i)
00448     { 
return Array<RandomVar>::operator[](i); }
00449 
00450   const RandomVar& 
operator[](
int i)
 const
00451 
    { 
return Array<RandomVar>::operator[](i); }
00452 
00453   
static int compareRVnumbers(
const RandomVar* v1, 
const RandomVar* v2);
00454 
00457   
void sort();
00458 };
00459 
00460 
00462 class RVInstance 
00463 {
00464  
public:
00465   RandomVar V;
00466   Var v;
00467 
00468   
RVInstance(
const RandomVar& VV, 
const Var& vv);
00469   
RVInstance();
00470 
00471   
RVInstanceArray operator&&(
RVInstance rvi);
00472 
00473   
ConditionalExpression operator|(
RVInstanceArray a);
00474 
00476   
void swap_v_and_Vvalue();
00477 
00478 };
00479 
00480 class RVInstanceArray: 
public Array<RVInstance>
00481 {
00482  
public:
00483   
RVInstanceArray();
00484   
RVInstanceArray(
int n, 
int n_extra_allocated=0);
00485   
RVInstanceArray(
const Array<RVInstance>& a);
00486   
RVInstanceArray(
const RVInstance& v, 
int n_extra_allocated=0);
00487   
RVInstanceArray(
const RVInstance& v1, 
const RVInstance& v2,
00488                   
int n_extra_allocated=0);
00489   
RVInstanceArray(
const RVInstance& v1, 
const RVInstance& v2, 
00490                   
const RVInstance& v3, 
int n_extra_allocated=0);
00491 
00493   
int length() 
const; 
00494 
00497   
RVInstanceArray operator&&(
RVInstance rhs);
00498 
00506   
ConditionalExpression operator|(
RVInstanceArray rhs);
00507 
00509   
RVArray random_variables() 
const;
00510 
00512   
VarArray values() 
const;
00514   
VarArray instances() 
const;
00515 
00517   void swap_v_and_Vvalue()
00518     { 
for (
int i=0;i<
size();i++) (*this)[i].swap_v_and_Vvalue(); }
00519 
00520   
static int compareRVnumbers(
const RVInstance* rvi1, 
const RVInstance* rvi2);
00521 
00524   
void sort();
00525 
00526 };
00527 
00528 class ConditionalExpression
00529 {
00530  
public:
00531   RVInstance LHS;
00532   RVInstanceArray RHS;
00533 
00535   
ConditionalExpression(
RVInstance lhs, 
RVInstanceArray rhs);
00537   
ConditionalExpression(
RVInstance lhs);
00539   
ConditionalExpression(
RandomVar lhs);
00542   
ConditionalExpression(
RVInstanceArray lhs);
00543 };
00544 
00545 class RandomVariable: 
public PPointable
00546 {
00547   
friend class RandomVar;
00548   
friend class RVInstanceArray;
00549   
friend class RVArray;
00550 
00551   
static int rv_counter; 
00552 
00553 
00554 
protected:
00555 
00558   const int rv_number; 
00559 
00560  
public:
00561   const RVArray parents; 
00562 
00567   Var value; 
00568   
00569  
protected:
00574   bool marked;
00575 
00576   bool EMmark; 
00577   bool pmark; 
00578 
00581   bool* 
learn_the_parameters;
00582 
00583  
public:
00585   
RandomVariable(
int thelength, 
int thewidth=1);
00586   
RandomVariable(
const Vec& the_value);
00587   
RandomVariable(
const Mat& the_value);
00588   
RandomVariable(
const Var& the_value);
00590   
RandomVariable(
const RVArray& parents, 
int thelength);
00591   
RandomVariable(
const RVArray& parents, 
int thelength, 
int thewidth);
00592   
00593   
virtual char* 
classname() = 0;
00594 
00595   virtual int length() { 
return value->
length(); }
00596   virtual int width() { 
return value->
width(); }
00597   int nelems() { 
return value->nelems(); }
00598   bool isScalar() { 
return length()==1 && 
width()==1; }
00599   bool isVec() { 
return width()==1 || 
length()==1; }
00600   bool isColumnVec() { 
return width()==1; }
00601   bool isRowVec() { 
return length()==1; }
00602 
00609   
virtual bool isNonRandom() = 0;
00612   inline bool isConstant() { 
return isNonRandom() && 
value->isConstant(); }
00613 
00616   
virtual bool isDiscrete() = 0;
00617 
00622   
RandomVar subVec(
int start, 
int length);
00623 
00625 
00632 
00633 
00638   
virtual void setValueFromParentsValue() = 0;
00639 
00640   void markRHSandSetKnownValues(
const RVInstanceArray& RHS)
00641     {
00642       
for (
int i=0;i<RHS.
size();i++)
00643         RHS[i].V->mark(RHS[i].v);
00644       
setKnownValues();
00645     }
00646       
00654   
virtual void EMBprop(
const Vec obs, 
real posterior) = 0;
00655 
00660   
virtual void EMUpdate();
00661 
00668   
virtual bool canStopEM();
00669 
00672   
virtual void EMTrainingInitialize(
const RVArray& parameters_to_learn);
00673 
00676   
virtual void EMEpochInitialize();
00677 
00681 
00682   virtual void mark(
Var v) { 
marked = 
true; 
value = v; }
00683   virtual void mark() { 
marked = 
true; }
00684   virtual void unmark() { 
marked = 
false; }
00685   
virtual void clearEMmarks();
00686 
00688   
virtual void unmarkAncestors();
00689 
00690   virtual bool isMarked() { 
return marked; }
00691 
00696   
virtual void setKnownValues();
00697 
00703   
virtual Var logP(
const Var& obs, 
const RVInstanceArray& RHS,
00704       
RVInstanceArray* parameters_to_learn=0) = 0;
00705   
virtual Var P(
const Var& obs, 
const RVInstanceArray& RHS);
00706 
00712   
virtual Var ElogP(
const Var& obs, 
RVArray& parameters_to_learn, 
00713                     
const RVInstanceArray& RHS);
00714 
00728 
00729 
00734 
00735   
virtual real EM(
const RVArray& parameters_to_learn,
00736                    
VarArray& prop_path, 
00737                    
VarArray& observedVars, 
00738                    
VMat distr, 
int n_samples, 
00739                    
int max_n_iterations, 
00740                    
real relative_improvement_threshold,
00741                    
bool accept_worsening_likelihood=
false);
00742 
00750   
virtual real epoch(
VarArray& prop_path, 
00751                       
VarArray& observed_vars, 
const VMat& distr, 
00752                       
int n_samples, 
00753                       
bool do_EM_learning=
true);
00754 
00755   
virtual ~RandomVariable();
00756 
00757 };
00758 
00759 
00761 
00771 RandomVar 
operator*(RandomVar a, RandomVar b);
00772 
00779 RandomVar 
operator+(RandomVar a, RandomVar b);
00780 
00783 RandomVar 
operator-(RandomVar a, RandomVar b);
00784 
00787 RandomVar 
operator/(RandomVar a, RandomVar b);
00788 
00790 RandomVar 
exp(RandomVar x);
00791 
00793 RandomVar 
log(RandomVar x);
00794 
00795 RandomVar 
extend(RandomVar v, 
real extension_value = 1.0, 
int n_extend = 1);
00796 
00797 RandomVar 
hconcat(
const RVArray& a);
00798 
00813 
00814 
00819 
00820 
00821 
real EM(ConditionalExpression conditional_expression, 
00822          RVArray parameters_to_learn,
00823          VMat distr, 
int n_samples, 
int max_n_iterations=1, 
00824          
real relative_improvement_threshold=0.001,
00825          
bool accept_worsening_likelihood=
false,
00826          
bool compute_final_train_NLL=
true);
00827 
00828 
real oEM(ConditionalExpression conditional_expression,
00829          RVArray parameters_to_learn,
00830          VMat distr, 
int n_samples, 
int max_n_iterations, 
00831          
real relative_improvement_threshold=0.001,
00832          
bool compute_final_train_NLL=
true);
00833 
00834 
real oEM(ConditionalExpression conditional_expression,
00835          RVArray parameters_to_learn,
00836          VMat distr, 
int n_samples, 
00837          Optimizer& MStepOptimizer,
00838          
int max_n_iterations,
00839          
real relative_improvement_threshold=0.001,
00840          
bool compute_final_train_NLL=
true);
00841 
00851 Var 
logP(ConditionalExpression conditional_expression, 
00852          
bool clearMarksUponReturn=
true,
00853          RVInstanceArray* parameters_to_learn=0);
00854 
00864 Var 
P(ConditionalExpression conditional_expression,
00865          
bool clearMarksUponReturn=
true);
00866 
00872 Var 
ElogP(ConditionalExpression conditional_expression, 
00873     RVInstanceArray& parameters_to_learn,
00874     
bool clearMarksUponReturn=
true);
00875 
00880 RandomVar 
marginalize(
const RandomVar& RV, 
const RandomVar& hiddenRV);
00891 Vec 
sample(ConditionalExpression conditional_expression);
00892 
00897 Var 
Sample(ConditionalExpression conditional_expression);
00898 
00906 
void sample(ConditionalExpression conditional_expression,Mat& samples);
00907 
00909 
00914 RandomVar 
normal(
real mean=0, 
real standard_dev=1, 
int d=1,
00915                  
real minimum_standard_deviation=1e-6);
00916 
00921 RandomVar 
normal(RandomVar mean, RandomVar log_variance,
00922                  
real minimum_standard_deviation=1e-6);
00923 
00929 RandomVar 
mixture(RVArray components, RandomVar log_weights);
00930 
00937 RandomVar 
multinomial(RandomVar log_probabilities);
00938 
00961 class StochasticRandomVariable: 
public RandomVariable
00962 {
00963  
public:
00964   
StochasticRandomVariable(
int length=1);
00965   
StochasticRandomVariable(
const RVArray& params,
int length);
00966   
StochasticRandomVariable(
const RVArray& params,
int length, 
int width);
00967 
00969 
00971   virtual bool isNonRandom() { 
return false; }
00972 
00974   virtual bool isDiscrete() { 
return false; }
00975 
00976   
virtual void setKnownValues();
00977 
00981   
01000 };
01001 
01006 class FunctionalRandomVariable: 
public RandomVariable {
01007  
public:
01008   
FunctionalRandomVariable(
int length);
01009   
FunctionalRandomVariable(
int length, 
int width);
01010   
FunctionalRandomVariable(
const Vec& the_value);
01011   
FunctionalRandomVariable(
const Mat& the_value);
01012   
FunctionalRandomVariable(
const Var& the_value);
01013   
FunctionalRandomVariable(
const RVArray& parents,
int length);
01014   
FunctionalRandomVariable(
const RVArray& parents,
int length, 
int width);
01015 
01017 
01018   
virtual Var logP(
const Var& obs, 
const RVInstanceArray& RHS,
01019       
RVInstanceArray* parameters_to_learn);
01020 
01022   
bool isNonRandom();
01023 
01025   
virtual bool isDiscrete();
01026 
01028 
01032   
01047   
virtual bool invertible(
const Var& obs, 
01048                              
RVInstanceArray& unobserved_parents,
01049                              
Var** JacobianCorrection);
01050 
01052   
virtual void setValueFromParentsValue() = 0;
01053 
01057   
01080 };
01081 
01087 class NonRandomVariable: 
public FunctionalRandomVariable
01088 {
01089 
public:
01092   
NonRandomVariable(
int thelength);
01093   
NonRandomVariable(
int thelength, 
int thewidth);
01098   
NonRandomVariable(
const Var& v);
01099 
01100   virtual char* 
classname() { 
return "NonRandomVariable"; }
01101 
01102   void setValueFromParentsValue() { }
01103   bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01104                      
Var** JacobianCorrection) 
01105     { 
return true; }
01106   void EMBprop(
const Vec obs, 
real post) { }
01107 };
01108 
01109 class JointRandomVariable: 
public FunctionalRandomVariable
01110 {
01111  
public:
01112   
JointRandomVariable(
const RVArray& variables);
01113 
01114   virtual char* 
classname() { 
return "JointRandomVariable"; }
01115 
01116   
void setValueFromParentsValue();
01117   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01118                      
Var** JacobianCorrection);
01119   
void EMBprop(
const Vec obs, 
real post);
01120 };
01121 
01124 class RandomElementOfRandomVariable: 
public FunctionalRandomVariable
01125 {
01126 
public:
01127   
RandomElementOfRandomVariable(
const RandomVar& 
v, 
const RandomVar& 
index);
01128 
01129   virtual char* 
classname() { 
return "RandomElementOfRandomVariable"; }
01130 
01131   
void setValueFromParentsValue();
01132   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01133                      
Var** JacobianCorrection);
01134   
void EMBprop(
const Vec obs, 
real post);
01135 
01137   inline const RandomVar& 
v() { 
return parents[0]; }
01138   inline const RandomVar& 
index() { 
return parents[1]; }
01139 
01140 };
01141 
01142 
01148 class RVArrayRandomElementRandomVariable: 
public FunctionalRandomVariable
01149 {
01150  
public:
01151   
RVArrayRandomElementRandomVariable(
const RVArray& table, 
const RandomVar& 
index);
01152 
01153   virtual char* 
classname() { 
return "RVArrayRandomElementRandomVariable"; }
01154 
01155   
void setValueFromParentsValue();
01156   
virtual Var logP(
const Var& obs, 
const RVInstanceArray& RHS,
01157       
RVInstanceArray* parameters_to_learn=0);
01158   
void EMBprop(
const Vec obs, 
real post);
01159 
01161   inline const RandomVar& 
index() { 
return parents[parents.
size()-1]; }
01162 
01163 };
01164 
01165 class NegRandomVariable: 
public FunctionalRandomVariable
01166 {
01167  
public:
01168   
NegRandomVariable(
RandomVariable* input);
01169   
01170   virtual char* 
classname() { 
return "NegRandomVariable"; }
01171 
01172   
void setValueFromParentsValue();
01173   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01174                      
Var** JacobianCorrection);
01175   
void EMBprop(
const Vec obs, 
real post);
01176 };
01177 
01178 class ExpRandomVariable: 
public FunctionalRandomVariable
01179 {
01180  
public:
01181   
ExpRandomVariable(
RandomVar& input);
01182   
01183   virtual char* 
classname() { 
return "ExpRandomVariable"; }
01184 
01185   
void setValueFromParentsValue();
01186   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01187                      
Var** JacobianCorrection);
01188   
void EMBprop(
const Vec obs, 
real post);
01189 };
01190 
01191 
01192 class LogRandomVariable: 
public FunctionalRandomVariable
01193 {
01194  
public:
01195   
LogRandomVariable(
RandomVar& input);
01196   
01197   virtual char* 
classname() { 
return "LogRandomVariable"; }
01198 
01199   
void setValueFromParentsValue();
01200   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01201                      
Var** JacobianCorrection);
01202   
void EMBprop(
const Vec obs, 
real post);
01203 };
01204 
01205 class DiagonalNormalRandomVariable: 
public StochasticRandomVariable
01206 {
01211  
protected:
01212   real minimum_variance;
01213   real normfactor; 
01214   bool shared_variance; 
01215 
01216  
public:
01217   
DiagonalNormalRandomVariable(
const RandomVar& mean, 
01218                                
const RandomVar& log_variance,
01219                                
real minimum_standard_deviation = 1e-10);
01220 
01221   virtual char* 
classname() { 
return "DiagonalNormalRandomVariable"; }
01222 
01223   
Var logP(
const Var& obs, 
const RVInstanceArray& RHS,
01224       
RVInstanceArray* parameters_to_learn);
01225   
void setValueFromParentsValue();
01226   
void EMUpdate();
01227   
void EMBprop(
const Vec obs, 
real posterior);
01228   
void EMEpochInitialize();
01229 
01231   inline const RandomVar& 
mean() { 
return parents[0]; }
01232   inline const RandomVar& 
log_variance() { 
return parents[1]; }
01233   inline bool& 
learn_the_mean() { 
return learn_the_parameters[0]; }
01234   inline bool& 
learn_the_variance() { 
return learn_the_parameters[1]; }
01235 
01236  
protected:
01238   Vec mu_num; 
01239   Vec sigma_num; 
01240   real denom; 
01241 };
01242 
01243 class MixtureRandomVariable: 
public StochasticRandomVariable
01244 {
01245  
protected:
01247   RVArray components; 
01248 
01249  
public:
01250   
MixtureRandomVariable(
const RVArray& components,
01251                         
const RandomVar& log_weights);
01252 
01253   virtual char* 
classname() { 
return "MixtureRandomVariable"; }
01254 
01257   inline const RandomVar& 
log_weights() { 
return parents[0]; }
01258   inline bool& 
learn_the_weights() { 
return learn_the_parameters[0]; }
01259 
01260   
virtual Var logP(
const Var& obs, 
const RVInstanceArray& RHS,
01261       
RVInstanceArray* parameters_to_learn);
01262   
virtual Var ElogP(
const Var& obs, 
RVInstanceArray& parameters_to_learn,
01263       
const RVInstanceArray& RHS);
01264 
01265   
virtual void setValueFromParentsValue();
01266   
virtual void EMUpdate();
01267   
virtual void EMBprop(
const Vec obs, 
real posterior);
01268   
virtual void EMEpochInitialize();
01269   
virtual void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01270   
virtual bool isDiscrete();
01271   
virtual bool canStopEM();
01272   
virtual void setKnownValues();
01273   
virtual void unmarkAncestors();
01274   
virtual void clearEMmarks();
01275 
01276  
protected:
01278   Vec posteriors; 
01279   Vec sum_posteriors; 
01280 
01282   VarArray componentsLogP; 
01283   Var lw; 
01284   Var logp; 
01285 };
01286 
01293 class PlusRandomVariable: 
public FunctionalRandomVariable
01294 {
01295  
public:
01296   
PlusRandomVariable(
RandomVar input1, 
RandomVar input2);
01297 
01298   virtual char* 
classname() { 
return "PlusRandomVariable"; }
01299 
01300   
void setValueFromParentsValue();
01301   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01302                      
Var** JacobianCorrection);
01303   
void EMBprop(
const Vec obs, 
real post);
01304   
void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01305   
void EMEpochInitialize();
01306   
void EMUpdate();
01307 
01309   const RandomVar& 
X0() { 
return parents[0]; }
01310   const RandomVar& 
X1() { 
return parents[1]; }
01311 
01313   bool learn_X0() { 
return learn_the_parameters[0]; }
01314   bool learn_X1() { 
return learn_the_parameters[1]; }
01315   bool learn_something;
01316   RandomVar parent_to_learn; 
01317   RandomVar other_parent; 
01318   Vec numerator;
01319   Vec difference;
01320   real denom;
01321 };
01322 
01329 class MinusRandomVariable: 
public FunctionalRandomVariable
01330 {
01331  
public:
01332   
MinusRandomVariable(
RandomVar input1, 
RandomVar input2);
01333 
01334   virtual char* 
classname() { 
return "MinusRandomVariable"; }
01335 
01336   
void setValueFromParentsValue();
01337   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01338                      
Var** JacobianCorrection);
01339   
void EMBprop(
const Vec obs, 
real post);
01340   
void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01341   
void EMEpochInitialize();
01342   
void EMUpdate();
01343 
01345   const RandomVar& 
X0() { 
return parents[0]; }
01346   const RandomVar& 
X1() { 
return parents[1]; }
01347 
01349   bool learn_X0() { 
return learn_the_parameters[0]; }
01350   bool learn_X1() { 
return learn_the_parameters[1]; }
01351   bool learn_something;
01352   RandomVar parent_to_learn; 
01353   RandomVar other_parent; 
01354   Vec numerator;
01355   Vec difference;
01356   real denom;
01357 };
01358 
01359 
01364 class ElementWiseDivisionRandomVariable: 
public FunctionalRandomVariable
01365 {
01366  
public:
01367   
ElementWiseDivisionRandomVariable(
RandomVar input1, 
RandomVar input2);
01368 
01369   virtual char* 
classname() { 
return "ElementWiseDivisionRandomVariable"; }
01370 
01371   
void setValueFromParentsValue();
01372   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01373                      
Var** JacobianCorrection);
01374   
void EMBprop(
const Vec obs, 
real post);
01375   
void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01376   
void EMEpochInitialize();
01377   
void EMUpdate();
01378 
01380   const RandomVar& 
X0() { 
return parents[0]; }
01381   const RandomVar& 
X1() { 
return parents[1]; }
01382 
01383 };
01384 
01385 
01417 class ProductRandomVariable: 
public FunctionalRandomVariable
01418 {
01419  
public:
01420   int m,
n,
l; 
01421 
01422   
ProductRandomVariable(
MatRandomVar input1, 
MatRandomVar input2);
01423 
01424   virtual char* 
classname() { 
return "ProductRandomVariable"; }
01425 
01426   
void setValueFromParentsValue();
01427   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01428                      
Var** JacobianCorrection);
01429   
void EMBprop(
const Vec obs, 
real post);
01430   
void EMTrainingInitialize(
const RVArray& parameters_to_learn);
01431   
void EMEpochInitialize();
01432   
void EMUpdate();
01433 
01435   const RandomVar& 
X0() { 
return parents[0]; } 
01436   const RandomVar& 
X1() { 
return parents[1]; } 
01437   bool scalars; 
01438 
01440   bool learn_X0() { 
return learn_the_parameters[0]; }
01441   bool learn_X1() { 
return learn_the_parameters[1]; }
01442   bool learn_something;
01443   Mat X0numerator; 
01444   Mat X1numerator; 
01445   Mat denom; 
01446   Mat tmp1; 
01447   Mat tmp2; 
01448   Mat tmp3; 
01449   Vec vtmp3; 
01450   Vec tmp4;
01451 };
01452 
01454 class SubVecRandomVariable: 
public FunctionalRandomVariable
01455 {
01456  
protected:
01457   int start;
01458  
public:
01459   
SubVecRandomVariable(
const RandomVar& parent,
int start, 
int length);
01460   virtual char* 
classname() { 
return "SubvecRandomVariable"; }
01461   
void setValueFromParentsValue();
01462   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01463                      
Var** JacobianCorrection);
01464   
void EMBprop(
const Vec obs, 
real posterior);
01465 };
01466 
01472 class MultinomialRandomVariable: 
public StochasticRandomVariable
01473 {
01474  
public:
01477   
MultinomialRandomVariable(
const RandomVar& 
log_probabilities);
01478 
01480   inline const RandomVar& 
log_probabilities() { 
return parents[0]; }
01481   inline bool learn_the_probabilities() { 
return learn_the_parameters[0]; }
01482 
01483   virtual char* 
classname() { 
return "MultinomialRandomVariable"; }
01484 
01485   
Var logP(
const Var& obs, 
const RVInstanceArray& RHS, 
01486       
RVInstanceArray* parameters_to_learn);
01487   
void setValueFromParentsValue(); 
01488   
void EMUpdate();
01489   
void EMBprop(
const Vec obs, 
real posterior);
01490   
void EMEpochInitialize();
01491   
bool isDiscrete();
01492 
01493  
protected:
01495   Vec sum_posteriors; 
01496 };
01497 
01498 
01505 class ExtendedRandomVariable: 
public FunctionalRandomVariable
01506 {
01507  
protected:
01508   int n_extend;
01509   real fill_value;
01510  
public:
01511   
ExtendedRandomVariable(
const RandomVar& parent, 
real fill_value=1.0,
int n_extend=1);
01512   virtual char* 
classname() { 
return "ExtendedRandomVariable"; }
01513   
void setValueFromParentsValue();
01514   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01515                      
Var** JacobianCorrection);
01516   
void EMBprop(
const Vec obs, 
real posterior);
01517 };
01518 
01521 class ConcatColumnsRandomVariable: 
public FunctionalRandomVariable
01522 {
01523  
public:
01524   
ConcatColumnsRandomVariable(
const RVArray& vars);
01525   virtual char* 
classname() { 
return "ConcatColumnsRandomVariable"; }
01526   
void setValueFromParentsValue();
01527   
bool invertible(
const Var& obs, 
RVInstanceArray& unobserved_parents,
01528                      
Var** JacobianCorrection);
01529   
void EMBprop(
const Vec obs, 
real posterior);
01530 };
01531 
01537 
01538 
01539 class RandomVarVMatrix: 
public VMatrix
01540 {
01541  
protected:
01542   RandomVar rv;
01543   Var instance;
01544   VarArray prop_path;
01545 
01546  
public:
01547   
RandomVarVMatrix(
ConditionalExpression conditional_expression);
01548   virtual int nVars() { 
return instance->
length(); }
01549   virtual Vec sample()
01550     {
01551       
prop_path.
fprop();
01552       
return instance->value;
01553     }
01554 };
01555 
01557 
01558 } 
01559 
01560 
#endif
01561