Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

Learner.h

Go to the documentation of this file.
00001 // -*- C++ -*-4 1999/10/29 20:41:34 dugas 00002 00003 // Learner.h 00004 // 00005 // Copyright (C) 1998-2002 Pascal Vincent 00006 // Copyright (C) 1999-2002 Yoshua Bengio, Nicolas Chapados, Charles Dugas, Rejean Ducharme, Universite de Montreal 00007 // Copyright (C) 2001,2002 Francis Pieraut, Jean-Sebastien Senecal 00008 // Copyright (C) 2002 Frederic Morin, Xavier Saint-Mleux, Julien Keable 00009 // 00010 // Redistribution and use in source and binary forms, with or without 00011 // modification, are permitted provided that the following conditions are met: 00012 // 00013 // 1. Redistributions of source code must retain the above copyright 00014 // notice, this list of conditions and the following disclaimer. 00015 // 00016 // 2. Redistributions in binary form must reproduce the above copyright 00017 // notice, this list of conditions and the following disclaimer in the 00018 // documentation and/or other materials provided with the distribution. 00019 // 00020 // 3. The name of the authors may not be used to endorse or promote 00021 // products derived from this software without specific prior written 00022 // permission. 00023 // 00024 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00025 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00026 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00027 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00028 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00029 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00030 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00031 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00032 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00033 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00034 // 00035 // This file is part of the PLearn library. For more information on the PLearn 00036 // library, go to the PLearn Web site at www.plearn.org 00037 00038 00039 00040 00041 /* ******************************************************* 00042 * $Id: Learner.h,v 1.15 2004/07/21 16:30:56 chrish42 Exp $ 00043 ******************************************************* */ 00044 00045 00048 #ifndef Learner_INC 00049 #define Learner_INC 00050 00051 #include <plearn/measure/Measurer.h> 00052 #include <plearn/ker/Kernel.h> 00053 #include <plearn/math/VecStatsCollector.h> 00054 #include <plearn/math/StatsIterator.h> 00055 #include <plearn/vmat/VVec.h> 00056 //#include "TimeMeasurer.h" 00057 00058 namespace PLearn { 00059 using namespace std; 00060 00072 class Learner: public Object, public Measurer//, public TimeMeasurer 00073 { 00074 protected: 00075 00076 Vec tmpvec; // for temporary storage. 00077 00078 // EN FAIRE UN POINTEUR AUSSI 00079 ofstream* train_objective_stream; 00080 Array<ofstream*> test_results_streams; 00081 00082 private: 00083 00084 static Vec tmp_input; // temporary input vec 00085 static Vec tmp_target; // temporary target vec 00086 static Vec tmp_weight; // temporary example weight vec 00087 static Vec tmp_output; // temporary output vec 00088 static Vec tmp_costs; // temporary costs vec 00089 00090 protected: 00091 00093 void openTrainObjectiveStream(); 00094 00097 ostream& getTrainObjectiveStream(); 00098 00100 void openTestResultsStreams(); 00101 00104 ostream& getTestResultsStream(int k); 00105 00107 void freeTestResultsStreams(); 00108 00110 void outputResultLineToFile(const string & filename, const Vec& results,bool append,const string& names); 00111 00112 protected: 00115 string expdir; 00116 00118 int epoch_; 00119 00123 bool distributed_; 00124 00125 00126 public: 00127 00129 string basename() const; 00130 00131 typedef Object inherited; 00132 00133 int inputsize_; 00134 int targetsize_; 00135 int outputsize_; 00136 int weightsize_; //<! number of weight fields in the target vec (all_targets = actual_target & weights) 00137 00140 bool dont_parallelize; //<! (default: false) 00141 00143 //oassignstream testout; 00144 PStream testout; 00145 int test_every; 00146 Vec avg_objective; 00147 Vec avgsq_objective; 00148 VMat train_set; 00149 Array<VMat> test_sets; 00150 int minibatch_size; 00151 00156 int report_test_progress_every; 00157 00160 Vec options; 00161 00163 int earlystop_testsetnum; 00164 int earlystop_testresultindex; 00165 real earlystop_max_degradation; 00166 real earlystop_min_value; 00167 real earlystop_min_improvement; 00168 bool earlystop_relative_changes; 00169 bool earlystop_save_best; 00170 int earlystop_max_degraded_steps; 00171 00172 bool save_at_every_epoch; 00173 bool save_objective; //<! whether to save in basename()+".objective" the cost after each measure (e.g. after each epoch) 00174 int best_step; 00175 00176 protected: 00178 real earlystop_previousval; 00179 public: 00180 real earlystop_minval; 00181 00182 // DPERECATED. Please use the expdir system from now on, through setExperimentDirectory 00183 string experiment_name; 00184 00185 protected: 00186 //strstream earlystop_best_model; //!< string stream where the currently best model is saved 00187 00189 Array<Measurer*> measurers; 00190 00191 bool measure_cpu_time_first; // the first el. in measure(..) will be cpu time instead of courant step 00192 00193 bool each_cpu_saves_its_errors; 00194 public: 00195 Array<CostFunc> test_costfuncs; 00196 StatsItArray test_statistics; 00197 00198 static int use_file_if_bigger; 00199 00200 00201 00202 static bool force_saving_on_all_processes; 00203 00204 static PStream& /*oassignstream&*/ default_vlog(); 00205 //oassignstream vlog; //!< The log stream to which all the verbose output from this learner should be sent 00206 //oassignstream objectiveout; //!< The log stream to use to record the objective function during training 00207 PStream vlog; 00208 PStream objectiveout; 00209 00216 Learner(int the_inputsize=0, int the_targetsize=0, int the_outputsize=0); 00217 00218 virtual ~Learner(); 00219 00222 00228 virtual void setExperimentDirectory(const string& the_expdir); 00229 string getExperimentDirectory() const { return expdir; } 00230 00233 PLEARN_DECLARE_ABSTRACT_OBJECT(Learner); 00234 virtual void makeDeepCopyFromShallowCopy(CopiesMap& copies); 00235 00236 private: 00246 void build_(); 00247 00248 public: 00251 virtual void build(); 00252 00254 virtual void setTrainingSet(VMat training_set) { train_set = training_set; } 00255 inline VMat getTrainingSet() { return train_set; } 00256 00263 virtual void train(VMat training_set) =0; 00264 00265 00270 virtual void newtrain(VecStatsCollector& train_stats); 00271 00272 00275 virtual void newtest(VMat testset, VecStatsCollector& test_stats, 00276 VMat testoutputs=0, VMat testcosts=0); 00277 00278 /* 00279 virtual void useAndCost(Vec input, Vec target, Vec output, Vec cost) 00280 00281 virtual void trainTest(VMat train, Array<VMat> testsets); 00282 virtual void trainKFold(VMat trainset, int k); 00283 virtual void trainBootstrap(VMat trainset, int k, Array<VMat> testsets); 00284 virtual void trainSequential(VMat dataset, sequence_spec); 00285 00286 */ 00287 00292 virtual void train(VMat training_set, VMat accept_prob, 00293 real max_accept_prob=1.0, VMat weights=VMat()) 00294 { PLERROR("This method is not implemented for this learner"); } 00295 00301 virtual void use(const Vec& input, Vec& output) =0; 00302 virtual void use(const Mat& inputs, Mat outputs) 00303 { 00304 for (int i=0;i<inputs.length();i++) 00305 { 00306 Vec input = inputs(i); 00307 Vec output = outputs(i); 00308 use(input,output); 00309 } 00310 } 00311 00313 Vec vec_input; 00314 00317 // NOTE: For backward compatibility, default version currently calls 00318 // deprecated method use which should ultimately be removed... 00319 virtual void computeOutput(const VVec& input, Vec& output); 00320 00324 // NOTE: For backward compatibility, default version currently calls the 00325 // deprecated method computeCost which should ultimately be removed... 00326 virtual void computeCostsFromOutputs(const VVec& input, const Vec& output, 00327 const VVec& target, const VVec& weight, 00328 Vec& costs); 00329 00330 00334 virtual void computeOutputAndCosts(const VVec& input, VVec& target, const VVec& weight, 00335 Vec& output, Vec& costs); 00336 00340 virtual void computeCosts(const VVec& input, VVec& target, VVec& weight, 00341 Vec& costs); 00342 00343 00347 virtual void setModel(const Vec& new_options); 00348 00355 virtual void forget(); 00356 00378 virtual bool measure(int step, const Vec& costs); 00379 00387 virtual void oldwrite(ostream& out) const; 00388 virtual void oldread(istream& in); 00389 00391 void save(const string& filename="") const; 00393 void load(const string& filename=""); 00394 00399 virtual void stop_if_wanted(); 00400 00402 inline int inputsize() const { return inputsize_; } 00403 inline int targetsize() const { return targetsize_; } 00404 inline int outputsize() const { return outputsize_; } 00405 inline int weightsize() const { return weightsize_; } 00406 inline int epoch() const { return epoch_; } 00407 00411 virtual int costsize() const; 00412 00415 void setTestCostFunctions(Array<CostFunc> costfunctions) 00416 { test_costfuncs = costfunctions; } 00417 00420 void setTestStatistics(StatsItArray statistics) 00421 { test_statistics = statistics; } 00422 00425 virtual void setTestDuringTrain(ostream& testout, int every, 00426 Array<VMat> testsets); 00427 00429 virtual void setTestDuringTrain(Array<VMat> testsets); 00430 00431 00433 const Array<VMat>& getTestDuringTrain() const { 00434 return test_sets; 00435 } 00436 00437 00454 void setEarlyStopping(int which_testset, int which_testresult, 00455 real max_degradation, real min_value=-FLT_MAX, 00456 real min_improvement=0, bool relative_changes=true, 00457 bool save_best=true, int max_degraded_steps=-1); 00458 00466 virtual void computeCost(const Vec& input, const Vec& target, const Vec& output, const Vec& cost); 00467 00470 virtual void useAndCost(const Vec& input, const Vec& target, 00471 Vec output, Vec cost); 00472 00478 virtual void useAndCostOnTestVec(const VMat& test_set, int i, const Vec& output, const Vec& cost); 00479 00486 virtual void apply(const VMat& data, VMat outputs); 00487 00492 virtual void applyAndComputeCosts(const VMat& data, VMat outputs, VMat costs); 00493 00497 virtual void applyAndComputeCostsOnTestMat(const VMat& test_set, int i, const Mat& output_block, 00498 const Mat& cost_block); 00499 00504 virtual void computeCosts(const VMat& data, VMat costs); 00505 00508 virtual void computeLeaveOneOutCosts(const VMat& data, VMat costs); 00509 00514 virtual void computeLeaveOneOutCosts(const VMat& data, VMat costsmat, CostFunc costf); 00515 00521 Vec computeTestStatistics(const VMat& costs); 00522 00527 virtual Vec test(VMat test_set, const string& save_test_outputs="", 00528 const string& save_test_costs=""); 00529 00534 virtual Array<string> costNames() const; 00535 00541 virtual Array<string> testResultsNames() const; 00542 00546 virtual Array<string> trainObjectiveNames() const; 00547 00552 void appendMeasurer(Measurer& measurer) 00553 { measurers.append(&measurer); } 00554 00555 protected: 00556 static void declareOptions(OptionList& ol); 00557 00558 void setTrainCost(Vec &cost) 00559 { train_cost.resize(cost.length()); train_cost << cost; }; 00560 Vec train_cost; 00561 public: 00562 Vec getTrainCost() { return train_cost; }; 00563 }; 00564 00565 DECLARE_OBJECT_PTR(Learner); 00566 00567 typedef PP<Learner> PPLearner; 00568 00569 inline void prettyprint_test_results(ostream& out, const Learner& learner, const Vec& results) 00570 { 00571 Array<string> names = learner.testResultsNames(); 00572 for (int i=0; i<names.size(); i++) 00573 out << names[i] << ": " << results[i] << endl; 00574 } 00575 00576 00577 } // end of namespace PLearn 00578 00579 #endif 00580 00581 00582 00583 00584

Generated on Tue Aug 17 15:57:04 2004 for PLearn by doxygen 1.3.7