00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00044
#include "StackedLearner.h"
00045
#include <plearn/vmat/PLearnerOutputVMatrix.h>
00046
#include <plearn/vmat/ShiftAndRescaleVMatrix.h>
00047
00048
namespace PLearn {
00049
using namespace std;
00050
00051 StackedLearner::StackedLearner()
00052
00053 :
00054 base_train_splitter(0),
00055 train_base_learners(true),
00056 normalize_base_learners_output(false),
00057 precompute_base_learners_output(false),
00058 put_raw_input(false)
00059 {
00060 }
00061
00062
PLEARN_IMPLEMENT_OBJECT(
StackedLearner,
00063
"Implements stacking, that combines two levels of learner, the 2nd level using the 1st outputs as inputs",
00064
"Stacking is a generic strategy in which two levels (or more, recursively) of learners\n"
00065
"are combined. The lower level may have one or more learners, and they may be trained\n"
00066
"on the same or different data from the upper level single learner. The outputs of the\n"
00067
"1st level learners are concatenated and serve as inputs to the second level learner.\n"
00068
"IT IS ASSUMED THAT ALL BASE LEARNERS HAVE THE SAME NUMBER OF INPUTS AND OUTPUTS\n"
00069
"There is also the option to copy the input of the 1st level learner as additional\n"
00070
" inputs for the second level (put_raw_input).\n"
00071
"A Splitter can optionally be provided to specify how to split the data into\n"
00072
"the training /validation sets for the lower and upper levels respectively\n"
00073 );
00074
00075 void StackedLearner::declareOptions(
OptionList& ol)
00076 {
00077
00078
00079
00080
00081
00082
00083
declareOption(ol,
"base_learners", &StackedLearner::base_learners, OptionBase::buildoption,
00084
"A set of 1st level base learners that are independently trained (here or elsewhere)\n"
00085
"and whose outputs will serve as inputs to the combiner (2nd level learner)");
00086
00087
declareOption(ol,
"combiner", &StackedLearner::combiner, OptionBase::buildoption,
00088
"A learner that is trained (possibly on a data set different from the one used to train\n"
00089
"the base_learners) using the outputs of the base_learners as inputs. If it is not\n"
00090
"provided, then the StackedLearner simply AVERAGES the outputs of the base_learners\n");
00091
00092
declareOption(ol,
"splitter", &StackedLearner::splitter, OptionBase::buildoption,
00093
"A Splitter used to select which data subset(s) goes to training the base_learners\n"
00094
"and which data subset(s) goes to training the combiner. If not provided then the\n"
00095
"same data is used to train and test both levels. If provided, in each split, there should be\n"
00096
"two sets: the set on which to train the first level and the set on which to train the second one\n");
00097
00098
declareOption(ol,
"base_train_splitter", &StackedLearner::base_train_splitter, OptionBase::buildoption,
00099
"This splitter can be used to split the training set into different training sets for each base learner\n"
00100
"If it is not set, the same training set will be applied to the base learners.\n"
00101
"If \"splitter\" is also used, it will be applied first to determine the training set used by base_train_splitter.\n"
00102
"The splitter should give as many splits as base learners, and each split should contain one set.");
00103
00104
declareOption(ol,
"train_base_learners", &StackedLearner::train_base_learners, OptionBase::buildoption,
00105
"whether to train the base learners in the method train (otherwise they should be\n"
00106
"initialized properly at construction / setOption time)\n");
00107
00108
declareOption(ol,
"normalize_base_learners_output", &StackedLearner::normalize_base_learners_output, OptionBase::buildoption,
00109
"If set to 1, the output of the base learners on the combiner training set\n"
00110
"will be normalized (zero mean, unit variance) before training the combiner.");
00111
00112
declareOption(ol,
"precompute_base_learners_output", &StackedLearner::precompute_base_learners_output, OptionBase::buildoption,
00113
"If set to 1, the output of the base learners on the combiner training set\n"
00114
"will be precomputed in memory before training the combiner (this may speed\n"
00115
"up significantly the combiner training process).");
00116
00117
00118
declareOption(ol,
"put_raw_input", &StackedLearner::put_raw_input, OptionBase::buildoption,
00119
"whether to put the raw inputs in addition of the base learners outputs, in input of the combiner (default=0)\n");
00120
00121
00122 inherited::declareOptions(ol);
00123 }
00124
00125 void StackedLearner::setTrainStatsCollector(
PP<VecStatsCollector> statscol)
00126 {
00127 train_stats = statscol;
00128
if (
combiner)
00129
combiner->setTrainStatsCollector(statscol);
00130 }
00131
00132 void StackedLearner::build_()
00133 {
00134
00135
00136
00137
00138
00139
00140
00141
for (
int i=0;i<
base_learners.
length();i++)
00142 {
00143
if (!
base_learners[i])
00144
PLERROR(
"StackedLearner::build: base learners have not been created!");
00145
base_learners[i]->build();
00146
if (i>0 &&
base_learners[i]->outputsize()!=base_learners[i-1]->outputsize())
00147
PLERROR(
"StackedLearner: expecting base learners to have the same number of outputs!");
00148 }
00149
if (
combiner)
00150
combiner->build();
00151
if (
splitter)
00152
splitter->build();
00153
if (
splitter &&
splitter->nSetsPerSplit()!=2)
00154
PLERROR(
"StackedLearner: the Splitter should produce only two sets per split, got %d",
splitter->nSetsPerSplit());
00155
base_learners_outputs.
resize(
base_learners.
length(),
base_learners[0]->outputsize());
00156 }
00157
00158
00159 void StackedLearner::build()
00160 {
00161 inherited::build();
00162
build_();
00163 }
00164
00165
00166 void StackedLearner::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies)
00167 {
00168 inherited::makeDeepCopyFromShallowCopy(copies);
00169
00170
deepCopyField(
base_learners, copies);
00171
if (
combiner)
00172
deepCopyField(
combiner, copies);
00173
if (
splitter)
00174
deepCopyField(
splitter, copies);
00175
if (
base_train_splitter)
00176
deepCopyField(
base_train_splitter, copies);
00177 }
00178
00179
00180 int StackedLearner::outputsize()
const
00181
{
00182
00183
00184
if (
combiner)
00185
return combiner->outputsize();
00186
else
00187
return base_learners[0]->outputsize();
00188 }
00189
00190 void StackedLearner::forget()
00191 {
00192
if (
train_base_learners)
00193
for (
int i=0;i<
base_learners.
length();i++)
00194
base_learners[i]->forget();
00195
if (
combiner)
00196
combiner->forget();
00197 }
00198
00199 void StackedLearner::setTrainingSet(
VMat training_set,
bool call_forget)
00200 {
00201
if (
splitter)
00202 {
00203
splitter->setDataSet(training_set);
00204
if (
splitter->nsplits()==1)
00205 {
00206
TVec<VMat> sets =
splitter->getSplit();
00207
VMat lower_trainset = sets[0];
00208
VMat upper_trainset = sets[1];
00209
if (
base_train_splitter) {
00210
base_train_splitter->setDataSet(lower_trainset);
00211 }
else {
00212
for (
int i=0;i<
base_learners.
length();i++)
00213
base_learners[i]->setTrainingSet(lower_trainset,call_forget &&
train_base_learners);
00214 }
00215
if (
combiner)
00216
combiner->setTrainingSet(
new PLearnerOutputVMatrix(upper_trainset,
base_learners,
put_raw_input),call_forget);
00217 }
else {
00218
PLERROR(
"In StackedLearner::setTrainingSet - The splitter provided should only return one split");
00219 }
00220 }
else
00221 {
00222
if (
base_train_splitter) {
00223
base_train_splitter->setDataSet(training_set);
00224 }
else {
00225
for (
int i=0;i<
base_learners.
length();i++)
00226
base_learners[i]->setTrainingSet(training_set,call_forget &&
train_base_learners);
00227 }
00228
if (
combiner)
00229
combiner->setTrainingSet(
new PLearnerOutputVMatrix(training_set,
base_learners,
put_raw_input),call_forget);
00230 }
00231
if (
base_train_splitter) {
00232
for (
int i=0;i<
base_learners.
length();i++) {
00233
base_learners[i]->setTrainingSet(
base_train_splitter->getSplit(i)[0],call_forget &&
train_base_learners);
00234 }
00235 }
00236 inherited::setTrainingSet(training_set, call_forget);
00237 }
00238
00239 void StackedLearner::train()
00240 {
00241
if (!train_stats)
00242
PLERROR(
"StackedLearner::train: train_stats has not been set!");
00243
if (!
splitter ||
splitter->nsplits()==1)
00244 {
00245
if (
train_base_learners)
00246
for (
int i=0;i<
base_learners.
length();i++)
00247 {
00248
PP<VecStatsCollector> stats =
new VecStatsCollector();
00249
base_learners[i]->setTrainStatsCollector(stats);
00250
if (expdir!=
"")
00251
base_learners[i]->setExperimentDirectory(expdir+
"Base"+
tostring(i));
00252
base_learners[i]->train();
00253 stats->finalize();
00254 }
00255
if (
combiner)
00256 {
00257
if (
normalize_base_learners_output) {
00258
00259
VMat normalized_trainset =
00260
new ShiftAndRescaleVMatrix(
combiner->getTrainingSet(), -1);
00261
combiner->setTrainingSet(normalized_trainset);
00262 }
00263
if (
precompute_base_learners_output) {
00264
00265
VMat precomputed_trainset =
combiner->getTrainingSet();
00266 precomputed_trainset.
precompute();
00267
combiner->setTrainingSet(precomputed_trainset,
false);
00268 }
00269
combiner->setTrainStatsCollector(train_stats);
00270
if (expdir!=
"")
00271
combiner->setExperimentDirectory(expdir+
"Combiner");
00272
combiner->train();
00273 }
00274 }
else PLERROR(
"StackedLearner: multi-splits case not implemented yet");
00275 }
00276
00277
00278 void StackedLearner::computeOutput(
const Vec& input,
Vec& output)
const
00279
{
00280
for (
int i=0;i<
base_learners.
length();i++)
00281 {
00282
Vec out_i =
base_learners_outputs(i);
00283
if (!
base_learners[i])
00284
PLERROR(
"StackedLearner::computeOutput: base learners have not been created!");
00285
base_learners[i]->computeOutput(input,out_i);
00286 }
00287
if (
combiner)
00288
combiner->computeOutput(
base_learners_outputs.
toVec(),output);
00289
else
00290
columnMean(
base_learners_outputs,output);
00291 }
00292
00293 void StackedLearner::computeCostsFromOutputs(
const Vec& input,
const Vec& output,
00294
const Vec& target,
Vec& costs)
const
00295
{
00296
if (
combiner)
00297
combiner->computeCostsFromOutputs(
base_learners_outputs.
toVec(),output,target,costs);
00298
else
00299
base_learners[0]->computeCostsFromOutputs(input,output,target,costs);
00300 }
00301
00302 TVec<string> StackedLearner::getTestCostNames()
const
00303
{
00304
00305
00306
if (
combiner)
00307
return combiner->getTestCostNames();
00308
else
00309
return base_learners[0]->getTestCostNames();
00310 }
00311
00312 TVec<string> StackedLearner::getTrainCostNames()
const
00313
{
00314
00315
00316
if (
combiner)
00317
return combiner->getTestCostNames();
00318
else
00319
return base_learners[0]->getTestCostNames();
00320 }
00321
00322
00323
00324 }