00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00048
#ifndef PLearner_INC
00049
#define PLearner_INC
00050
00051
#include <plearn/base/Object.h>
00052
#include <plearn/vmat/VMat.h>
00053
#include <plearn/math/VecStatsCollector.h>
00054
00055
namespace PLearn {
00056
using namespace std;
00057
00061 class PLearner:
public Object
00062 {
00063
00064
private:
00065
00066 typedef Object inherited;
00067
00068 mutable int n_train_costs_;
00069 mutable int n_test_costs_;
00070
00072 mutable Vec tmp_output;
00073
00074
public:
00075
00076
00077
00085 string expdir;
00086
00087 long seed_;
00088 int stage;
00089
00090
00091
00092
00093
00094 int nstages;
00095
00096
00097
00098
00099
00100
00101 bool report_progress;
00102 int verbosity;
00103
00104
protected:
00105
00106
00108
00116 VMat train_set;
00117
00118
00119 int inputsize_,
targetsize_,
weightsize_,
n_examples;
00120
00121 VMat validation_set;
00122
00125 PP<VecStatsCollector> train_stats;
00126
00127
public:
00128
00129
PLearner();
00130
virtual ~PLearner();
00131
00135
virtual void setTrainingSet(
VMat training_set,
bool call_forget=
true);
00136
00138 inline VMat getTrainingSet()
const {
return train_set; }
00139
00141
virtual void setValidationSet(
VMat validset);
00142
00144 VMat getValidationSet()
const {
return validation_set; }
00145
00149
virtual void setTrainStatsCollector(
PP<VecStatsCollector> statscol);
00150
00152 inline PP<VecStatsCollector> getTrainStatsCollector()
00153 {
return train_stats; }
00154
00159
virtual void setExperimentDirectory(
const string& the_expdir);
00160
00162 string getExperimentDirectory()
const {
return expdir; }
00163
00165
virtual int inputsize() const;
00166
00168 virtual
int targetsize() const;
00169
00173 virtual
int outputsize() const =0;
00174
00175 public:
00176
00179 virtual
void build();
00180
00181 protected:
00183 virtual
void build_from_train_set() {}
00184
00185
private:
00205
void build_();
00206
00207
public:
00210
00223
virtual void forget() =0;
00224
00225
00228
00258
virtual void train() =0;
00259
00260
00263
virtual void computeOutput(
const Vec& input,
Vec& output)
const =0;
00264
00271
virtual void computeCostsFromOutputs(
const Vec& input,
const Vec& output,
00272
const Vec& target,
Vec& costs)
const =0;
00273
00277
virtual void computeOutputAndCosts(
const Vec& input,
const Vec& target,
00278
Vec& output,
Vec& costs)
const;
00279
00283
virtual void computeCostsOnly(
const Vec& input,
const Vec& target,
Vec& costs)
const;
00284
00285
00289
virtual void use(
VMat testset,
VMat outputs)
const;
00290
00293
virtual void useOnTrain(
Mat& outputs)
const;
00294
00298
virtual void test(
VMat testset,
PP<VecStatsCollector> test_stats,
00299
VMat testoutputs=0,
VMat testcosts=0)
const;
00300
00303
virtual TVec<string> getTestCostNames() const =0;
00304
00308 virtual
TVec<
string> getTrainCostNames() const =0;
00309
00312 virtual
int nTestCosts() const;
00313
00316 virtual
int nTrainCosts() const;
00317
00320
int getTestCostIndex(const
string& costname) const;
00321
00324
int getTrainCostIndex(const
string& costname) const;
00325
00328 virtual
void resetInternalState();
00329
00332 virtual
bool isStatefulLearner() const;
00333
00334 protected:
00335
00336 static
void declareOptions(
OptionList& ol);
00337 virtual
void makeDeepCopyFromShallowCopy(CopiesMap& copies);
00338
00339 public:
00340
00343 PLEARN_DECLARE_ABSTRACT_OBJECT(
PLearner);
00344
00350 virtual
void matlabSave(const
string& matlab_subdir){}
00351 };
00352
00353
DECLARE_OBJECT_PTR(PLearner);
00354
00359
00360
00361
00362
00363
00364 }
00365
00366
#endif
00367
00368
00369
00370
00371