Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

SemiSupervisedProbClassCostVariable.cc

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // PLearn (A C++ Machine Learning Library) 00004 // Copyright (C) 1998 Pascal Vincent 00005 // Copyright (C) 1999-2002 Pascal Vincent, Yoshua Bengio, Rejean Ducharme and University of Montreal 00006 // Copyright (C) 2001-2002 Nicolas Chapados, Ichiro Takeuchi, Jean-Sebastien Senecal 00007 // Copyright (C) 2002 Xiangdong Wang, Christian Dorion 00008 00009 // Redistribution and use in source and binary forms, with or without 00010 // modification, are permitted provided that the following conditions are met: 00011 // 00012 // 1. Redistributions of source code must retain the above copyright 00013 // notice, this list of conditions and the following disclaimer. 00014 // 00015 // 2. Redistributions in binary form must reproduce the above copyright 00016 // notice, this list of conditions and the following disclaimer in the 00017 // documentation and/or other materials provided with the distribution. 00018 // 00019 // 3. The name of the authors may not be used to endorse or promote 00020 // products derived from this software without specific prior written 00021 // permission. 00022 // 00023 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00024 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00025 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00026 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00027 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00028 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00029 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00030 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00031 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00032 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00033 // 00034 // This file is part of the PLearn library. For more information on the PLearn 00035 // library, go to the PLearn Web site at www.plearn.org 00036 00037 00038 /* ******************************************************* 00039 * $Id: SemiSupervisedProbClassCostVariable.cc,v 1.4 2004/04/27 16:03:35 morinf Exp $ 00040 * This file is part of the PLearn library. 00041 ******************************************************* */ 00042 00043 #include "SemiSupervisedProbClassCostVariable.h" 00044 #include "Var_utils.h" 00045 00046 namespace PLearn { 00047 using namespace std; 00048 00049 00052 PLEARN_IMPLEMENT_OBJECT(SemiSupervisedProbClassCostVariable, 00053 "ONE LINE DESCR", 00054 "NO HELP"); 00055 00056 SemiSupervisedProbClassCostVariable::SemiSupervisedProbClassCostVariable(Var prob_, Var target_, Var prior_, real ff) 00057 : inherited(prob_ & target_ & (VarArray)prior_,1,1), flatten_factor(ff) 00058 { 00059 build_(); 00060 } 00061 00062 void 00063 SemiSupervisedProbClassCostVariable::build() 00064 { 00065 inherited::build(); 00066 build_(); 00067 } 00068 00069 void 00070 SemiSupervisedProbClassCostVariable::build_() 00071 { 00072 if (varray.size() >= 3 && varray[0] && varray[1] && varray[2]) { 00073 // varray[0], varray[1] and varray[2] are (respectively) prob_, target_ and prior_ from constructor 00074 if (varray[2]->length()>0 && varray[0]->length() != varray[2]->length()) 00075 PLERROR("In SemiSupervisedProbClassCostVariable: If prior.length()>0 then prior and prob must have the same size"); 00076 if (!varray[1]->isScalar()) 00077 PLERROR("In SemiSupervisedProbClassCostVariable: target must be a scalar"); 00078 raised_prob.resize(varray[0]->length()); 00079 } 00080 if (flatten_factor <= 0) 00081 PLERROR("In SemiSupervisedProbClassCostVariable: flatten_factor must be positive, and even > 1 for normal use."); 00082 } 00083 00084 void 00085 SemiSupervisedProbClassCostVariable::declareOptions(OptionList &ol) 00086 { 00087 declareOption(ol, "flatten_factor", &SemiSupervisedProbClassCostVariable::flatten_factor, OptionBase::buildoption, ""); 00088 inherited::declareOptions(ol); 00089 } 00090 00091 void SemiSupervisedProbClassCostVariable::recomputeSize(int& l, int& w) const 00092 { l=1; w=1; } 00093 00094 00095 void SemiSupervisedProbClassCostVariable::fprop() 00096 { 00101 real target_value = target()->valuedata[0]; 00102 int n=prob()->size(); 00103 real* p=prob()->valuedata; 00104 if (finite(target_value)) // supervised case 00105 { 00106 int t = int(target_value); 00107 if (t<0 || t>=n) 00108 PLERROR("In SemiSupervisedProbClassCostVariable: target must be either missing or between 0 and %d incl.\n",prob()->size()-1); 00109 valuedata[0] = -safeflog(p[t]); 00110 } 00111 else // unsupervised case 00112 { 00113 sum_raised_prob=0; 00114 real* priorv = prior()->valuedata; 00115 for (int i=0;i<n;i++) 00116 { 00117 raised_prob[i] = pow(priorv[i]*p[i],flatten_factor); 00118 sum_raised_prob += raised_prob[i]; 00119 } 00120 valuedata[0] = - safeflog(sum_raised_prob)/flatten_factor; 00121 } 00122 } 00123 00124 void SemiSupervisedProbClassCostVariable::bprop() 00125 { 00126 real target_value = target()->valuedata[0]; 00127 int n=prob()->size(); 00128 real* dprob=prob()->gradientdata; 00129 real* p=prob()->valuedata; 00130 if (finite(target_value)) // supervised case 00131 { 00132 int t = int(target_value); 00133 for (int i=0;i<n;i++) 00134 if (i==t && p[t]>0) 00135 dprob[i] += -gradientdata[0]/p[t]; 00136 } 00137 else // unsupervised case 00138 { 00139 for (int i=0;i<n;i++) 00140 if (p[i]>0) 00141 { 00142 real grad = - gradientdata[0]*raised_prob[i]/(p[i]*sum_raised_prob); 00143 if (finite(grad)) 00144 dprob[i] += grad; 00145 } 00146 } 00147 } 00148 00149 00150 void SemiSupervisedProbClassCostVariable::symbolicBprop() 00151 { 00152 PLERROR("SemiSupervisedProbClassCostVariable::symbolicBprop() not implemented"); 00153 } 00154 00155 00156 void SemiSupervisedProbClassCostVariable::rfprop() 00157 { 00158 PLERROR("SemiSupervisedProbClassCostVariable::rfprop() not implemented"); 00159 } 00160 00161 00162 00163 } // end of namespace PLearn 00164 00165

Generated on Tue Aug 17 16:05:10 2004 for PLearn by doxygen 1.3.7