Variable.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
#ifndef Variable_INC
00045
#define Variable_INC
00046
00047
#include <plearn/math/TMat.h>
00048
#include <plearn/base/Object.h>
00049
00050
namespace PLearn {
00051
using namespace std;
00052
00053
class Variable;
00054
class VarArray;
00055
class RandomVariable;
00056
class RandomVar;
00057
00058 class Var:
public PP<Variable>
00059 {
00060
friend class RandomVariable;
00061
friend class RandomVar;
00062
00063
public:
00064
Var();
00065
Var(
Variable* v);
00066
Var(
Variable* v,
const char* name);
00067
Var(
const Var& other);
00068
Var(
const Var& other,
const char* name);
00069
explicit Var(
int the_length,
int width_=1);
00070
Var(
int the_length,
int the_width,
const char* name);
00071
Var(
int the_length,
const char* name);
00072
explicit Var(
const Vec& vec,
bool vertival=
true);
00073
explicit Var(
const Mat& mat);
00074
00075
int length()
const;
00076
int width()
const;
00077
00078
Var subVec(
int start,
int len,
bool transpose=
false)
const;
00079
Var subMat(
int i,
int j,
int sublength,
int subwidth,
bool transpose=
false)
const;
00080
Var row(
int i,
bool transpose=
false)
const;
00081
Var column(
int j,
bool transpose=
false)
const;
00082
Var operator()(
int i)
const;
00083
Var operator()(
int i,
int j)
const;
00084
00086
Var operator[](
int i)
const;
00087
Var operator[](
Var i)
const;
00088
00090
Var operator()(
Var index)
const;
00091
00093
Var operator()(
Var i,
Var j)
const;
00094
00095
void operator=(
real f);
00096
void operator=(
const Vec& v);
00097
void operator=(
const Mat& m);
00098 };
00099
00100 class Variable:
public Object
00101 {
00102
public:
00103 typedef Object inherited;
00104
00105
protected:
00107 Variable() :
varnum(++
nvars),
marked(false),
varname(),
allows_partial_update(false),
00108
gradient_status(0),
valuedata(0),
gradientdata(0),
min_value(-FLT_MAX),
00109
max_value(FLT_MAX),
dont_bprop_here(false) {}
00110
00111
static void declareOptions(
OptionList & ol);
00112
00113
friend class Var;
00114
friend class RandomVariable;
00115
friend class ProductRandomVariable;
00116
friend class Function;
00117
00118
friend class UnaryVariable;
00119
friend class BinaryVariable;
00120
friend class NaryVariable;
00121
00122
public:
00123
static int nvars;
00124 int varnum;
00125
00126
protected:
00127 bool marked;
00128 string varname;
00129
00130
protected:
00131 bool allows_partial_update;
00132 int gradient_status;
00133 TVec<int> rows_to_update;
00134
00135
public:
00136 Vec value;
00137 Vec gradient;
00138 Mat matValue;
00139 Mat matGradient;
00140 Vec rValue;
00141 Mat matRValue;
00142 Mat matDiagHessian;
00143
00144 real*
valuedata;
00145 real*
gradientdata;
00146 real min_value,
max_value;
00147 Var g;
00148 Vec diaghessian;
00149 real*
diaghessiandata;
00150 real*
rvaluedata;
00151 bool dont_bprop_here;
00152
00153
public:
00154
Variable(
int thelength,
int thewidth);
00155
Variable(
const Mat& m);
00156
00157 int length()
const {
return matValue.
length(); }
00158 int width()
const {
return matValue.
width(); }
00159 int size()
const {
return matValue.
size(); }
00160 int nelems()
const {
return size(); }
00161
00168
virtual void recomputeSize(
int& l,
int& w)
const;
00169
00173
void resize(
int l,
int w);
00174
00179
void sizeprop();
00180
00182
virtual void setParents(
const VarArray& parents);
00183
00185
Variable(
const Variable& v);
00186
00187
private:
00188
void build_();
00189
public:
00190
virtual void build();
00191
00192 bool isScalar()
const {
return length()==1 &&
width()==1; }
00193 bool isVec()
const {
return length()==1 ||
width()==1; }
00194 bool isColumnVec()
const {
return width()==1; }
00195 bool isRowVec()
const {
return length()==1; }
00196
00197
PLEARN_DECLARE_ABSTRACT_OBJECT(
Variable);
00198
00199
virtual void makeDeepCopyFromShallowCopy(map<const void*, void*>& copies);
00200
00202
virtual void fprop() =0;
00204
00206 inline void sizefprop()
00207 {
sizeprop();
fprop(); }
00208
00209
virtual void bprop() =0;
00215
virtual void bbprop();
00217
virtual void fbprop();
00219
virtual void fbbprop();
00221
virtual void symbolicBprop();
00222
00223
virtual void rfprop();
00224
00225 virtual void copyValueInto(
Vec v) { v <<
value; }
00226 virtual void copyGradientInto(
Vec g) { g <<
gradient; }
00227
00228
virtual void print(ostream& out)
const;
00229
00232
string getName() const;
00234
void setName(const
string& the_name);
00235 bool nameIsSet() {
return varname.size()>0; }
00236
00240
Mat defineGradientLocation(
const Mat& m);
00241
00242
virtual void printInfo(
bool print_gradient=
false) = 0;
00243
virtual void printInfos(
bool print_gradient=
false);
00244
00245
Var subVec(
int start,
int len,
bool transpose=
false);
00246
Var subMat(
int i,
int j,
int sublength,
int subwidth,
bool transpose=
false);
00247 Var row(
int i,
bool transpose=
false) {
return subMat(i,0,1,
width(),
transpose); }
00248 Var column(
int j,
bool transpose=
false) {
return subMat(0,j,
length(),1,
transpose); }
00249
00250 void setDontBpropHere(
bool val) {
dont_bprop_here =
val; }
00251 void setKeepPositive() {
min_value = 0; }
00252 void setMinValue(
real minv=-FLT_MAX) {
min_value = minv; }
00253 void setMaxValue(
real maxv=FLT_MAX) {
max_value = maxv; }
00254 void setBoxConstraint(
real minv,
real maxv) {
min_value = minv;
max_value = maxv; }
00255
00256 void setMark() {
marked =
true; }
00257 void clearMark() {
marked =
false; }
00258 bool isMarked() {
return marked; }
00259
00260 void fillGradient(
real value) {
gradient.
fill(value); }
00261 void clearGradient()
00262 {
00263
if(!
allows_partial_update)
00264
gradient.
clear();
00265
else
00266 {
00267
for (
int r=0;r<
rows_to_update.
length();r++)
00268 {
00269
int row =
rows_to_update[r];
00270
matGradient.
row(row).clear();
00271 }
00272
rows_to_update.
resize(0);
00273
gradient_status=0;
00274 }
00275 }
00276
void clearDiagHessian();
00277 void clearSymbolicGradient() {
g =
Var(); }
00278
00284
bool update(
real step_size,
Vec direction_vec,
real coeff = 1.0,
real b = 0.0);
00285
00290
bool update(
Vec step_sizes,
Vec direction_vec,
real coeff = 1.0,
real b = 0.0);
00291
00293
inline void updateAndClear();
00294
00299
bool update(
real step_size);
00300
00302 void allowPartialUpdates()
00303 {
00304
allows_partial_update=
true;
00305
rows_to_update.
resize(
length());
00306
rows_to_update.
resize(0);
00307
gradient_status=0;
00308 }
00309
00311 void disallowPartialUpdates()
00312 {
00313
allows_partial_update =
false;
00314
gradient_status=2;
00315 }
00316
00318 void updateRow(
int row)
00319 {
00320
if (
gradient_status!=2 &&
allows_partial_update && !
rows_to_update.
contains(row))
00321 {
00322
rows_to_update.
append(row);
00323
if (
gradient_status==0)
gradient_status=1;
00324 }
00325 }
00326
00327
00333
bool update(
Vec new_value);
00334
00339
real maxUpdate(
Vec direction);
00340
00345
virtual bool markPath() =0;
00346
00349
virtual void buildPath(
VarArray& proppath) =0;
00350
00351
virtual void oldread(istream& in);
00352
virtual void write(ostream& out)
const;
00353
00354
00355
00356
00357 void copyFrom(
const Vec& v) {
value << v; }
00358 void copyTo(
Vec& v) { v <<
value; }
00359 void copyGradientFrom(
const Vec& v) {
gradient << v; }
00360 void copyGradientTo(
Vec& v) { v <<
gradient; }
00361
void makeSharedValue(
real* x,
int n);
00362
void makeSharedGradient(
real* x,
int n);
00363
00364
void makeSharedValue(
PP<
Storage<real> > storage,
int offset_=0);
00365
void makeSharedGradient(
PP<
Storage<real> > storage,
int offset_=0);
00366
void makeSharedValue(
Vec& v,
int offset_=0);
00367
void makeSharedGradient(
Vec& v,
int offset_=0);
00368
00369 void copyRValueFrom(
const Vec& v) {
resizeRValue();
rValue << v; }
00370 void copyRValueTo(
Vec& v) {
resizeRValue(); v <<
rValue; }
00371
void makeSharedRValue(
real* x,
int n);
00372
void makeSharedRValue(
PP<
Storage<real> > storage,
int offset_=0);
00373
void makeSharedRValue(
Vec& v,
int offset_=0);
00374
00375 virtual bool isConstant() {
return false; }
00376
00382
virtual void fprop_from_all_sources();
00383
00386
virtual VarArray sources() = 0;
00387
00390
virtual VarArray random_sources() = 0;
00391
00393
virtual VarArray ancestors() = 0;
00395
virtual void unmarkAncestors() = 0;
00396
00399
virtual VarArray parents() = 0;
00400
00402
virtual void accg(
Var v);
00403
00406
virtual void verifyGradient(
real step=0.001);
00407
00409
virtual void resizeDiagHessian();
00410
00411
virtual void resizeRValue();
00412 };
00413
00414
DECLARE_OBJECT_PTR(Variable);
00415
DECLARE_OBJECT_PP(Var, Variable);
00416
00417
00418 inline void Variable::updateAndClear()
00419 {
00420
for(
int i=0; i<
nelems(); i++)
00421
valuedata[i] +=
gradientdata[i];
00422
gradient.
clear();
00423 }
00424
00425
void varDeepCopyField(
Var& field, CopiesMap& copies);
00426
00427
00428 inline Var Var::row(
int i,
bool transpose)
const
00429
{
00430
return subMat(i, 0, 1,
width(),
transpose);
00431 }
00432
00433 inline Var Var::column(
int j,
bool transpose)
const
00434
{
00435
return subMat(0, j,
length(), 1,
transpose);
00436 }
00437
00438 inline Var Var::operator()(
int i)
const
00439
{
00440
return row(i,
false);
00441 }
00442
00443 inline Var Var::operator()(
int i,
int j)
const
00444
{
00445
return subMat(i, j, 1, 1);
00446 }
00447
00448 }
00449
00450
#endif
00451
Generated on Tue Aug 17 16:10:11 2004 for PLearn by
1.3.7