Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

MultiInstanceNNet.h

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // MultiInstanceNNet.h 00004 // Copyright (c) 1998-2002 Pascal Vincent 00005 // Copyright (C) 1999-2002 Yoshua Bengio and University of Montreal 00006 // Copyright (c) 2002 Jean-Sebastien Senecal, Xavier Saint-Mleux, Rejean Ducharme 00007 // 00008 // Redistribution and use in source and binary forms, with or without 00009 // modification, are permitted provided that the following conditions are met: 00010 // 00011 // 1. Redistributions of source code must retain the above copyright 00012 // notice, this list of conditions and the following disclaimer. 00013 // 00014 // 2. Redistributions in binary form must reproduce the above copyright 00015 // notice, this list of conditions and the following disclaimer in the 00016 // documentation and/or other materials provided with the distribution. 00017 // 00018 // 3. The name of the authors may not be used to endorse or promote 00019 // products derived from this software without specific prior written 00020 // permission. 00021 // 00022 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00023 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00024 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00025 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00026 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00027 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00028 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00029 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00030 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00031 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00032 // 00033 // This file is part of the PLearn library. For more information on the PLearn 00034 // library, go to the PLearn Web site at www.plearn.org 00035 00036 00037 /* ******************************************************* 00038 * $Id: MultiInstanceNNet.h,v 1.14 2004/07/21 20:14:37 tihocan Exp $ 00039 ******************************************************* */ 00040 00043 #ifndef MultiInstanceNNet_INC 00044 #define MultiInstanceNNet_INC 00045 00046 #include <plearn_learners/generic/PLearner.h> 00047 #include <plearn/opt/Optimizer.h> 00048 00049 namespace PLearn { 00050 using namespace std; 00051 00052 class MultiInstanceNNet: public PLearner 00053 { 00054 00055 private: 00056 00057 typedef PLearner inherited; 00058 00060 mutable Vec instance_logP0; 00061 00062 protected: 00063 00064 Var input; // Var(inputsize()) 00065 Var target; // Var(targetsize()-weightsize()) 00066 Var sampleweight; // Var(1) if train_set->hasWeights() 00067 Var w1; // bias and weights of first hidden layer 00068 Var w2; // bias and weights of second hidden layer 00069 Var wout; // bias and weights of output layer 00070 Var wdirect; // bias and weights for direct in-to-out connection 00071 00072 Var output; // output (P(y_i|x_i)) for a single bag element 00073 Var bag_size; // filled up by SumOverBagsVariable 00074 Var bag_inputs; // filled up by SumOverBagsVariable 00075 Var bag_output; // P(y=1|bag_inputs) 00076 00077 Func inputs_and_targets_to_test_costs; // (bag inputs and targets) -> (bag NLL, bag class. err, lift_output) 00078 Func inputs_and_targets_to_training_costs; // (bag inputs and targets) -> (bag NLL + penalty, bag NLL, bag class. err, lift_output) 00079 Func input_to_logP0; // single input x -> log P(y=0|x) 00080 Var nll; 00081 00082 VarArray costs; // (negative log-likelihood, classification error) for the bag 00083 VarArray penalties; 00084 Var training_cost; // weighted cost + penalties 00085 Var test_costs; // hconcat(costs) 00086 VarArray invars; 00087 VarArray params; // all arameter input vars 00088 00089 Vec paramsvalues; // values of all parameters 00090 00091 int optstage_per_lstage; // number of bags in training set / batch_size (in nb of bags) 00092 bool training_set_has_changed; // if so, must count nb of bags in training set 00093 00094 public: 00095 mutable Func f; // input -> output 00096 mutable Func test_costf; // input & target -> output & test_costs 00097 mutable Func output_and_target_to_cost; // output & target -> cost 00098 00099 public: 00100 00101 // Build options inherited from learner: 00102 // inputsize, outputsize, targetsize, experiment_name, save_at_every_epoch 00103 00104 // Build options: 00105 int max_n_instances; // maximum number of instances (input vectors x_i) allowed 00106 00107 int nhidden; // number of hidden units in first hidden layer (default:0) 00108 int nhidden2; // number of hidden units in second hidden layer (default:0) 00109 00110 real weight_decay; // default: 0 00111 real bias_decay; // default: 0 00112 real layer1_weight_decay; // default: MISSING_VALUE 00113 real layer1_bias_decay; // default: MISSING_VALUE 00114 real layer2_weight_decay; // default: MISSING_VALUE 00115 real layer2_bias_decay; // default: MISSING_VALUE 00116 real output_layer_weight_decay; // default: MISSING_VALUE 00117 real output_layer_bias_decay; // default: MISSING_VALUE 00118 real direct_in_to_out_weight_decay; // default: MISSING_VALUE 00119 real classification_regularizer; // default: 0 00120 00121 bool L1_penalty; // default: false 00122 bool direct_in_to_out; // should we include direct input to output connecitons? default: false 00123 real interval_minval, interval_maxval; // if output_transfer_func = interval(minval,maxval), these are the interval bounds 00124 mutable int test_bag_size; // counting the number of instances in a test bag 00125 00126 // Build options related to the optimization: 00127 PP<Optimizer> optimizer; // the optimizer to use (no default) 00128 00129 int batch_size; // how many samples to use to estimate gradient before an update 00130 // 0 means the whole training set (default: 1) 00131 00132 private: 00133 void build_(); 00134 00135 public: 00136 00137 MultiInstanceNNet(); 00138 virtual ~MultiInstanceNNet(); 00139 PLEARN_DECLARE_OBJECT(MultiInstanceNNet); 00140 00141 // PLearner methods 00142 00143 virtual void setTrainingSet(VMat training_set, bool call_forget=true); 00144 00145 virtual void build(); 00146 virtual void forget(); // simply calls initializeParams() 00147 00148 virtual int outputsize() const; 00149 virtual TVec<string> getTrainCostNames() const; 00150 virtual TVec<string> getTestCostNames() const; 00151 00152 virtual void train(); 00153 00154 virtual void computeOutput(const Vec& input, Vec& output) const; 00155 00156 virtual void computeOutputAndCosts(const Vec& input, const Vec& target, 00157 Vec& output, Vec& costs) const; 00158 00159 virtual void computeCostsFromOutputs(const Vec& input, const Vec& output, 00160 const Vec& target, Vec& costs) const; 00161 00162 00163 virtual void makeDeepCopyFromShallowCopy(CopiesMap &copies); 00164 00165 protected: 00166 static void declareOptions(OptionList& ol); 00167 void initializeParams(); 00168 00169 }; 00170 00171 DECLARE_OBJECT_PTR(MultiInstanceNNet); 00172 00173 } // end of namespace PLearn 00174 00175 #endif 00176

Generated on Tue Aug 17 15:59:12 2004 for PLearn by doxygen 1.3.7