00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
#include "PLearner.h"
00046
00047
namespace PLearn {
00048
using namespace std;
00049
00050 PLearner::PLearner()
00051 :
00052 n_train_costs_(-1),
00053 n_test_costs_(-1),
00054 seed_(-1),
00055 stage(0), nstages(1),
00056 report_progress(true),
00057 verbosity(1),
00058 inputsize_(-1),
00059 targetsize_(-1),
00060 weightsize_(-1)
00061 {}
00062
00063
PLEARN_IMPLEMENT_ABSTRACT_OBJECT(
PLearner,
00064
"The base class for all PLearn learners.",
00065
""
00066 );
00067
00068 void PLearner::makeDeepCopyFromShallowCopy(
CopiesMap& copies)
00069 {
00070 inherited::makeDeepCopyFromShallowCopy(copies);
00071
deepCopyField(
tmp_output, copies);
00072
00073
00074
deepCopyField(
train_stats, copies);
00075 }
00076
00077 void PLearner::declareOptions(
OptionList& ol)
00078 {
00079
declareOption(ol,
"expdir", &PLearner::expdir, OptionBase::buildoption,
00080
"Path of the directory associated with this learner, in which\n"
00081
"it should save any file it wishes to create. \n"
00082
"The directory will be created if it does not already exist.\n"
00083
"If expdir is the empty string (the default), then the learner \n"
00084
"should not create *any* file. Note that, anyway, most file creation and \n"
00085
"reporting are handled at the level of the PTester class rather than \n"
00086
"at the learner's. \n");
00087
00088
declareOption(ol,
"seed", &PLearner::seed_, OptionBase::buildoption,
00089
"The initial seed for the random number generator used to initialize this learner's parameters\n"
00090
"as typically done in the forget() method... \n"
00091
"If -1 is provided, then a 'random' seed is chosen based on time of day, insuring that\n"
00092
"different experiments may yield different results.\n"
00093
"With a given seed, forget() should always initialize the parameters to the same values.");
00094
00095
declareOption(ol,
"stage", &PLearner::stage, OptionBase::learntoption,
00096
"The current training stage, since last fresh initialization (forget()): \n"
00097
"0 means untrained, n often means after n epochs or optimization steps, etc...\n"
00098
"The true meaning is learner-dependant."
00099
"You should never modify this option directly!"
00100
"It is the role of forget() to bring it back to 0,\n"
00101
"and the role of train() to bring it up to 'nstages'...");
00102
00103
declareOption(ol,
"n_examples", &PLearner::n_examples, OptionBase::learntoption,
00104
"The number of samples in the training set.\n"
00105
"Obtained from training set with setTrainingSet.");
00106
00107
declareOption(ol,
"inputsize", &PLearner::inputsize_, OptionBase::learntoption,
00108
"The number of input columns in the data sets."
00109
"Obtained from training set with setTrainingSet.");
00110
00111
declareOption(ol,
"targetsize", &PLearner::targetsize_, OptionBase::learntoption,
00112
"The number of target columns in the data sets."
00113
"Obtained from training set with setTrainingSet.");
00114
00115
declareOption(ol,
"weightsize", &PLearner::weightsize_, OptionBase::learntoption,
00116
"The number of cost weight columns in the data sets."
00117
"Obtained from training set with setTrainingSet.");
00118
00119
declareOption(ol,
"nstages", &PLearner::nstages, OptionBase::buildoption,
00120
"The stage until which train() should train this learner and return.\n"
00121
"The meaning of 'stage' is learner-dependent, but for learners whose \n"
00122
"training is incremental (such as involving incremental optimization), \n"
00123
"it is typically synonym with the number of 'epochs', i.e. the number \n"
00124
"of passages of the optimization process through the whole training set, \n"
00125
"since the last fresh initialisation.");
00126
00127
declareOption(ol,
"report_progress", &PLearner::report_progress, OptionBase::buildoption,
00128
"should progress in learning and testing be reported in a ProgressBar.\n");
00129
00130
declareOption(ol,
"verbosity", &PLearner::verbosity, OptionBase::buildoption,
00131
"Level of verbosity. If 0 should not write anything on cerr. \n"
00132
"If >0 may write some info on the steps performed along the way.\n"
00133
"The level of details written should depend on this value.");
00134
00135 inherited::declareOptions(ol);
00136 }
00137
00138
00139 void PLearner::setExperimentDirectory(
const string& the_expdir)
00140 {
00141
if(the_expdir==
"")
00142
expdir =
"";
00143
else
00144 {
00145
if(!
force_mkdir(the_expdir))
00146
PLERROR(
"In PLearner::setExperimentDirectory Could not create experiment directory %s",the_expdir.c_str());
00147
expdir =
abspath(the_expdir);
00148 }
00149 }
00150
00151 void PLearner::setTrainingSet(
VMat training_set,
bool call_forget)
00152 {
00153
00154
00155
bool training_set_has_changed = !
train_set || !(
train_set->looksTheSameAs(training_set));
00156
train_set = training_set;
00157
if (training_set_has_changed)
00158 {
00159
inputsize_ =
train_set->inputsize();
00160
targetsize_ =
train_set->targetsize();
00161
weightsize_ =
train_set->weightsize();
00162 }
00163
n_examples =
train_set->
length();
00164
if (training_set_has_changed || call_forget)
00165
build();
00166
if (call_forget)
00167
forget();
00168 }
00169
00170 void PLearner::setValidationSet(
VMat validset)
00171 {
validation_set = validset; }
00172
00173
00174 void PLearner::setTrainStatsCollector(
PP<VecStatsCollector> statscol)
00175 {
train_stats = statscol; }
00176
00177
00178 int PLearner::inputsize()
const
00179
{
00180
if (
inputsize_<0)
00181
PLERROR(
"Must specify a training set before calling PLearner::inputsize()");
00182
return inputsize_;
00183 }
00184
00185 int PLearner::targetsize()
const
00186
{
00187
if(
targetsize_ == -1)
00188
PLERROR(
"In PLearner::targetsize - 'targetsize_' is -1, either no training set has beeen specified or its sizes were not set properly");
00189
return targetsize_;
00190 }
00191
00192 void PLearner::build_()
00193 {
00194
if(
expdir!=
"")
00195 {
00196
if(!
force_mkdir(
expdir))
00197
PLERROR(
"In PLearner Could not create experiment directory %s",
expdir.c_str());
00198
expdir =
abspath(
expdir);
00199 }
00200 }
00201
00202 void PLearner::build()
00203 {
00204 inherited::build();
00205
build_();
00206 }
00207
00208 PLearner::~PLearner()
00209 {
00210 }
00211
00212 int PLearner::nTestCosts()
const
00213
{
00214
if(
n_test_costs_<0)
00215
n_test_costs_ =
getTestCostNames().
size();
00216
return n_test_costs_;
00217 }
00218
00219 int PLearner::nTrainCosts()
const
00220
{
00221
if(
n_train_costs_<0)
00222
n_train_costs_ =
getTrainCostNames().
size();
00223
return n_train_costs_;
00224 }
00225
00226 int PLearner::getTestCostIndex(
const string& costname)
const
00227
{
00228
TVec<string> costnames =
getTestCostNames();
00229
for(
int i=0; i<costnames.
length(); i++)
00230
if(costnames[i]==costname)
00231
return i;
00232
PLERROR(
"In PLearner::getTestCostIndex, No test cost named %s in this learner.\n"
00233
"Available test costs are: %s", costname.c_str(),
tostring(costnames).c_str());
00234
return -1;
00235 }
00236
00237 int PLearner::getTrainCostIndex(
const string& costname)
const
00238
{
00239
TVec<string> costnames =
getTrainCostNames();
00240
for(
int i=0; i<costnames.
length(); i++)
00241
if(costnames[i]==costname)
00242
return i;
00243
PLERROR(
"In PLearner::getTrainCostIndex, No train cost named %s in this learner.\n"
00244
"Available train costs are: %s", costname.c_str(),
tostring(costnames).c_str());
00245
return -1;
00246 }
00247
00248 void PLearner::computeOutputAndCosts(
const Vec& input,
const Vec& target,
00249
Vec& output,
Vec& costs)
const
00250
{
00251
computeOutput(input, output);
00252
computeCostsFromOutputs(input, output, target, costs);
00253 }
00254
00255 void PLearner::computeCostsOnly(
const Vec& input,
const Vec& target,
00256
Vec& costs)
const
00257
{
00258
tmp_output.
resize(
outputsize());
00259
computeOutputAndCosts(input, target,
tmp_output, costs);
00260 }
00261
00263
00265 void PLearner::use(
VMat testset,
VMat outputs)
const
00266
{
00267
int l = testset.
length();
00268
Vec input;
00269
Vec target;
00270
real weight;
00271
Vec output(
outputsize());
00272
00273
ProgressBar* pb = NULL;
00274
if(
report_progress)
00275 pb =
new ProgressBar(
"Using learner",l);
00276
00277
for(
int i=0; i<l; i++)
00278 {
00279 testset.
getExample(i, input, target, weight);
00280
computeOutput(input, output);
00281 outputs->putOrAppendRow(i,output);
00282
if(pb)
00283 pb->
update(i);
00284 }
00285
00286
if(pb)
00287
delete pb;
00288 }
00289
00291
00293 void PLearner::useOnTrain(
Mat& outputs)
const {
00294
PLWARNING(
"In PLearner::useOnTrain - This method has not been tested yet, remove this warning if it works fine");
00295
VMat train_output(outputs);
00296
use(
train_set, train_output);
00297 }
00298
00300
00302 void PLearner::test(
VMat testset,
PP<VecStatsCollector> test_stats,
00303
VMat testoutputs,
VMat testcosts)
const
00304
{
00305
int l = testset.
length();
00306
Vec input;
00307
Vec target;
00308
real weight;
00309
00310
Vec output(testoutputs ?
outputsize() :0);
00311
00312
Vec costs(
nTestCosts());
00313
00314
00315
00316
if(test_stats)
00317 test_stats->forget();
00318
00319
ProgressBar* pb = NULL;
00320
if(
report_progress)
00321 pb =
new ProgressBar(
"Testing learner",l);
00322
00323
if (l == 0) {
00324
00325 costs.
fill(-1);
00326 test_stats->update(costs);
00327 }
00328
00329
for(
int i=0; i<l; i++)
00330 {
00331 testset.
getExample(i, input, target, weight);
00332
00333
if(testoutputs)
00334 {
00335
computeOutputAndCosts(input, target, output, costs);
00336 testoutputs->putOrAppendRow(i,output);
00337 }
00338
else
00339
computeCostsOnly(input, target, costs);
00340
00341
if(testcosts)
00342 testcosts->putOrAppendRow(i, costs);
00343
00344
if(test_stats)
00345 test_stats->update(costs,weight);
00346
00347
if(pb)
00348 pb->
update(i);
00349 }
00350
00351
if(test_stats)
00352 test_stats->finalize();
00353
00354
if(pb)
00355
delete pb;
00356
00357 }
00358
00359 void PLearner::resetInternalState()
00360 {}
00361
00362 bool PLearner::isStatefulLearner()
const
00363
{
return false; }
00364
00365
00366 }
00367