Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

Variable.h

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // PLearn (A C++ Machine Learning Library) 00004 // Copyright (C) 1998 Pascal Vincent 00005 // Copyright (C) 1999-2002 Pascal Vincent, Yoshua Bengio and University of Montreal 00006 // 00007 00008 // Redistribution and use in source and binary forms, with or without 00009 // modification, are permitted provided that the following conditions are met: 00010 // 00011 // 1. Redistributions of source code must retain the above copyright 00012 // notice, this list of conditions and the following disclaimer. 00013 // 00014 // 2. Redistributions in binary form must reproduce the above copyright 00015 // notice, this list of conditions and the following disclaimer in the 00016 // documentation and/or other materials provided with the distribution. 00017 // 00018 // 3. The name of the authors may not be used to endorse or promote 00019 // products derived from this software without specific prior written 00020 // permission. 00021 // 00022 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00023 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00024 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00025 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00026 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00027 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00028 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00029 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00030 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00031 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00032 // 00033 // This file is part of the PLearn library. For more information on the PLearn 00034 // library, go to the PLearn Web site at www.plearn.org 00035 00036 00037 00038 00039 /* ******************************************************* 00040 * $Id: Variable.h,v 1.17 2004/07/26 21:04:14 chrish42 Exp $ 00041 * This file is part of the PLearn library. 00042 ******************************************************* */ 00043 00044 #ifndef Variable_INC 00045 #define Variable_INC 00046 00047 #include <plearn/math/TMat.h> 00048 #include <plearn/base/Object.h> 00049 00050 namespace PLearn { 00051 using namespace std; 00052 00053 class Variable; 00054 class VarArray; 00055 class RandomVariable; 00056 class RandomVar; 00057 00058 class Var: public PP<Variable> 00059 { 00060 friend class RandomVariable; 00061 friend class RandomVar; 00062 00063 public: 00064 Var(); 00065 Var(Variable* v); 00066 Var(Variable* v, const char* name); 00067 Var(const Var& other); 00068 Var(const Var& other, const char* name); 00069 explicit Var(int the_length, int width_=1); 00070 Var(int the_length, int the_width, const char* name); 00071 Var(int the_length, const char* name); 00072 explicit Var(const Vec& vec, bool vertival=true); 00073 explicit Var(const Mat& mat); 00074 00075 int length() const; 00076 int width() const; 00077 00078 Var subVec(int start, int len, bool transpose=false) const; 00079 Var subMat(int i, int j, int sublength, int subwidth, bool transpose=false) const; 00080 Var row(int i, bool transpose=false) const; 00081 Var column(int j, bool transpose=false) const; 00082 Var operator()(int i) const; 00083 Var operator()(int i, int j) const; 00084 00086 Var operator[](int i) const; 00087 Var operator[](Var i) const; 00088 00090 Var operator()(Var index) const; 00091 00093 Var operator()(Var i, Var j) const; 00094 00095 void operator=(real f); 00096 void operator=(const Vec& v); 00097 void operator=(const Mat& m); 00098 }; 00099 00100 class Variable: public Object 00101 { 00102 public: 00103 typedef Object inherited; 00104 00105 protected: 00107 Variable() : varnum(++nvars), marked(false), varname(), allows_partial_update(false), 00108 gradient_status(0), valuedata(0), gradientdata(0), min_value(-FLT_MAX), 00109 max_value(FLT_MAX), dont_bprop_here(false) {} 00110 00111 static void declareOptions(OptionList & ol); 00112 00113 friend class Var; 00114 friend class RandomVariable; 00115 friend class ProductRandomVariable; 00116 friend class Function; 00117 00118 friend class UnaryVariable; 00119 friend class BinaryVariable; 00120 friend class NaryVariable; 00121 00122 public: 00123 static int nvars; 00124 int varnum; 00125 00126 protected: 00127 bool marked; 00128 string varname; 00129 00130 protected: 00131 bool allows_partial_update; 00132 int gradient_status; 00133 TVec<int> rows_to_update; 00134 00135 public: 00136 Vec value; 00137 Vec gradient; 00138 Mat matValue; 00139 Mat matGradient; 00140 Vec rValue; 00141 Mat matRValue; 00142 Mat matDiagHessian; 00143 00144 real* valuedata; 00145 real* gradientdata; 00146 real min_value, max_value; 00147 Var g; 00148 Vec diaghessian; 00149 real* diaghessiandata; 00150 real* rvaluedata; 00151 bool dont_bprop_here; 00152 00153 public: 00154 Variable(int thelength, int thewidth); 00155 Variable(const Mat& m); 00156 00157 int length() const { return matValue.length(); } 00158 int width() const { return matValue.width(); } 00159 int size() const { return matValue.size(); } // length*width 00160 int nelems() const { return size(); } 00161 00168 virtual void recomputeSize(int& l, int& w) const; 00169 00173 void resize(int l, int w); 00174 00179 void sizeprop(); 00180 00182 virtual void setParents(const VarArray& parents); 00183 00185 Variable(const Variable& v); 00186 00187 private: 00188 void build_(); 00189 public: 00190 virtual void build(); 00191 00192 bool isScalar() const { return length()==1 && width()==1; } 00193 bool isVec() const { return length()==1 || width()==1; } 00194 bool isColumnVec() const { return width()==1; } 00195 bool isRowVec() const { return length()==1; } 00196 00197 PLEARN_DECLARE_ABSTRACT_OBJECT(Variable); 00198 00199 virtual void makeDeepCopyFromShallowCopy(map<const void*, void*>& copies); 00200 00202 virtual void fprop() =0; 00204 00206 inline void sizefprop() 00207 { sizeprop(); fprop(); } 00208 00209 virtual void bprop() =0; 00215 virtual void bbprop(); 00217 virtual void fbprop(); 00219 virtual void fbbprop(); 00221 virtual void symbolicBprop(); 00222 00223 virtual void rfprop(); 00224 00225 virtual void copyValueInto(Vec v) { v << value; } 00226 virtual void copyGradientInto(Vec g) { g << gradient; } 00227 00228 virtual void print(ostream& out) const; 00229 00232 string getName() const; 00234 void setName(const string& the_name); 00235 bool nameIsSet() { return varname.size()>0; } 00236 00240 Mat defineGradientLocation(const Mat& m); 00241 00242 virtual void printInfo(bool print_gradient=false) = 0; 00243 virtual void printInfos(bool print_gradient=false); 00244 00245 Var subVec(int start, int len, bool transpose=false); 00246 Var subMat(int i, int j, int sublength, int subwidth, bool transpose=false); 00247 Var row(int i, bool transpose=false) { return subMat(i,0,1,width(),transpose); } 00248 Var column(int j, bool transpose=false) { return subMat(0,j,length(),1,transpose); } 00249 00250 void setDontBpropHere(bool val) { dont_bprop_here = val; } 00251 void setKeepPositive() { min_value = 0; } 00252 void setMinValue(real minv=-FLT_MAX) { min_value = minv; } 00253 void setMaxValue(real maxv=FLT_MAX) { max_value = maxv; } 00254 void setBoxConstraint(real minv, real maxv) { min_value = minv; max_value = maxv; } 00255 00256 void setMark() { marked = true; } 00257 void clearMark() { marked = false; } 00258 bool isMarked() { return marked; } 00259 00260 void fillGradient(real value) { gradient.fill(value); } 00261 void clearGradient() 00262 { 00263 if(!allows_partial_update) 00264 gradient.clear(); 00265 else 00266 { 00267 for (int r=0;r<rows_to_update.length();r++) 00268 { 00269 int row = rows_to_update[r]; 00270 matGradient.row(row).clear(); 00271 } 00272 rows_to_update.resize(0); 00273 gradient_status=0; 00274 } 00275 } 00276 void clearDiagHessian(); 00277 void clearSymbolicGradient() { g = Var(); } 00278 00284 bool update(real step_size, Vec direction_vec, real coeff = 1.0, real b = 0.0); 00285 00290 bool update(Vec step_sizes, Vec direction_vec, real coeff = 1.0, real b = 0.0); 00291 00293 inline void updateAndClear(); 00294 00299 bool update(real step_size); 00300 00302 void allowPartialUpdates() 00303 { 00304 allows_partial_update=true; 00305 rows_to_update.resize(length()); // make sure that there are always enough elements 00306 rows_to_update.resize(0); 00307 gradient_status=0; 00308 } 00309 00311 void disallowPartialUpdates() 00312 { 00313 allows_partial_update = false; 00314 gradient_status=2; 00315 } 00316 00318 void updateRow(int row) 00319 { 00320 if (gradient_status!=2 && allows_partial_update && !rows_to_update.contains(row)) 00321 { 00322 rows_to_update.append(row); 00323 if (gradient_status==0) gradient_status=1; 00324 } 00325 } 00326 00327 00333 bool update(Vec new_value); 00334 00339 real maxUpdate(Vec direction); 00340 00345 virtual bool markPath() =0; 00346 00349 virtual void buildPath(VarArray& proppath) =0; 00350 00351 virtual void oldread(istream& in); 00352 virtual void write(ostream& out) const; 00353 00354 00355 00356 00357 void copyFrom(const Vec& v) { value << v; } 00358 void copyTo(Vec& v) { v << value; } 00359 void copyGradientFrom(const Vec& v) { gradient << v; } 00360 void copyGradientTo(Vec& v) { v << gradient; } 00361 void makeSharedValue(real* x, int n); 00362 void makeSharedGradient(real* x, int n); 00363 00364 void makeSharedValue(PP<Storage<real> > storage, int offset_=0); 00365 void makeSharedGradient(PP<Storage<real> > storage, int offset_=0); 00366 void makeSharedValue(Vec& v, int offset_=0); 00367 void makeSharedGradient(Vec& v, int offset_=0); 00368 00369 void copyRValueFrom(const Vec& v) { resizeRValue(); rValue << v; } 00370 void copyRValueTo(Vec& v) { resizeRValue(); v << rValue; } 00371 void makeSharedRValue(real* x, int n); 00372 void makeSharedRValue(PP<Storage<real> > storage, int offset_=0); 00373 void makeSharedRValue(Vec& v, int offset_=0); 00374 00375 virtual bool isConstant() { return false; } 00376 00382 virtual void fprop_from_all_sources(); 00383 00386 virtual VarArray sources() = 0; 00387 00390 virtual VarArray random_sources() = 0; 00391 00393 virtual VarArray ancestors() = 0; 00395 virtual void unmarkAncestors() = 0; 00396 00399 virtual VarArray parents() = 0; 00400 00402 virtual void accg(Var v); 00403 00406 virtual void verifyGradient(real step=0.001); 00407 00409 virtual void resizeDiagHessian(); 00410 00411 virtual void resizeRValue(); 00412 }; 00413 00414 DECLARE_OBJECT_PTR(Variable); 00415 DECLARE_OBJECT_PP(Var, Variable); 00416 00417 // set value += gradient and clears the gradient 00418 inline void Variable::updateAndClear() 00419 { 00420 for(int i=0; i<nelems(); i++) 00421 valuedata[i] += gradientdata[i]; 00422 gradient.clear(); 00423 } 00424 00425 void varDeepCopyField(Var& field, CopiesMap& copies); 00426 00427 00428 inline Var Var::row(int i, bool transpose) const 00429 { 00430 return subMat(i, 0, 1, width(), transpose); 00431 } 00432 00433 inline Var Var::column(int j, bool transpose) const 00434 { 00435 return subMat(0, j, length(), 1, transpose); 00436 } 00437 00438 inline Var Var::operator()(int i) const 00439 { 00440 return row(i, false); 00441 } 00442 00443 inline Var Var::operator()(int i, int j) const 00444 { 00445 return subMat(i, j, 1, 1); 00446 } 00447 00448 } // end of namespace PLearn 00449 00450 #endif 00451

Generated on Tue Aug 17 16:10:11 2004 for PLearn by doxygen 1.3.7