Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

StackedLearner.cc

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // StackedLearner.cc 00004 // 00005 // Copyright (C) 2003 Yoshua Bengio 00006 // 00007 // Redistribution and use in source and binary forms, with or without 00008 // modification, are permitted provided that the following conditions are met: 00009 // 00010 // 1. Redistributions of source code must retain the above copyright 00011 // notice, this list of conditions and the following disclaimer. 00012 // 00013 // 2. Redistributions in binary form must reproduce the above copyright 00014 // notice, this list of conditions and the following disclaimer in the 00015 // documentation and/or other materials provided with the distribution. 00016 // 00017 // 3. The name of the authors may not be used to endorse or promote 00018 // products derived from this software without specific prior written 00019 // permission. 00020 // 00021 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00022 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00023 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00024 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00025 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00026 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00027 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00028 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00029 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 // 00032 // This file is part of the PLearn library. For more information on the PLearn 00033 // library, go to the PLearn Web site at www.plearn.org 00034 00035 /* ******************************************************* 00036 * $Id: StackedLearner.cc,v 1.16 2004/07/21 16:30:56 chrish42 Exp $ 00037 ******************************************************* */ 00038 00039 // Authors: Yoshua Bengio 00040 00044 #include "StackedLearner.h" 00045 #include <plearn/vmat/PLearnerOutputVMatrix.h> 00046 #include <plearn/vmat/ShiftAndRescaleVMatrix.h> 00047 00048 namespace PLearn { 00049 using namespace std; 00050 00051 StackedLearner::StackedLearner() 00052 /* ### Initialise all fields to their default value here */ 00053 : 00054 base_train_splitter(0), 00055 train_base_learners(true), 00056 normalize_base_learners_output(false), 00057 precompute_base_learners_output(false), 00058 put_raw_input(false) 00059 { 00060 } 00061 00062 PLEARN_IMPLEMENT_OBJECT(StackedLearner, 00063 "Implements stacking, that combines two levels of learner, the 2nd level using the 1st outputs as inputs", 00064 "Stacking is a generic strategy in which two levels (or more, recursively) of learners\n" 00065 "are combined. The lower level may have one or more learners, and they may be trained\n" 00066 "on the same or different data from the upper level single learner. The outputs of the\n" 00067 "1st level learners are concatenated and serve as inputs to the second level learner.\n" 00068 "IT IS ASSUMED THAT ALL BASE LEARNERS HAVE THE SAME NUMBER OF INPUTS AND OUTPUTS\n" 00069 "There is also the option to copy the input of the 1st level learner as additional\n" 00070 " inputs for the second level (put_raw_input).\n" 00071 "A Splitter can optionally be provided to specify how to split the data into\n" 00072 "the training /validation sets for the lower and upper levels respectively\n" 00073 ); 00074 00075 void StackedLearner::declareOptions(OptionList& ol) 00076 { 00077 // ### Declare all of this object's options here 00078 // ### For the "flags" of each option, you should typically specify 00079 // ### one of OptionBase::buildoption, OptionBase::learntoption or 00080 // ### OptionBase::tuningoption. Another possible flag to be combined with 00081 // ### is OptionBase::nosave 00082 00083 declareOption(ol, "base_learners", &StackedLearner::base_learners, OptionBase::buildoption, 00084 "A set of 1st level base learners that are independently trained (here or elsewhere)\n" 00085 "and whose outputs will serve as inputs to the combiner (2nd level learner)"); 00086 00087 declareOption(ol, "combiner", &StackedLearner::combiner, OptionBase::buildoption, 00088 "A learner that is trained (possibly on a data set different from the one used to train\n" 00089 "the base_learners) using the outputs of the base_learners as inputs. If it is not\n" 00090 "provided, then the StackedLearner simply AVERAGES the outputs of the base_learners\n"); 00091 00092 declareOption(ol, "splitter", &StackedLearner::splitter, OptionBase::buildoption, 00093 "A Splitter used to select which data subset(s) goes to training the base_learners\n" 00094 "and which data subset(s) goes to training the combiner. If not provided then the\n" 00095 "same data is used to train and test both levels. If provided, in each split, there should be\n" 00096 "two sets: the set on which to train the first level and the set on which to train the second one\n"); 00097 00098 declareOption(ol, "base_train_splitter", &StackedLearner::base_train_splitter, OptionBase::buildoption, 00099 "This splitter can be used to split the training set into different training sets for each base learner\n" 00100 "If it is not set, the same training set will be applied to the base learners.\n" 00101 "If \"splitter\" is also used, it will be applied first to determine the training set used by base_train_splitter.\n" 00102 "The splitter should give as many splits as base learners, and each split should contain one set."); 00103 00104 declareOption(ol, "train_base_learners", &StackedLearner::train_base_learners, OptionBase::buildoption, 00105 "whether to train the base learners in the method train (otherwise they should be\n" 00106 "initialized properly at construction / setOption time)\n"); 00107 00108 declareOption(ol, "normalize_base_learners_output", &StackedLearner::normalize_base_learners_output, OptionBase::buildoption, 00109 "If set to 1, the output of the base learners on the combiner training set\n" 00110 "will be normalized (zero mean, unit variance) before training the combiner."); 00111 00112 declareOption(ol, "precompute_base_learners_output", &StackedLearner::precompute_base_learners_output, OptionBase::buildoption, 00113 "If set to 1, the output of the base learners on the combiner training set\n" 00114 "will be precomputed in memory before training the combiner (this may speed\n" 00115 "up significantly the combiner training process)."); 00116 00117 00118 declareOption(ol, "put_raw_input", &StackedLearner::put_raw_input, OptionBase::buildoption, 00119 "whether to put the raw inputs in addition of the base learners outputs, in input of the combiner (default=0)\n"); 00120 00121 // Now call the parent class' declareOptions 00122 inherited::declareOptions(ol); 00123 } 00124 00125 void StackedLearner::setTrainStatsCollector(PP<VecStatsCollector> statscol) 00126 { 00127 train_stats = statscol; 00128 if (combiner) 00129 combiner->setTrainStatsCollector(statscol); 00130 } 00131 00132 void StackedLearner::build_() 00133 { 00134 // ### This method should do the real building of the object, 00135 // ### according to set 'options', in *any* situation. 00136 // ### Typical situations include: 00137 // ### - Initial building of an object from a few user-specified options 00138 // ### - Building of a "reloaded" object: i.e. from the complete set of all serialised options. 00139 // ### - Updating or "re-building" of an object after a few "tuning" options have been modified. 00140 // ### You should assume that the parent class' build_() has already been called. 00141 for (int i=0;i<base_learners.length();i++) 00142 { 00143 if (!base_learners[i]) 00144 PLERROR("StackedLearner::build: base learners have not been created!"); 00145 base_learners[i]->build(); 00146 if (i>0 && base_learners[i]->outputsize()!=base_learners[i-1]->outputsize()) 00147 PLERROR("StackedLearner: expecting base learners to have the same number of outputs!"); 00148 } 00149 if (combiner) 00150 combiner->build(); 00151 if (splitter) 00152 splitter->build(); 00153 if (splitter && splitter->nSetsPerSplit()!=2) 00154 PLERROR("StackedLearner: the Splitter should produce only two sets per split, got %d",splitter->nSetsPerSplit()); 00155 base_learners_outputs.resize(base_learners.length(),base_learners[0]->outputsize()); 00156 } 00157 00158 // ### Nothing to add here, simply calls build_ 00159 void StackedLearner::build() 00160 { 00161 inherited::build(); 00162 build_(); 00163 } 00164 00165 00166 void StackedLearner::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies) 00167 { 00168 inherited::makeDeepCopyFromShallowCopy(copies); 00169 00170 deepCopyField(base_learners, copies); 00171 if (combiner) 00172 deepCopyField(combiner, copies); 00173 if (splitter) 00174 deepCopyField(splitter, copies); 00175 if (base_train_splitter) 00176 deepCopyField(base_train_splitter, copies); 00177 } 00178 00179 00180 int StackedLearner::outputsize() const 00181 { 00182 // compute and return the size of this learner's output, (which typically 00183 // may depend on its inputsize(), targetsize() and set options) 00184 if (combiner) 00185 return combiner->outputsize(); 00186 else 00187 return base_learners[0]->outputsize(); 00188 } 00189 00190 void StackedLearner::forget() 00191 { 00192 if (train_base_learners) 00193 for (int i=0;i<base_learners.length();i++) 00194 base_learners[i]->forget(); 00195 if (combiner) 00196 combiner->forget(); 00197 } 00198 00199 void StackedLearner::setTrainingSet(VMat training_set, bool call_forget) 00200 { 00201 if (splitter) 00202 { 00203 splitter->setDataSet(training_set); 00204 if (splitter->nsplits()==1) 00205 { 00206 TVec<VMat> sets = splitter->getSplit(); 00207 VMat lower_trainset = sets[0]; 00208 VMat upper_trainset = sets[1]; 00209 if (base_train_splitter) { 00210 base_train_splitter->setDataSet(lower_trainset); 00211 } else { 00212 for (int i=0;i<base_learners.length();i++) 00213 base_learners[i]->setTrainingSet(lower_trainset,call_forget && train_base_learners); 00214 } 00215 if (combiner) 00216 combiner->setTrainingSet(new PLearnerOutputVMatrix(upper_trainset, base_learners, put_raw_input),call_forget); 00217 } else { 00218 PLERROR("In StackedLearner::setTrainingSet - The splitter provided should only return one split"); 00219 } 00220 } else 00221 { 00222 if (base_train_splitter) { 00223 base_train_splitter->setDataSet(training_set); 00224 } else { 00225 for (int i=0;i<base_learners.length();i++) 00226 base_learners[i]->setTrainingSet(training_set,call_forget && train_base_learners); 00227 } 00228 if (combiner) 00229 combiner->setTrainingSet(new PLearnerOutputVMatrix(training_set, base_learners, put_raw_input),call_forget); 00230 } 00231 if (base_train_splitter) { 00232 for (int i=0;i<base_learners.length();i++) { 00233 base_learners[i]->setTrainingSet(base_train_splitter->getSplit(i)[0],call_forget && train_base_learners); 00234 } 00235 } 00236 inherited::setTrainingSet(training_set, call_forget); 00237 } 00238 00239 void StackedLearner::train() 00240 { 00241 if (!train_stats) 00242 PLERROR("StackedLearner::train: train_stats has not been set!"); 00243 if (!splitter || splitter->nsplits()==1) // simplest case 00244 { 00245 if (train_base_learners) 00246 for (int i=0;i<base_learners.length();i++) 00247 { 00248 PP<VecStatsCollector> stats = new VecStatsCollector(); 00249 base_learners[i]->setTrainStatsCollector(stats); 00250 if (expdir!="") 00251 base_learners[i]->setExperimentDirectory(expdir+"Base"+tostring(i)); 00252 base_learners[i]->train(); 00253 stats->finalize(); // WE COULD OPTIONALLY SAVE THEM AS WELL! 00254 } 00255 if (combiner) 00256 { 00257 if (normalize_base_learners_output) { 00258 // Normalize the combiner training set. 00259 VMat normalized_trainset = 00260 new ShiftAndRescaleVMatrix(combiner->getTrainingSet(), -1); 00261 combiner->setTrainingSet(normalized_trainset); 00262 } 00263 if (precompute_base_learners_output) { 00264 // First precompute the train set of the combiner in memory. 00265 VMat precomputed_trainset = combiner->getTrainingSet(); 00266 precomputed_trainset.precompute(); 00267 combiner->setTrainingSet(precomputed_trainset, false); 00268 } 00269 combiner->setTrainStatsCollector(train_stats); 00270 if (expdir!="") 00271 combiner->setExperimentDirectory(expdir+"Combiner"); 00272 combiner->train(); 00273 } 00274 } else PLERROR("StackedLearner: multi-splits case not implemented yet"); 00275 } 00276 00277 00278 void StackedLearner::computeOutput(const Vec& input, Vec& output) const 00279 { 00280 for (int i=0;i<base_learners.length();i++) 00281 { 00282 Vec out_i = base_learners_outputs(i); 00283 if (!base_learners[i]) 00284 PLERROR("StackedLearner::computeOutput: base learners have not been created!"); 00285 base_learners[i]->computeOutput(input,out_i); 00286 } 00287 if (combiner) 00288 combiner->computeOutput(base_learners_outputs.toVec(),output); 00289 else // just do a simple average of the outputs 00290 columnMean(base_learners_outputs,output); 00291 } 00292 00293 void StackedLearner::computeCostsFromOutputs(const Vec& input, const Vec& output, 00294 const Vec& target, Vec& costs) const 00295 { 00296 if (combiner) 00297 combiner->computeCostsFromOutputs(base_learners_outputs.toVec(),output,target,costs); 00298 else // cheat 00299 base_learners[0]->computeCostsFromOutputs(input,output,target,costs); 00300 } 00301 00302 TVec<string> StackedLearner::getTestCostNames() const 00303 { 00304 // Return the names of the costs computed by computeCostsFromOutpus 00305 // (these may or may not be exactly the same as what's returned by getTrainCostNames) 00306 if (combiner) 00307 return combiner->getTestCostNames(); 00308 else 00309 return base_learners[0]->getTestCostNames(); 00310 } 00311 00312 TVec<string> StackedLearner::getTrainCostNames() const 00313 { 00314 // Return the names of the objective costs that the train method computes and 00315 // for which it updates the VecStatsCollector train_stats 00316 if (combiner) 00317 return combiner->getTestCostNames(); 00318 else 00319 return base_learners[0]->getTestCostNames(); 00320 } 00321 00322 00323 00324 } // end of namespace PLearn

Generated on Tue Aug 17 16:06:46 2004 for PLearn by doxygen 1.3.7