00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00048
#ifndef Learner_INC
00049
#define Learner_INC
00050
00051
#include <plearn/measure/Measurer.h>
00052
#include <plearn/ker/Kernel.h>
00053
#include <plearn/math/VecStatsCollector.h>
00054
#include <plearn/math/StatsIterator.h>
00055
#include <plearn/vmat/VVec.h>
00056
00057
00058
namespace PLearn {
00059
using namespace std;
00060
00072 class Learner:
public Object,
public Measurer
00073 {
00074
protected:
00075
00076 Vec tmpvec;
00077
00078
00079 ofstream*
train_objective_stream;
00080 Array<ofstream*> test_results_streams;
00081
00082
private:
00083
00084
static Vec tmp_input;
00085
static Vec tmp_target;
00086
static Vec tmp_weight;
00087
static Vec tmp_output;
00088
static Vec tmp_costs;
00089
00090
protected:
00091
00093
void openTrainObjectiveStream();
00094
00097 ostream&
getTrainObjectiveStream();
00098
00100
void openTestResultsStreams();
00101
00104 ostream&
getTestResultsStream(
int k);
00105
00107
void freeTestResultsStreams();
00108
00110
void outputResultLineToFile(
const string & filename,
const Vec& results,
bool append,
const string& names);
00111
00112
protected:
00115 string expdir;
00116
00118 int epoch_;
00119
00123 bool distributed_;
00124
00125
00126
public:
00127
00129
string basename() const;
00130
00131 typedef
Object inherited;
00132
00133 int inputsize_;
00134 int targetsize_;
00135 int outputsize_;
00136 int weightsize_;
00137
00140 bool dont_parallelize;
00141
00143
00144 PStream testout;
00145 int test_every;
00146 Vec avg_objective;
00147 Vec avgsq_objective;
00148 VMat train_set;
00149 Array<
VMat> test_sets;
00150 int minibatch_size;
00151
00156 int report_test_progress_every;
00157
00160 Vec options;
00161
00163 int earlystop_testsetnum;
00164 int earlystop_testresultindex;
00165 real earlystop_max_degradation;
00166 real earlystop_min_value;
00167 real earlystop_min_improvement;
00168 bool earlystop_relative_changes;
00169 bool earlystop_save_best;
00170 int earlystop_max_degraded_steps;
00171
00172 bool save_at_every_epoch;
00173 bool save_objective;
00174 int best_step;
00175
00176 protected:
00178 real earlystop_previousval;
00179 public:
00180 real earlystop_minval;
00181
00182
00183 string experiment_name;
00184
00185 protected:
00186
00187
00189 Array<
Measurer*> measurers;
00190
00191 bool measure_cpu_time_first;
00192
00193 bool each_cpu_saves_its_errors;
00194 public:
00195 Array<
CostFunc> test_costfuncs;
00196 StatsItArray test_statistics;
00197
00198 static
int use_file_if_bigger;
00199
00200
00201
00202 static
bool force_saving_on_all_processes;
00203
00204 static
PStream& default_vlog();
00205
00206
00207 PStream vlog;
00208 PStream objectiveout;
00209
00216
Learner(
int the_inputsize=0,
int the_targetsize=0,
int the_outputsize=0);
00217
00218 virtual ~
Learner();
00219
00222
00228 virtual
void setExperimentDirectory(const
string& the_expdir);
00229 string getExperimentDirectory()
const {
return expdir; }
00230
00233
PLEARN_DECLARE_ABSTRACT_OBJECT(
Learner);
00234
virtual void makeDeepCopyFromShallowCopy(CopiesMap& copies);
00235
00236
private:
00246
void build_();
00247
00248
public:
00251
virtual void build();
00252
00254 virtual void setTrainingSet(
VMat training_set) {
train_set = training_set; }
00255 inline VMat getTrainingSet() {
return train_set; }
00256
00263
virtual void train(
VMat training_set) =0;
00264
00265
00270
virtual void newtrain(
VecStatsCollector& train_stats);
00271
00272
00275
virtual void newtest(
VMat testset,
VecStatsCollector& test_stats,
00276
VMat testoutputs=0,
VMat testcosts=0);
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00292 virtual void train(
VMat training_set,
VMat accept_prob,
00293
real max_accept_prob=1.0,
VMat weights=
VMat())
00294 {
PLERROR(
"This method is not implemented for this learner"); }
00295
00301
virtual void use(
const Vec& input,
Vec& output) =0;
00302 virtual void use(
const Mat& inputs,
Mat outputs)
00303 {
00304
for (
int i=0;i<inputs.
length();i++)
00305 {
00306
Vec input = inputs(i);
00307
Vec output = outputs(i);
00308
use(input,output);
00309 }
00310 }
00311
00313 Vec vec_input;
00314
00317
00318
00319
virtual void computeOutput(
const VVec& input,
Vec& output);
00320
00324
00325
00326
virtual void computeCostsFromOutputs(
const VVec& input,
const Vec& output,
00327
const VVec& target,
const VVec& weight,
00328
Vec& costs);
00329
00330
00334
virtual void computeOutputAndCosts(
const VVec& input,
VVec& target,
const VVec& weight,
00335
Vec& output,
Vec& costs);
00336
00340
virtual void computeCosts(
const VVec& input,
VVec& target,
VVec& weight,
00341
Vec& costs);
00342
00343
00347
virtual void setModel(
const Vec& new_options);
00348
00355
virtual void forget();
00356
00378
virtual bool measure(
int step,
const Vec& costs);
00379
00387
virtual void oldwrite(ostream& out)
const;
00388
virtual void oldread(istream& in);
00389
00391
void save(
const string& filename=
"")
const;
00393
void load(
const string& filename=
"");
00394
00399
virtual void stop_if_wanted();
00400
00402 inline int inputsize()
const {
return inputsize_; }
00403 inline int targetsize()
const {
return targetsize_; }
00404 inline int outputsize()
const {
return outputsize_; }
00405 inline int weightsize()
const {
return weightsize_; }
00406 inline int epoch()
const {
return epoch_; }
00407
00411
virtual int costsize() const;
00412
00415 void setTestCostFunctions(
Array<
CostFunc> costfunctions)
00416 {
test_costfuncs = costfunctions; }
00417
00420 void setTestStatistics(
StatsItArray statistics)
00421 {
test_statistics = statistics; }
00422
00425
virtual void setTestDuringTrain(ostream& testout,
int every,
00426
Array<VMat> testsets);
00427
00429
virtual void setTestDuringTrain(
Array<VMat> testsets);
00430
00431
00433 const Array<VMat>&
getTestDuringTrain()
const {
00434
return test_sets;
00435 }
00436
00437
00454
void setEarlyStopping(
int which_testset,
int which_testresult,
00455
real max_degradation,
real min_value=-FLT_MAX,
00456
real min_improvement=0,
bool relative_changes=
true,
00457
bool save_best=
true,
int max_degraded_steps=-1);
00458
00466
virtual void computeCost(
const Vec& input,
const Vec& target,
const Vec& output,
const Vec& cost);
00467
00470
virtual void useAndCost(
const Vec& input,
const Vec& target,
00471
Vec output,
Vec cost);
00472
00478
virtual void useAndCostOnTestVec(
const VMat& test_set,
int i,
const Vec& output,
const Vec& cost);
00479
00486
virtual void apply(
const VMat& data,
VMat outputs);
00487
00492
virtual void applyAndComputeCosts(
const VMat& data,
VMat outputs,
VMat costs);
00493
00497
virtual void applyAndComputeCostsOnTestMat(
const VMat& test_set,
int i,
const Mat& output_block,
00498
const Mat& cost_block);
00499
00504
virtual void computeCosts(
const VMat& data,
VMat costs);
00505
00508
virtual void computeLeaveOneOutCosts(
const VMat& data,
VMat costs);
00509
00514
virtual void computeLeaveOneOutCosts(
const VMat& data,
VMat costsmat,
CostFunc costf);
00515
00521
Vec computeTestStatistics(
const VMat& costs);
00522
00527
virtual Vec test(
VMat test_set,
const string& save_test_outputs=
"",
00528
const string& save_test_costs=
"");
00529
00534
virtual Array<string> costNames() const;
00535
00541 virtual
Array<
string> testResultsNames() const;
00542
00546 virtual
Array<
string> trainObjectiveNames() const;
00547
00552 void appendMeasurer(
Measurer& measurer)
00553 {
measurers.
append(&measurer); }
00554
00555
protected:
00556
static void declareOptions(
OptionList& ol);
00557
00558 void setTrainCost(
Vec &cost)
00559 {
train_cost.
resize(cost.
length());
train_cost << cost; };
00560 Vec train_cost;
00561
public:
00562 Vec getTrainCost() {
return train_cost; };
00563 };
00564
00565
DECLARE_OBJECT_PTR(Learner);
00566
00567 typedef PP<Learner> PPLearner;
00568
00569 inline void prettyprint_test_results(ostream& out,
const Learner& learner,
const Vec& results)
00570 {
00571
Array<string> names = learner.
testResultsNames();
00572
for (
int i=0; i<names.
size(); i++)
00573 out << names[i] <<
": " << results[i] <<
endl;
00574 }
00575
00576
00577 }
00578
00579
#endif
00580
00581
00582
00583
00584