00001 // -*- C++ -*- 00002 00003 // NNet.h 00004 // Copyright (c) 1998-2002 Pascal Vincent 00005 // Copyright (C) 1999-2002 Yoshua Bengio and University of Montreal 00006 // Copyright (c) 2002 Jean-Sebastien Senecal, Xavier Saint-Mleux, Rejean Ducharme 00007 // 00008 // Redistribution and use in source and binary forms, with or without 00009 // modification, are permitted provided that the following conditions are met: 00010 // 00011 // 1. Redistributions of source code must retain the above copyright 00012 // notice, this list of conditions and the following disclaimer. 00013 // 00014 // 2. Redistributions in binary form must reproduce the above copyright 00015 // notice, this list of conditions and the following disclaimer in the 00016 // documentation and/or other materials provided with the distribution. 00017 // 00018 // 3. The name of the authors may not be used to endorse or promote 00019 // products derived from this software without specific prior written 00020 // permission. 00021 // 00022 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00023 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00024 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00025 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00026 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00027 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00028 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00029 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00030 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00031 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00032 // 00033 // This file is part of the PLearn library. For more information on the PLearn 00034 // library, go to the PLearn Web site at www.plearn.org 00035 00036 00037 /* ******************************************************* 00038 * $Id: NNet.h,v 1.20 2004/07/21 16:30:56 chrish42 Exp $ 00039 ******************************************************* */ 00040 00043 #ifndef NNet_INC 00044 #define NNet_INC 00045 00046 #include "PLearner.h" 00047 #include <plearn/opt/Optimizer.h> 00048 //#include "Var_all.h" 00049 00050 namespace PLearn { 00051 using namespace std; 00052 00053 class NNet: public PLearner 00054 { 00055 protected: 00056 Var input; // Var(inputsize()) 00057 Var target; // Var(targetsize()-weightsize()) 00058 Var sampleweight; // Var(1) if train_set->hasWeights() 00059 Var w1; // bias and weights of first hidden layer 00060 Var w2; // bias and weights of second hidden layer 00061 Var wout; // bias and weights of output layer 00062 Var outbias; // bias used only if fixed_output_weights 00063 Var wdirect; // bias and weights for direct in-to-out connection 00064 Var wrec; // input reconstruction weights (optional), from hidden layer to predicted input 00065 Var rbf_centers; // n_classes (or n_classes-1) x rbf_layer_size = mu_i of RBF gaussians 00066 Var rbf_sigmas; // n_classes (or n_classes-1) entries (sigma's of the RBFs) 00067 Var junk_prob; // scalar background (junk) probability, if first_class_is_junk 00068 Var output; 00069 Var predicted_input; 00070 VarArray costs; // all costs of interest 00071 VarArray penalties; 00072 Var training_cost; // weighted scalar costs[0] including penalties 00073 Var test_costs; // hconcat(costs) 00074 VarArray invars; 00075 VarArray params; // all arameter input vars 00076 00077 Vec paramsvalues; // values of all parameters 00078 00079 public: 00080 mutable Func f; // input -> output 00081 mutable Func test_costf; // input & target -> output & test_costs 00082 mutable Func output_and_target_to_cost; // output & target -> cost 00083 00084 public: 00085 00086 typedef PLearner inherited; 00087 00088 // Build options inherited from learner: 00089 // inputsize, outputszie, targetsize, experiment_name, save_at_every_epoch 00090 00091 // Build options: 00092 int nhidden; // number of hidden units in first hidden layer (default:0) 00093 int nhidden2; // number of hidden units in second hidden layer (default:0) 00094 int noutputs; // number of output units (outputsize) 00095 00096 real weight_decay; // default: 0 00097 real bias_decay; // default: 0 00098 real layer1_weight_decay; // default: MISSING_VALUE 00099 real layer1_bias_decay; // default: MISSING_VALUE 00100 real layer2_weight_decay; // default: MISSING_VALUE 00101 real layer2_bias_decay; // default: MISSING_VALUE 00102 real output_layer_weight_decay; // default: MISSING_VALUE 00103 real output_layer_bias_decay; // default: MISSING_VALUE 00104 real direct_in_to_out_weight_decay; // default: MISSING_VALUE 00105 real classification_regularizer; // default: 0 00106 real margin; // default: 1, used with margin_perceptron_cost 00107 bool fixed_output_weights; 00108 00109 int rbf_layer_size; // number of representation units when adding an rbf layer in output 00110 bool first_class_is_junk; 00111 00112 bool L1_penalty; // default: false 00113 real input_reconstruction_penalty; // default = 0 00114 bool direct_in_to_out; // should we include direct input to output connecitons? default: false 00115 string output_transfer_func; // tanh, sigmoid, softplus, softmax, etc... (default: "" means no transfer function) 00116 string hidden_transfer_func; // tanh, sigmoid, softplus, softmax, etc... (default: "tanh" means no transfer function) 00117 real interval_minval, interval_maxval; // if output_transfer_func = interval(minval,maxval), these are the interval bounds 00118 00120 // where the cost functions can be one of mse, mse_onehot, NLL, 00121 // class_error or multiclass_error (no default) 00122 Array<string> cost_funcs; 00123 00124 // Build options related to the optimization: 00125 PP<Optimizer> optimizer; // the optimizer to use (no default) 00126 00127 int batch_size; // how many samples to use to estimate gradient before an update 00128 // 0 means the whole training set (default: 1) 00129 00130 string initialization_method; 00131 00132 00133 private: 00134 void build_(); 00135 00136 public: 00137 00138 NNet(); 00139 virtual ~NNet(); 00140 PLEARN_DECLARE_OBJECT(NNet); 00141 00142 virtual void build(); 00143 virtual void forget(); // simply calls initializeParams() 00144 00145 virtual int outputsize() const; 00146 virtual TVec<string> getTrainCostNames() const; 00147 virtual TVec<string> getTestCostNames() const; 00148 00149 virtual void train(); 00150 00151 virtual void computeOutput(const Vec& input, Vec& output) const; 00152 00153 virtual void computeOutputAndCosts(const Vec& input, const Vec& target, 00154 Vec& output, Vec& costs) const; 00155 00156 virtual void computeCostsFromOutputs(const Vec& input, const Vec& output, 00157 const Vec& target, Vec& costs) const; 00158 00159 virtual void makeDeepCopyFromShallowCopy(CopiesMap &copies); 00160 00161 protected: 00162 static void declareOptions(OptionList& ol); 00163 void initializeParams(); 00164 00165 }; 00166 00167 DECLARE_OBJECT_PTR(NNet); 00168 00169 } // end of namespace PLearn 00170 00171 #endif 00172