Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

RandomVar.h

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 00004 // PLearn (A C++ Machine Learning Library) 00005 // Copyright (C) 1998 Pascal Vincent 00006 // Copyright (C) 1999-2002 Pascal Vincent, Yoshua Bengio and University of Montreal 00007 // 00008 00009 // Redistribution and use in source and binary forms, with or without 00010 // modification, are permitted provided that the following conditions are met: 00011 // 00012 // 1. Redistributions of source code must retain the above copyright 00013 // notice, this list of conditions and the following disclaimer. 00014 // 00015 // 2. Redistributions in binary form must reproduce the above copyright 00016 // notice, this list of conditions and the following disclaimer in the 00017 // documentation and/or other materials provided with the distribution. 00018 // 00019 // 3. The name of the authors may not be used to endorse or promote 00020 // products derived from this software without specific prior written 00021 // permission. 00022 // 00023 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00024 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00025 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00026 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00027 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00028 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00029 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00030 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00031 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00032 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00033 // 00034 // This file is part of the PLearn library. For more information on the PLearn 00035 // library, go to the PLearn Web site at www.plearn.org 00036 00037 00038 00039 00040 /* ******************************************************* 00041 * $Id: RandomVar.h,v 1.6 2004/07/21 16:30:54 chrish42 Exp $ 00042 * AUTHORS: Pascal Vincent & Yoshua Bengio 00043 * This file is part of the PLearn library. 00044 ******************************************************* */ 00045 00046 /* RandomVar.h 00047 00048 Random Variables package: 00049 00050 RandomVar class, helper classes such as RandomVariable, 00051 RVInstance, RVArray, RVInstanceArray, ConditionalExpression, 00052 and many subclasses of RandomVariable, 00053 plus helper global functions and operators. 00054 00055 */ 00056 00057 /* TUTORIAL ON THE RANDOMVAR PACKAGE 00058 00059 Whereas a Var represents a variable in the mathematical sense, 00060 a RandomVar represents a random variable in the mathematical sense. 00061 A random variable is generally defined in terms of other random variables 00062 through deterministic transformations, or in terms of other random 00063 variables which are the parameters of a its distribution. 00064 A RandomVar represents a node in a graphical 00065 model, and its distribution is defined in terms 00066 of the values of its parents in the model. 00067 For example, its parents may be the parameter 00068 of its distribution or it may be the variables 00069 which when combined deterministically give rise 00070 to its value. The set of classes provided here 00071 allow to build such a network of random variables, 00072 and to make limited inferences, probability computations, 00073 gradient computations, and learning. 00074 00075 Examples of use of RandomVars: 00076 00077 Var u(1),lv(1); // constants, but changing their value will change 00078 Var w; // the distributions associated to the RandomVar's 00079 ... 00080 // things like *, +, log, tanh, etc... must do the proper thing for RandomVars 00081 RandomVar X = gamma(0.5)*log(normal(u,lv)); 00082 RandomVar Y = tanh(product(w,X) + u); 00083 RandomVar LW(2); // unnormalized log-weights of the mixture 00084 RandomVar Z = mixture(Y&X,LW); 00085 00086 // see the comment on operator&, operator[] and mixtures below 00087 ... 00088 // (conditionned on the RHS of the |) 00089 Vec x,y; // put some value in x and y 00090 Vec z = sample(Z|(X==x && Y==y)); 00091 // This is achieved by redefining operator==, operator&& and operator| 00092 // to represent the data structure which the function sample 00093 // takes as argument. In particular, note that == creates the 00094 // RVInstance data structure (which contains a RV and a Var instance), 00095 // while && or & makes an array of these structures (RVInstanceArray), and 00096 // the "|", builds a ConditionalExpression, that contains a RVInstanceArray 00097 // for the left hand side and a RVInstanceArray for the right hand side. Note 00098 // the use of parentheses because of the default precedence of operators. 00099 ... 00100 // Using these operators, you can also express the computation of the probability of 00101 // a value for the random variable, or its conditional probability. In fact the 00102 // statement below defines a functional relationship (in the usual "Variable" 00103 // sense between the variable y and the variable p or log_p). Note that 00104 // no actual numerical value has yet been computed. 00105 Var y(1); 00106 Var p = P(Y==y); 00107 Var log_p = logP(Y==y); 00108 // or 00109 Var y,x; 00110 Var log_p = logP((Y==y)|(X==x)); // note again ()'s for precedence 00111 // e.g. of use 00112 Vec actual_value_for_x_and_y = ...; 00113 Func f = log_p(x&y); 00114 real prob = f(actual_value_for_x_and_y)[0]; 00115 cout << "log(P(Y|X))=" << prob << endl; 00116 ... 00117 // in the case of discrete distributions, the whole distribution can be returned 00118 // by P and logP, which are also defined on RandomVars. 00119 Vec p = P(Y); 00120 Vec log_p = logP(Y); 00121 ... 00122 // note that if Y has RandomVar parents, the above var can only be computed 00123 // if these parents are given a particular value or if 00124 // they are integrated over (with the function marginalize below). 00125 // This call will therefore automatically try to marginalize Y 00126 // (by integrating over the parents which are not observed). 00127 // To make a RandomVar observed, simply use the conditioning 00128 // notation V|(X==x && Y==y && Z==z), e.g. to condition V on X=x,Y=y,Z=z. 00129 ... 00130 // Similarly to P, logP, and sample, other functions of RVs are defined: 00131 // construct a Var that is functionally dependent on the Var x 00132 // and represents the expectation of Y given that X==x. 00133 Var e=E(Y|(X==x)); 00134 // Similarly for covariance matrix: 00135 Var v=V(Y|(X==x && Z==z)); 00136 // and for the cumulative distribution function 00137 // (which here depends on the Vars y and x) 00138 Var c=P((Y<y)|(X==x)); 00139 // Note that derivatives through all these functional relationships 00140 // can automatically be computed. For example, to compute the 00141 // gradients of the log-probability wrt some parameters W & B & LogVariance 00142 RandomVar W,B,X,LogVariance; // all are "non-random", X is the input 00143 Var w,b,x,lv; // values that the above will take 00144 // e.g. to give values to these Vars 00145 w[0]=1; b[0]=0;x[0]=3;lv[0]=0; 00146 RandomVar Y=normal(W*X+B,LV); // the model (i.e. a regression) 00147 // establish the functional relationship of interest, 00148 // which goes from (y & x & w & b & lv) to logp: 00149 Var logp = logP((Y==y)|(X==x && W==w && B==b && LV==lv)); 00150 // to actually compute logp, do 00151 logp->fprop_from_all_sources(); // a source is "constant" Var 00152 // to compute gradients, use the propagationPath function to find 00153 // the path of Vars from, say w&b, to logp: 00154 VarArray prop_path = propagationPath(w&b,logp); 00155 prop_path.bprop(); // compute dlogP/dparams in w->gradient and b->gradient 00156 ... 00157 // By default a RandomVar represents a "non-random" random variable 00158 // (of class NonRandomVariable or a FunctionalRandomVariable which depends 00159 // only on NonRandomVariable). This is not the same as a "constant" 00160 // variable. It only means that its value is deterministic, but 00161 // its value may be a (Var) function of other Vars: 00162 RandomVar X; 00163 X->value = 1+exp(y+w*z); // y, w and z are Vars here 00164 ... 00165 // Parameters of the distributions that define the random variables can be 00166 // learned, most generally by optimizing a criterion to optimize, e.g. 00167 Mat observed_xy_matrix; // each row = concatenation of an X and a Y obs. 00168 Var cost = -logP((Y==y)|(X==x && Params==params)); 00169 // below, establish the functional relationship between x & y & params 00170 // and the "totalcost" which is the sum of "cost" when x & y are 00171 // "sampled" from the given empirical distribution. 00172 Var totalcost = meanOf(cost,x&y,VMat(observed_xy_matrix), 00173 observed_xy_matrix.length(), 00174 params); 00175 // construct an optimizer that optimizes totalcost by varying params 00176 ConjugateGradientOptimizer opt(params,totalcost, 1e-5, 1e-3, 100); 00177 // do the actual optimization 00178 real train_NLL = opt.optimize(); 00179 // now we can test on a particular point setting values for 00180 // x and y and params and doing an fprop with 00181 propagationPath(x&y&params,cost).fprop(); 00182 ... 00183 // Sometimes, the parameters can be estimated more efficiently 00184 // using an internal mechanism for estimation (usually the analytical 00185 // solution of maximum likelihood or the EM algorithm): 00186 real avgNegLogLik=EM(Y|X,W&B&LV,VMat(observed_xy_matrix), 00187 observed_xy_matrix.length(),4,0.001); 00188 // where the first argument specifies which conditional distribution 00189 // is of interest (ignoring the parameters), the second argument 00190 // gives the parameters to estimate, and the third one specifies 00191 // a training set. Note that the order of the variables in the 00192 // observed_xy_matrix must be (1) inputs: all the variables on the RHS 00193 // of the conditioning |, (2) outputs: all the variables on the LHS of the |. 00194 ... 00195 // Arrays of RVs can be formed with the operator &: 00196 RVArray a = X & Y & V; 00197 // and they can be automatically cast into JointRandomVariables 00198 // (whose value is the concatenation of the values of its parents) 00199 RandomVar Z = X & Y & V; 00200 ... 00201 // Marginals can sometimes be obtained (when X is discrete or 00202 // the integral is feasible analytically and the code knows how to do it...). 00203 // For example, suppose X is one of the parents of Y. Then 00204 RandomVar mY = marginalize(Y,X); 00205 // is a random variable such that P(mY=y) = int_x P(Y=y|X=x)P(X=x) dx 00206 // this is obtained by summing over the values of X if it is discrete, 00207 // by doing the integral if we know how to do it, or otherwise, by the 00208 // Laplace approximation, or by some numerical integration method 00209 // such as Monte-Carlo. 00210 ... 00211 // The operator() is defined on RandomVar as follows: 00212 // If i is an integer, X(i) extracts a RandomVar that is scalar 00213 // and corresponds to the i-th element of the vector random variable X 00214 // (similarly, if the underlying Var is a Var, X(i,j), 00215 // extracts the random element (i,j)). These two operators 00216 // are also defined for the case in which the index is a Var 00217 // (treated like the integer), and the case in which it is a RandomVar. 00218 // The last case, X(I), actually represents a mixture of the elements 00219 // of the vector X, with weights given by the parameters of I 00220 // (which must be discrete). 00221 ... 00222 // The operator[] is defined on RVArrays 00223 // and it allows to extract the i-th random variable in the array. 00224 // With I a RandomVar and A an RVArray, A[I], is very interesting because 00225 // it represents the graphical model of a mixture in which I is the index, 00226 // and it is not (yet) integrated over. 00227 RVArray A(3); 00228 A[0]=X; A[1]=Y; A[2]=Z; 00229 RandomVar XYZ(A); // joint distribution 00230 // or equivalently 00231 RandomVar XYZ = X & Y & Z; 00232 ... 00233 // A MultinomialRandomVariable is a subclass of RandomVariable 00234 // that represents discrete-valued variables in the range 0 ... N-1. 00235 RandomVar LW(3); // unnormalized log-probabilities of I 00236 RandomVar I = multinomial(LW); // N=3 possible values here 00237 // The parameters of a discrete random variable are the "log-probabilities" 00238 // (more precisely the discrete probabilities are obtained with a softmax 00239 // of the parameters, LW here). The discrete random variable will be 00240 // conditional if LW is not a NonRandomVariable but rather depends 00241 // on some other RVs. 00242 ... 00243 // Let us consider a random variable that is obtained by selecting 00244 // one of several random variables (on the same space). We call 00245 // such a random variable an RVArrayRandomElementRandomVariable and it is 00246 // obtained with the operator[] acting on a RVArray, 00247 // with a discrete randomvariable as argument: 00248 RandomVar V = A[I]; // will take either the distribution of X, Y or Z according to value of I 00249 // therefore P(V=v|I=i)=P(A[i]=v) 00250 ... 00251 // A mixture is the marginalization of an IndexedRandomVariable with 00252 // respect to the random index: 00253 RandomVar LW(3); // unnormalized log-weights of the mixture 00254 RandomVar M = mixture(a,LW); 00255 // which is exactly the same thing as 00256 RandomVar I = multinomial(LW); 00257 RandomVar M = marginalize(A[I],I); 00258 ... 00259 // Example of conditional mixture of n d-dimensional diagonal 00260 // Gaussians with neural network expectations: 00261 // (1) define the neural network 00262 RandomVar X(n_inputs); // X is the network input RV 00263 Var x(n_inputs); // x will be its value 00264 int n_inputs=4, n_hidden=5,n_outputs=d*n; 00265 Var layer1W(n_hidden,n_inputs), layer2W(n_outputs,n_hidden); 00266 Var layer1bias(n_hidden), layer2bias(n_outputs); 00267 RandomVar NetOutput = layer2bias+layer2W*tanh(layer1bias+layer1W*X); 00268 // (2) define the gaussian mixture 00269 // (2.1) define the gaussians 00270 RVArray normals(n); 00271 RVArray mu(n),logsigma(n); 00272 for (int i=0;i<n;i++) { 00273 mu[i]=NetOutput.subVec(i*d,d); // extract subvector as i-th mean vector 00274 normals[i]=Normal(mu[i],logsigma[i]); 00275 } 00276 // (2.2) build the mixture itself 00277 Var lw(n); 00278 lw->value.fill(1.0/n); 00279 RandomVar Y = mixture(normals,lw); //the "target" output random variable 00280 Var y(d); // its value 00281 VarArray tunable_parameters = 00282 lw & layer1W & layer2W & layer1bias & layer2bias; 00283 // each row of Mat is the concatenation of an x (input) and a y 00284 Mat observed_xy_matrix; 00285 // logP returns the path that goes from all "source" (constant) variables 00286 // into the computation of the given conditional probability 00287 Var cost = -logP((Y==y)|(X==x)); 00288 // note that we don't need to condition on the "non-random" parameters 00289 // such as the log-weights of the mixture (lw), but they will 00290 // occur as tunable parameters. 00291 // Below, the order of x and y in the observed_xy_matrix must 00292 // match their order in the second argument of meanOf. 00293 Var totalcost = meanOf(cost,x&y,VMat(observed_xy_matrix), 00294 observed_xy_matrix.length(),tunable_parameters); 00295 ... 00296 // Example in which some parameters W & B have to be fitted 00297 // to some data, while the hyper-parameters gamma that control 00298 // the distribution of W & B should be fitted to maximize 00299 // the likelihood of the data. 00300 int npoints = 10; // there are 10 (x,y) pairs in each observation 00301 Var muW, logvarW, muB, logvarB; // parameters of the prior 00302 RandomVar W = normal(muW,logvarW); // prior on W 00303 RandomVar B = normal(muB,logvarB); // prior on B 00304 Var log_var(1); // log-variance of Y 00305 Var ones(npoints); 00306 ones->value.fill(1.0); 00307 Var log_vars = ones*log_var; // make vector of npoints copies of log_var 00308 Var x(npoints); // input 00309 Var muXint, logvXint; // parameters of Xinterval 00310 RandomVar Xinterval = normal(muXint,logvXint); // prior on Xinterval 00311 RandomVar Y = normal(tanh(W*x+B),log_vars); 00312 Var cost = -logP(Y==y && Xinterval==vconcat(min(x) & max(x)); 00313 // note that the above requires marginalizing over W & B 00314 VarArray gamma = muXint & logvXint & log_var & muW & logvarW & muB & logvarB; 00315 Var totalcost = meanOf(cost,x&y,VMat(observed_xy_matrix), 00316 observed_xy_matrix.length(), gamma); 00317 ConjugateGradientOptimizer opt(gamma,totalcost, 1e-5, 1e-3, 100); 00318 real train_NLL = opt.optimize(); 00319 // to obtain a fit of Theta for a particular value of x and y, optimize 00320 Var w,b; 00321 Var fitcost = 00322 -logP((Y==y && Xinterval==vconcat(min(x) & max(x)))|(W==w && B==b)); 00323 ConjugateGradientOptimizer fitopt(w & b,fitcost, 1e-5, 1e-3, 100); 00324 real fit_NLL = fitopt.optimize(); 00325 // and the fitted parameters can be read in w->value and b->value. 00326 */ 00327 00328 00331 #ifndef RANDOMVAR_INC 00332 #define RANDOMVAR_INC 00333 00334 #include <plearn/opt/Optimizer.h> 00335 00338 #include "SampleVariable.h" 00339 00341 #include <plearn/vmat/VMat.h> 00342 00343 namespace PLearn { 00344 using namespace std; 00345 00346 00347 class RandomVariable; 00348 class RVArray; 00349 class RVInstance; 00350 class RVInstanceArray; 00351 class ConditionalExpression; 00352 00354 class RandomVar: public PP<RandomVariable> 00355 { 00356 public: 00357 RandomVar(); 00358 RandomVar(int length, int width=1); 00359 RandomVar(RandomVariable* v); 00360 RandomVar(const RandomVar& other); 00361 00362 RandomVar(const Vec& vec); 00363 RandomVar(const Mat& mat); 00364 RandomVar(const Var& var); 00365 00366 RandomVar(const RVArray& vars); 00367 00368 RandomVar operator[](RandomVar index); 00369 #if 0 00370 RandomVar operator[](int i); 00371 #endif 00372 00373 void operator=(const RVArray& vars); 00374 00376 void operator=(real f); 00377 void operator=(const Vec& v); 00378 void operator=(const Mat& m); 00379 void operator=(const Var& v); 00380 00383 RVInstance operator==(const Var& v) const; 00384 00386 bool operator==(const RandomVar& rv) const { return rv.ptr == this->ptr; } 00387 bool operator!=(const RandomVar& rv) const { return rv.ptr != this->ptr; } 00388 00390 RVArray operator&(const RandomVar& v) const; 00391 00401 ConditionalExpression operator|(RVArray rhs) const; 00402 00404 ConditionalExpression operator|(RVInstanceArray rhs) const; 00405 00406 #if 0 00409 RandomVar operator[](RandomVar index); 00410 RandomVar operator[](int i); 00412 RandomVar operator()(RandomVar i, RandomVar j); 00413 RandomVar operator()(int i, int j); 00414 #endif 00415 00416 }; 00417 00418 typedef RandomVar MatRandomVar; 00419 00420 00422 class RVArray: public Array<RandomVar> 00423 { 00424 public: 00425 RVArray(); 00426 RVArray(int n, int n_extra_allocated=0); 00427 RVArray(const Array<RandomVar>& va); 00428 RVArray(const RandomVar& v, int n_extra_allocated=0); 00429 RVArray(const RandomVar& v1, const RandomVar& v2, int n_extra_allocated=0); 00430 RVArray(const RandomVar& v1, const RandomVar& v2, const RandomVar& v3, 00431 int n_extra_allocated=0); 00432 00433 int length() const; 00434 00436 VarArray values() const; 00437 00440 00443 00445 RandomVar operator[](RandomVar index); 00446 00447 RandomVar& operator[](int i) 00448 { return Array<RandomVar>::operator[](i); } 00449 00450 const RandomVar& operator[](int i) const 00451 { return Array<RandomVar>::operator[](i); } 00452 00453 static int compareRVnumbers(const RandomVar* v1, const RandomVar* v2); 00454 00457 void sort(); 00458 }; 00459 00460 00462 class RVInstance 00463 { 00464 public: 00465 RandomVar V; 00466 Var v; 00467 00468 RVInstance(const RandomVar& VV, const Var& vv); 00469 RVInstance(); 00470 00471 RVInstanceArray operator&&(RVInstance rvi); 00472 00473 ConditionalExpression operator|(RVInstanceArray a); 00474 00476 void swap_v_and_Vvalue(); 00477 00478 }; 00479 00480 class RVInstanceArray: public Array<RVInstance> 00481 { 00482 public: 00483 RVInstanceArray(); 00484 RVInstanceArray(int n, int n_extra_allocated=0); 00485 RVInstanceArray(const Array<RVInstance>& a); 00486 RVInstanceArray(const RVInstance& v, int n_extra_allocated=0); 00487 RVInstanceArray(const RVInstance& v1, const RVInstance& v2, 00488 int n_extra_allocated=0); 00489 RVInstanceArray(const RVInstance& v1, const RVInstance& v2, 00490 const RVInstance& v3, int n_extra_allocated=0); 00491 00493 int length() const; 00494 00497 RVInstanceArray operator&&(RVInstance rhs); 00498 00506 ConditionalExpression operator|(RVInstanceArray rhs); 00507 00509 RVArray random_variables() const; 00510 00512 VarArray values() const; 00514 VarArray instances() const; 00515 00517 void swap_v_and_Vvalue() 00518 { for (int i=0;i<size();i++) (*this)[i].swap_v_and_Vvalue(); } 00519 00520 static int compareRVnumbers(const RVInstance* rvi1, const RVInstance* rvi2); 00521 00524 void sort(); 00525 00526 }; 00527 00528 class ConditionalExpression 00529 { 00530 public: 00531 RVInstance LHS; 00532 RVInstanceArray RHS; 00533 00535 ConditionalExpression(RVInstance lhs, RVInstanceArray rhs); 00537 ConditionalExpression(RVInstance lhs); 00539 ConditionalExpression(RandomVar lhs); 00542 ConditionalExpression(RVInstanceArray lhs); 00543 }; 00544 00545 class RandomVariable: public PPointable 00546 { 00547 friend class RandomVar; 00548 friend class RVInstanceArray; 00549 friend class RVArray; 00550 00551 static int rv_counter; 00552 00553 00554 protected: 00555 00558 const int rv_number; 00559 00560 public: 00561 const RVArray parents; 00562 00567 Var value; 00568 00569 protected: 00574 bool marked; 00575 00576 bool EMmark; 00577 bool pmark; 00578 00581 bool* learn_the_parameters; 00582 00583 public: 00585 RandomVariable(int thelength, int thewidth=1); 00586 RandomVariable(const Vec& the_value); 00587 RandomVariable(const Mat& the_value); 00588 RandomVariable(const Var& the_value); 00590 RandomVariable(const RVArray& parents, int thelength); 00591 RandomVariable(const RVArray& parents, int thelength, int thewidth); 00592 00593 virtual char* classname() = 0; 00594 00595 virtual int length() { return value->length(); } 00596 virtual int width() { return value->width(); } 00597 int nelems() { return value->nelems(); } 00598 bool isScalar() { return length()==1 && width()==1; } 00599 bool isVec() { return width()==1 || length()==1; } 00600 bool isColumnVec() { return width()==1; } 00601 bool isRowVec() { return length()==1; } 00602 00609 virtual bool isNonRandom() = 0; 00612 inline bool isConstant() { return isNonRandom() && value->isConstant(); } 00613 00616 virtual bool isDiscrete() = 0; 00617 00622 RandomVar subVec(int start, int length); 00623 00625 00632 00633 00638 virtual void setValueFromParentsValue() = 0; 00639 00640 void markRHSandSetKnownValues(const RVInstanceArray& RHS) 00641 { 00642 for (int i=0;i<RHS.size();i++) 00643 RHS[i].V->mark(RHS[i].v); 00644 setKnownValues(); 00645 } 00646 00654 virtual void EMBprop(const Vec obs, real posterior) = 0; 00655 00660 virtual void EMUpdate(); 00661 00668 virtual bool canStopEM(); 00669 00672 virtual void EMTrainingInitialize(const RVArray& parameters_to_learn); 00673 00676 virtual void EMEpochInitialize(); 00677 00681 00682 virtual void mark(Var v) { marked = true; value = v; } 00683 virtual void mark() { marked = true; } 00684 virtual void unmark() { marked = false; } 00685 virtual void clearEMmarks(); 00686 00688 virtual void unmarkAncestors(); 00689 00690 virtual bool isMarked() { return marked; } 00691 00696 virtual void setKnownValues(); 00697 00703 virtual Var logP(const Var& obs, const RVInstanceArray& RHS, 00704 RVInstanceArray* parameters_to_learn=0) = 0; 00705 virtual Var P(const Var& obs, const RVInstanceArray& RHS); 00706 00712 virtual Var ElogP(const Var& obs, RVArray& parameters_to_learn, 00713 const RVInstanceArray& RHS); 00714 00728 00729 00734 00735 virtual real EM(const RVArray& parameters_to_learn, 00736 VarArray& prop_path, 00737 VarArray& observedVars, 00738 VMat distr, int n_samples, 00739 int max_n_iterations, 00740 real relative_improvement_threshold, 00741 bool accept_worsening_likelihood=false); 00742 00750 virtual real epoch(VarArray& prop_path, 00751 VarArray& observed_vars, const VMat& distr, 00752 int n_samples, 00753 bool do_EM_learning=true); 00754 00755 virtual ~RandomVariable(); 00756 00757 }; 00758 00759 00761 00771 RandomVar operator*(RandomVar a, RandomVar b); 00772 00779 RandomVar operator+(RandomVar a, RandomVar b); 00780 00783 RandomVar operator-(RandomVar a, RandomVar b); 00784 00787 RandomVar operator/(RandomVar a, RandomVar b); 00788 00790 RandomVar exp(RandomVar x); 00791 00793 RandomVar log(RandomVar x); 00794 00795 RandomVar extend(RandomVar v, real extension_value = 1.0, int n_extend = 1); 00796 00797 RandomVar hconcat(const RVArray& a); 00798 00813 00814 00819 00820 //real EMbyEMBprop(ConditionalExpression conditional_expression, 00821 real EM(ConditionalExpression conditional_expression, 00822 RVArray parameters_to_learn, 00823 VMat distr, int n_samples, int max_n_iterations=1, 00824 real relative_improvement_threshold=0.001, 00825 bool accept_worsening_likelihood=false, 00826 bool compute_final_train_NLL=true); 00827 00828 real oEM(ConditionalExpression conditional_expression, 00829 RVArray parameters_to_learn, 00830 VMat distr, int n_samples, int max_n_iterations, 00831 real relative_improvement_threshold=0.001, 00832 bool compute_final_train_NLL=true); 00833 00834 real oEM(ConditionalExpression conditional_expression, 00835 RVArray parameters_to_learn, 00836 VMat distr, int n_samples, 00837 Optimizer& MStepOptimizer, 00838 int max_n_iterations, 00839 real relative_improvement_threshold=0.001, 00840 bool compute_final_train_NLL=true); 00841 00851 Var logP(ConditionalExpression conditional_expression, 00852 bool clearMarksUponReturn=true, 00853 RVInstanceArray* parameters_to_learn=0); 00854 00864 Var P(ConditionalExpression conditional_expression, 00865 bool clearMarksUponReturn=true); 00866 00872 Var ElogP(ConditionalExpression conditional_expression, 00873 RVInstanceArray& parameters_to_learn, 00874 bool clearMarksUponReturn=true); 00875 00880 RandomVar marginalize(const RandomVar& RV, const RandomVar& hiddenRV); 00891 Vec sample(ConditionalExpression conditional_expression); 00892 00897 Var Sample(ConditionalExpression conditional_expression); 00898 00906 void sample(ConditionalExpression conditional_expression,Mat& samples); 00907 00909 00914 RandomVar normal(real mean=0, real standard_dev=1, int d=1, 00915 real minimum_standard_deviation=1e-6); 00916 00921 RandomVar normal(RandomVar mean, RandomVar log_variance, 00922 real minimum_standard_deviation=1e-6); 00923 00929 RandomVar mixture(RVArray components, RandomVar log_weights); 00930 00937 RandomVar multinomial(RandomVar log_probabilities); 00938 00961 class StochasticRandomVariable: public RandomVariable 00962 { 00963 public: 00964 StochasticRandomVariable(int length=1); 00965 StochasticRandomVariable(const RVArray& params,int length); 00966 StochasticRandomVariable(const RVArray& params,int length, int width); 00967 00969 00971 virtual bool isNonRandom() { return false; } 00972 00974 virtual bool isDiscrete() { return false; } 00975 00976 virtual void setKnownValues(); 00977 00981 01000 }; 01001 01006 class FunctionalRandomVariable: public RandomVariable { 01007 public: 01008 FunctionalRandomVariable(int length); 01009 FunctionalRandomVariable(int length, int width); 01010 FunctionalRandomVariable(const Vec& the_value); 01011 FunctionalRandomVariable(const Mat& the_value); 01012 FunctionalRandomVariable(const Var& the_value); 01013 FunctionalRandomVariable(const RVArray& parents,int length); 01014 FunctionalRandomVariable(const RVArray& parents,int length, int width); 01015 01017 01018 virtual Var logP(const Var& obs, const RVInstanceArray& RHS, 01019 RVInstanceArray* parameters_to_learn); 01020 01022 bool isNonRandom(); 01023 01025 virtual bool isDiscrete(); 01026 01028 01032 01047 virtual bool invertible(const Var& obs, 01048 RVInstanceArray& unobserved_parents, 01049 Var** JacobianCorrection); 01050 01052 virtual void setValueFromParentsValue() = 0; 01053 01057 01080 }; 01081 01087 class NonRandomVariable: public FunctionalRandomVariable 01088 { 01089 public: 01092 NonRandomVariable(int thelength); 01093 NonRandomVariable(int thelength, int thewidth); 01098 NonRandomVariable(const Var& v); 01099 01100 virtual char* classname() { return "NonRandomVariable"; } 01101 01102 void setValueFromParentsValue() { } 01103 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01104 Var** JacobianCorrection) 01105 { return true; } 01106 void EMBprop(const Vec obs, real post) { } 01107 }; 01108 01109 class JointRandomVariable: public FunctionalRandomVariable 01110 { 01111 public: 01112 JointRandomVariable(const RVArray& variables); 01113 01114 virtual char* classname() { return "JointRandomVariable"; } 01115 01116 void setValueFromParentsValue(); 01117 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01118 Var** JacobianCorrection); 01119 void EMBprop(const Vec obs, real post); 01120 }; 01121 01124 class RandomElementOfRandomVariable: public FunctionalRandomVariable 01125 { 01126 public: 01127 RandomElementOfRandomVariable(const RandomVar& v, const RandomVar& index); 01128 01129 virtual char* classname() { return "RandomElementOfRandomVariable"; } 01130 01131 void setValueFromParentsValue(); 01132 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01133 Var** JacobianCorrection); 01134 void EMBprop(const Vec obs, real post); 01135 01137 inline const RandomVar& v() { return parents[0]; } 01138 inline const RandomVar& index() { return parents[1]; } 01139 01140 }; 01141 01142 01148 class RVArrayRandomElementRandomVariable: public FunctionalRandomVariable 01149 { 01150 public: 01151 RVArrayRandomElementRandomVariable(const RVArray& table, const RandomVar& index); 01152 01153 virtual char* classname() { return "RVArrayRandomElementRandomVariable"; } 01154 01155 void setValueFromParentsValue(); 01156 virtual Var logP(const Var& obs, const RVInstanceArray& RHS, 01157 RVInstanceArray* parameters_to_learn=0); 01158 void EMBprop(const Vec obs, real post); 01159 01161 inline const RandomVar& index() { return parents[parents.size()-1]; } 01162 01163 }; 01164 01165 class NegRandomVariable: public FunctionalRandomVariable 01166 { 01167 public: 01168 NegRandomVariable(RandomVariable* input); 01169 01170 virtual char* classname() { return "NegRandomVariable"; } 01171 01172 void setValueFromParentsValue(); 01173 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01174 Var** JacobianCorrection); 01175 void EMBprop(const Vec obs, real post); 01176 }; 01177 01178 class ExpRandomVariable: public FunctionalRandomVariable 01179 { 01180 public: 01181 ExpRandomVariable(RandomVar& input); 01182 01183 virtual char* classname() { return "ExpRandomVariable"; } 01184 01185 void setValueFromParentsValue(); 01186 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01187 Var** JacobianCorrection); 01188 void EMBprop(const Vec obs, real post); 01189 }; 01190 01191 01192 class LogRandomVariable: public FunctionalRandomVariable 01193 { 01194 public: 01195 LogRandomVariable(RandomVar& input); 01196 01197 virtual char* classname() { return "LogRandomVariable"; } 01198 01199 void setValueFromParentsValue(); 01200 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01201 Var** JacobianCorrection); 01202 void EMBprop(const Vec obs, real post); 01203 }; 01204 01205 class DiagonalNormalRandomVariable: public StochasticRandomVariable 01206 { 01211 protected: 01212 real minimum_variance; 01213 real normfactor; 01214 bool shared_variance; 01215 01216 public: 01217 DiagonalNormalRandomVariable(const RandomVar& mean, 01218 const RandomVar& log_variance, 01219 real minimum_standard_deviation = 1e-10); 01220 01221 virtual char* classname() { return "DiagonalNormalRandomVariable"; } 01222 01223 Var logP(const Var& obs, const RVInstanceArray& RHS, 01224 RVInstanceArray* parameters_to_learn); 01225 void setValueFromParentsValue(); 01226 void EMUpdate(); 01227 void EMBprop(const Vec obs, real posterior); 01228 void EMEpochInitialize(); 01229 01231 inline const RandomVar& mean() { return parents[0]; } 01232 inline const RandomVar& log_variance() { return parents[1]; } 01233 inline bool& learn_the_mean() { return learn_the_parameters[0]; } 01234 inline bool& learn_the_variance() { return learn_the_parameters[1]; } 01235 01236 protected: 01238 Vec mu_num; 01239 Vec sigma_num; 01240 real denom; 01241 }; 01242 01243 class MixtureRandomVariable: public StochasticRandomVariable 01244 { 01245 protected: 01247 RVArray components; 01248 01249 public: 01250 MixtureRandomVariable(const RVArray& components, 01251 const RandomVar& log_weights); 01252 01253 virtual char* classname() { return "MixtureRandomVariable"; } 01254 01257 inline const RandomVar& log_weights() { return parents[0]; } 01258 inline bool& learn_the_weights() { return learn_the_parameters[0]; } 01259 01260 virtual Var logP(const Var& obs, const RVInstanceArray& RHS, 01261 RVInstanceArray* parameters_to_learn); 01262 virtual Var ElogP(const Var& obs, RVInstanceArray& parameters_to_learn, 01263 const RVInstanceArray& RHS); 01264 01265 virtual void setValueFromParentsValue(); 01266 virtual void EMUpdate(); 01267 virtual void EMBprop(const Vec obs, real posterior); 01268 virtual void EMEpochInitialize(); 01269 virtual void EMTrainingInitialize(const RVArray& parameters_to_learn); 01270 virtual bool isDiscrete(); 01271 virtual bool canStopEM(); 01272 virtual void setKnownValues(); 01273 virtual void unmarkAncestors(); 01274 virtual void clearEMmarks(); 01275 01276 protected: 01278 Vec posteriors; 01279 Vec sum_posteriors; 01280 01282 VarArray componentsLogP; 01283 Var lw; 01284 Var logp; 01285 }; 01286 01293 class PlusRandomVariable: public FunctionalRandomVariable 01294 { 01295 public: 01296 PlusRandomVariable(RandomVar input1, RandomVar input2); 01297 01298 virtual char* classname() { return "PlusRandomVariable"; } 01299 01300 void setValueFromParentsValue(); 01301 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01302 Var** JacobianCorrection); 01303 void EMBprop(const Vec obs, real post); 01304 void EMTrainingInitialize(const RVArray& parameters_to_learn); 01305 void EMEpochInitialize(); 01306 void EMUpdate(); 01307 01309 const RandomVar& X0() { return parents[0]; } 01310 const RandomVar& X1() { return parents[1]; } 01311 01313 bool learn_X0() { return learn_the_parameters[0]; } 01314 bool learn_X1() { return learn_the_parameters[1]; } 01315 bool learn_something; 01316 RandomVar parent_to_learn; 01317 RandomVar other_parent; 01318 Vec numerator; 01319 Vec difference; 01320 real denom; 01321 }; 01322 01329 class MinusRandomVariable: public FunctionalRandomVariable 01330 { 01331 public: 01332 MinusRandomVariable(RandomVar input1, RandomVar input2); 01333 01334 virtual char* classname() { return "MinusRandomVariable"; } 01335 01336 void setValueFromParentsValue(); 01337 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01338 Var** JacobianCorrection); 01339 void EMBprop(const Vec obs, real post); 01340 void EMTrainingInitialize(const RVArray& parameters_to_learn); 01341 void EMEpochInitialize(); 01342 void EMUpdate(); 01343 01345 const RandomVar& X0() { return parents[0]; } 01346 const RandomVar& X1() { return parents[1]; } 01347 01349 bool learn_X0() { return learn_the_parameters[0]; } 01350 bool learn_X1() { return learn_the_parameters[1]; } 01351 bool learn_something; 01352 RandomVar parent_to_learn; 01353 RandomVar other_parent; 01354 Vec numerator; 01355 Vec difference; 01356 real denom; 01357 }; 01358 01359 01364 class ElementWiseDivisionRandomVariable: public FunctionalRandomVariable 01365 { 01366 public: 01367 ElementWiseDivisionRandomVariable(RandomVar input1, RandomVar input2); 01368 01369 virtual char* classname() { return "ElementWiseDivisionRandomVariable"; } 01370 01371 void setValueFromParentsValue(); 01372 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01373 Var** JacobianCorrection); 01374 void EMBprop(const Vec obs, real post); 01375 void EMTrainingInitialize(const RVArray& parameters_to_learn); 01376 void EMEpochInitialize(); 01377 void EMUpdate(); 01378 01380 const RandomVar& X0() { return parents[0]; } 01381 const RandomVar& X1() { return parents[1]; } 01382 01383 }; 01384 01385 01417 class ProductRandomVariable: public FunctionalRandomVariable 01418 { 01419 public: 01420 int m,n,l; 01421 01422 ProductRandomVariable(MatRandomVar input1, MatRandomVar input2); 01423 01424 virtual char* classname() { return "ProductRandomVariable"; } 01425 01426 void setValueFromParentsValue(); 01427 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01428 Var** JacobianCorrection); 01429 void EMBprop(const Vec obs, real post); 01430 void EMTrainingInitialize(const RVArray& parameters_to_learn); 01431 void EMEpochInitialize(); 01432 void EMUpdate(); 01433 01435 const RandomVar& X0() { return parents[0]; } 01436 const RandomVar& X1() { return parents[1]; } 01437 bool scalars; 01438 01440 bool learn_X0() { return learn_the_parameters[0]; } 01441 bool learn_X1() { return learn_the_parameters[1]; } 01442 bool learn_something; 01443 Mat X0numerator; 01444 Mat X1numerator; 01445 Mat denom; 01446 Mat tmp1; 01447 Mat tmp2; 01448 Mat tmp3; 01449 Vec vtmp3; 01450 Vec tmp4; 01451 }; 01452 01454 class SubVecRandomVariable: public FunctionalRandomVariable 01455 { 01456 protected: 01457 int start; 01458 public: 01459 SubVecRandomVariable(const RandomVar& parent,int start, int length); 01460 virtual char* classname() { return "SubvecRandomVariable"; } 01461 void setValueFromParentsValue(); 01462 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01463 Var** JacobianCorrection); 01464 void EMBprop(const Vec obs, real posterior); 01465 }; 01466 01472 class MultinomialRandomVariable: public StochasticRandomVariable 01473 { 01474 public: 01477 MultinomialRandomVariable(const RandomVar& log_probabilities); 01478 01480 inline const RandomVar& log_probabilities() { return parents[0]; } 01481 inline bool learn_the_probabilities() { return learn_the_parameters[0]; } 01482 01483 virtual char* classname() { return "MultinomialRandomVariable"; } 01484 01485 Var logP(const Var& obs, const RVInstanceArray& RHS, 01486 RVInstanceArray* parameters_to_learn); 01487 void setValueFromParentsValue(); 01488 void EMUpdate(); 01489 void EMBprop(const Vec obs, real posterior); 01490 void EMEpochInitialize(); 01491 bool isDiscrete(); 01492 01493 protected: 01495 Vec sum_posteriors; 01496 }; 01497 01498 01505 class ExtendedRandomVariable: public FunctionalRandomVariable 01506 { 01507 protected: 01508 int n_extend; 01509 real fill_value; 01510 public: 01511 ExtendedRandomVariable(const RandomVar& parent, real fill_value=1.0,int n_extend=1); 01512 virtual char* classname() { return "ExtendedRandomVariable"; } 01513 void setValueFromParentsValue(); 01514 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01515 Var** JacobianCorrection); 01516 void EMBprop(const Vec obs, real posterior); 01517 }; 01518 01521 class ConcatColumnsRandomVariable: public FunctionalRandomVariable 01522 { 01523 public: 01524 ConcatColumnsRandomVariable(const RVArray& vars); 01525 virtual char* classname() { return "ConcatColumnsRandomVariable"; } 01526 void setValueFromParentsValue(); 01527 bool invertible(const Var& obs, RVInstanceArray& unobserved_parents, 01528 Var** JacobianCorrection); 01529 void EMBprop(const Vec obs, real posterior); 01530 }; 01531 01537 01538 01539 class RandomVarVMatrix: public VMatrix 01540 { 01541 protected: 01542 RandomVar rv; 01543 Var instance; 01544 VarArray prop_path; 01545 01546 public: 01547 RandomVarVMatrix(ConditionalExpression conditional_expression); 01548 virtual int nVars() { return instance->length(); } 01549 virtual Vec sample() 01550 { 01551 prop_path.fprop(); 01552 return instance->value; 01553 } 01554 }; 01555 01557 01558 } // end of namespace PLearn 01559 01560 #endif 01561

Generated on Tue Aug 17 16:03:22 2004 for PLearn by doxygen 1.3.7