Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

SumOfVariable.cc

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // PLearn (A C++ Machine Learning Library) 00004 // Copyright (C) 1998 Pascal Vincent 00005 // Copyright (C) 1999-2002 Pascal Vincent, Yoshua Bengio, Rejean Ducharme and University of Montreal 00006 // Copyright (C) 2001-2002 Nicolas Chapados, Ichiro Takeuchi, Jean-Sebastien Senecal 00007 // Copyright (C) 2002 Xiangdong Wang, Christian Dorion 00008 00009 // Redistribution and use in source and binary forms, with or without 00010 // modification, are permitted provided that the following conditions are met: 00011 // 00012 // 1. Redistributions of source code must retain the above copyright 00013 // notice, this list of conditions and the following disclaimer. 00014 // 00015 // 2. Redistributions in binary form must reproduce the above copyright 00016 // notice, this list of conditions and the following disclaimer in the 00017 // documentation and/or other materials provided with the distribution. 00018 // 00019 // 3. The name of the authors may not be used to endorse or promote 00020 // products derived from this software without specific prior written 00021 // permission. 00022 // 00023 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00024 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00025 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00026 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00027 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00028 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00029 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00030 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00031 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00032 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00033 // 00034 // This file is part of the PLearn library. For more information on the PLearn 00035 // library, go to the PLearn Web site at www.plearn.org 00036 00037 00038 /* ******************************************************* 00039 * $Id: SumOfVariable.cc,v 1.10 2004/07/21 16:30:54 chrish42 Exp $ 00040 * This file is part of the PLearn library. 00041 ******************************************************* */ 00042 00043 #include "SumOfVariable.h" 00044 #include <plearn/sys/PLMPI.h> 00045 #include <plearn/display/DisplayUtils.h> 00046 00047 namespace PLearn { 00048 using namespace std; 00049 00050 00051 00054 PLEARN_IMPLEMENT_OBJECT(SumOfVariable, 00055 "Variable that sums the value of a Func evaluated on each row of a VMat", 00056 "NO HELP"); 00057 00058 SumOfVariable::SumOfVariable(VMat the_distr, Func the_f, int the_nsamples) 00059 : inherited(nonInputParentsOfPath(the_f->inputs,the_f->outputs), 00060 the_f->outputs[0]->length(), 00061 the_f->outputs[0]->width()), 00062 distr(the_distr), f(the_f), nsamples(the_nsamples), curpos(0), 00063 //input_value(the_distr->inputsize()+the_distr->targetsize()+the_distr->weightsize()), 00064 //input_gradient(the_distr->inputsize()+the_distr->targetsize()+the_distr->weightsize()), 00065 input_value(the_distr->width()), 00066 input_gradient(the_distr->width()), 00067 output_value(the_f->outputs[0]->size()) 00068 { 00069 build_(); 00070 } 00071 00072 void 00073 SumOfVariable::build() 00074 { 00075 inherited::build(); 00076 build_(); 00077 } 00078 00079 void 00080 SumOfVariable::build_() 00081 { 00082 if (f && distr) { 00083 input_value.resize(distr->inputsize() + distr->targetsize() + distr->weightsize()); 00084 input_gradient.resize(distr->inputsize() + distr->targetsize() + distr->weightsize()); 00085 if(f->outputs.size() != 1) 00086 PLERROR("In SumOfVariable: function must have a single variable output (maybe you can vconcat the vars into a single one prior to calling sumOf, if this is really what you want)"); 00087 00088 if(nsamples == -1) 00089 nsamples = distr->length(); 00090 f->inputs.setDontBpropHere(true); 00091 } 00092 } 00093 00094 void 00095 SumOfVariable::declareOptions(OptionList &ol) 00096 { 00097 declareOption(ol, "distr", &SumOfVariable::distr, OptionBase::buildoption, ""); 00098 declareOption(ol, "f", &SumOfVariable::f, OptionBase::buildoption, ""); 00099 declareOption(ol, "nsamples", &SumOfVariable::nsamples, OptionBase::buildoption, ""); 00100 declareOption(ol, "curpos", &SumOfVariable::curpos, OptionBase::buildoption, ""); 00101 inherited::declareOptions(ol); 00102 } 00103 00104 00105 void SumOfVariable::recomputeSize(int& l, int& w) const 00106 { 00107 if (f && f->outputs.size()) { 00108 l = f->outputs[0]->length(); 00109 w = f->outputs[0]->width(); 00110 } else 00111 l = w = 0; 00112 } 00113 00114 00115 void SumOfVariable::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies) 00116 { 00117 NaryVariable::makeDeepCopyFromShallowCopy(copies); 00118 deepCopyField(distr, copies); 00119 deepCopyField(f, copies); 00120 } 00121 00122 00123 void SumOfVariable::fprop() 00124 { 00125 f->recomputeParents(); 00126 00127 if(nsamples==1) 00128 { 00129 input_value.resize(distr->width()); 00130 distr->getRow(curpos, input_value); 00131 input_value.resize(distr->inputsize()+distr->targetsize()+distr->weightsize()); 00132 f->fprop(input_value, value); 00133 if(++curpos == distr->length()) 00134 curpos = 0; 00135 } 00136 else 00137 { 00138 value.clear(); 00139 #if USING_MPI 00140 if (nsamples > distr->length()) 00141 PLERROR("In SumOfVariable::fprop, the case where nsamples is greater than distr->length is not supported in parallel computation"); 00142 int nb_sample = nsamples/PLMPI::size; 00143 int start_pos = PLMPI::rank * nb_sample; 00144 int end_pos = (PLMPI::rank==PLMPI::size-1) ? nsamples : start_pos + nb_sample; 00145 Vec dummy_value(value.length()); 00146 for(int i=start_pos; i<end_pos; i++) 00147 { 00148 input_value.resize(distr->width()); 00149 distr->getRow(i, input_value); 00150 input_value.resize(distr->inputsize()+distr->targetsize()+distr->weightsize()); 00151 f->fprop(input_value, output_value); 00152 dummy_value += output_value; 00153 } 00154 MPI_Allreduce(dummy_value.data(), value.data(), value.length(), PLMPI_REAL, MPI_SUM, MPI_COMM_WORLD); 00155 #else 00156 for(int i=0; i<nsamples; i++) 00157 { 00158 input_value.resize(distr->width()); 00159 distr->getRow(curpos, input_value); 00160 input_value.resize(distr->inputsize()+distr->targetsize()+distr->weightsize()); 00161 f->fprop(input_value, output_value); 00162 value += output_value; 00163 if(++curpos == distr->length()) 00164 curpos = 0; 00165 } 00166 #endif 00167 } 00168 } 00169 00170 00171 void SumOfVariable::bprop() 00172 { fbprop(); } 00173 00174 00175 void SumOfVariable::fbprop() 00176 { 00177 f->recomputeParents(); 00178 00179 if(nsamples==1) 00180 { 00181 input_value.resize(distr->width()); 00182 distr->getRow(curpos, input_value); 00183 input_value.resize(distr->inputsize()+distr->targetsize()+distr->weightsize()); 00184 //displayFunction(f, true, false, 250); 00185 f->fbprop(input_value, value, input_gradient, gradient); 00186 //displayFunction(f, true, false, 250); 00187 if(++curpos == distr->length()) 00188 curpos = 0; 00189 } 00190 else 00191 { 00192 value.clear(); 00193 #if USING_MPI 00194 if (nsamples > distr->length()) 00195 PLERROR("In SumOfVariable::fbprop, the case where nsamples is greater than distr->length is not supported in parallel computation"); 00196 int nb_sample = nsamples/PLMPI::size; 00197 int start_pos = PLMPI::rank * nb_sample; 00198 int end_pos = (PLMPI::rank==PLMPI::size-1) ? nsamples : start_pos + nb_sample; 00199 Vec dummy_value(value.length()); 00200 for(int i=start_pos; i<end_pos; i++) 00201 { 00202 input_value.resize(distr->width()); 00203 distr->getRow(i, input_value); 00204 input_value.resize(distr->inputsize()+distr->targetsize()+distr->weightsize()); 00205 f->fbprop(input_value, output_value, input_gradient, gradient); 00206 dummy_value += output_value; 00207 } 00208 MPI_Allreduce(dummy_value.data(), value.data(), value.length(), PLMPI_REAL, MPI_SUM, MPI_COMM_WORLD); 00209 VarArray params = f->parameters; 00210 for (int i=0; i<params->length(); i++) 00211 { 00212 Vec buffer(params[i]->size()); 00213 MPI_Reduce(params[i]->gradientdata, buffer.data(), buffer.length(), PLMPI_REAL, MPI_SUM, 0, MPI_COMM_WORLD); 00214 buffer >> params[i]->gradient; 00215 MPI_Bcast(params[i]->gradientdata, buffer.length(), PLMPI_REAL, 0, MPI_COMM_WORLD); 00216 } 00217 #else 00218 for(int i=0; i<nsamples; i++) 00219 { 00220 input_value.resize(distr->width()); 00221 distr->getRow(curpos, input_value); 00222 input_value.resize(distr->inputsize()+distr->targetsize()+distr->weightsize()); 00223 static bool display_fn=false; 00224 if (display_fn) 00225 displayFunction(f, true, false, 250); 00226 f->fbprop(input_value, output_value, input_gradient, gradient); 00227 value += output_value; 00228 if(++curpos == distr->length()) 00229 curpos = 0; 00230 } 00231 #endif 00232 } 00233 } 00234 00235 00236 void SumOfVariable::symbolicBprop() 00237 { 00238 /* 00239 // f is a function of its inputs, what we want is a function of the parameters of f (which are in the inputs field of this SumOfVariable) 00240 VarArray& params = varray; 00241 int nparams = params.size(); 00242 f->bproppath.symbolicBprop(); 00243 00244 VarArray dparams(nparams); 00245 for(int i=0; i<nparams; i++) 00246 dparams[i] = params[i]->g; 00247 00248 Var dparams_concat = new ConcatElementsVariable(dparams); 00249 Var dparams_sum = new SumOfVariable(distr, Func(params,dparams_concat), nsamples); 00250 00251 for(int i=0; i<nparams; i++) 00252 params[i]->g += dparams_sum.sub(...) 00253 */ 00254 } 00255 00256 00257 void SumOfVariable::rfprop() 00258 { 00259 if (rValue.length()==0) resizeRValue(); 00260 // TODO... (we will need a rfprop() in Func) 00261 00262 // f->recomputeParents(); 00263 00264 // if(nsamples==1) 00265 // { 00266 // distr->getRow(curpos, input_value); 00267 // f->fprop(input_value, value); 00268 // if(++curpos == distr->length()) 00269 // curpos = 0; 00270 // } 00271 // else 00272 // { 00273 // value.clear(); 00274 // #if USING_MPI 00275 // if (nsamples > distr->length()) 00276 // PLERROR("In SumOfVariable::fprop, the case where nsamples is greater than distr->length is not supported in parallel computation"); 00277 // int nb_sample = nsamples/PLMPI::size; 00278 // int start_pos = PLMPI::rank * nb_sample; 00279 // int end_pos = (PLMPI::rank==PLMPI::size-1) ? nsamples : start_pos + nb_sample; 00280 // Vec dummy_value(value.length()); 00281 // for(int i=start_pos; i<end_pos; i++) 00282 // { 00283 // distr->getRow(i, input_value); 00284 // f->fprop(input_value, output_value); 00285 // dummy_value += output_value; 00286 // } 00287 // MPI_Allreduce(dummy_value.data(), value.data(), value.length(), PLMPI_REAL, MPI_SUM, MPI_COMM_WORLD); 00288 // #else 00289 // for(int i=0; i<nsamples; i++) 00290 // { 00291 // distr->getRow(curpos, input_value); 00292 // f->fprop(input_value, output_value); 00293 // value += output_value; 00294 // if(++curpos == distr->length()) 00295 // curpos = 0; 00296 // } 00297 // #endif 00298 // } 00299 } 00300 00301 00302 void SumOfVariable::printInfo(bool print_gradient) 00303 { 00304 Vec input_value(distr->width()); 00305 Vec input_gradient(distr->width()); 00306 Vec output_value(nelems()); 00307 00308 f->recomputeParents(); 00309 value.clear(); 00310 00311 for(int i=0; i<nsamples; i++) 00312 { 00313 input_value.resize(distr->width()); 00314 distr->getRow(curpos++,input_value); 00315 input_value.resize(distr->inputsize()+distr->targetsize()+distr->weightsize()); 00316 if (print_gradient) 00317 f->fbprop(input_value, output_value, input_gradient, gradient); 00318 else 00319 f->fprop(input_value, output_value); 00320 value += output_value; 00321 if(curpos>=distr->length()) 00322 curpos = 0; 00323 f->fproppath.printInfo(print_gradient); 00324 } 00325 cout << info() << " : " << getName() << " = " << value; 00326 if (print_gradient) cout << " gradient=" << gradient; 00327 cout << endl; 00328 } 00329 00330 00331 00332 } // end of namespace PLearn 00333 00334

Generated on Tue Aug 17 16:07:59 2004 for PLearn by doxygen 1.3.7