Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

DiagonalizedFactorsProductVariable.cc

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // PLearn (A C++ Machine Learning Library) 00004 // Copyright (C) 1998 Pascal Vincent 00005 // Copyright (C) 1999-2002 Pascal Vincent, Yoshua Bengio, Rejean Ducharme and University of Montreal 00006 // Copyright (C) 2001-2002 Nicolas Chapados, Ichiro Takeuchi, Jean-Sebastien Senecal 00007 // Copyright (C) 2002 Xiangdong Wang, Christian Dorion 00008 00009 // Redistribution and use in source and binary forms, with or without 00010 // modification, are permitted provided that the following conditions are met: 00011 // 00012 // 1. Redistributions of source code must retain the above copyright 00013 // notice, this list of conditions and the following disclaimer. 00014 // 00015 // 2. Redistributions in binary form must reproduce the above copyright 00016 // notice, this list of conditions and the following disclaimer in the 00017 // documentation and/or other materials provided with the distribution. 00018 // 00019 // 3. The name of the authors may not be used to endorse or promote 00020 // products derived from this software without specific prior written 00021 // permission. 00022 // 00023 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00024 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00025 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00026 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00027 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00028 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00029 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00030 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00031 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00032 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00033 // 00034 // This file is part of the PLearn library. For more information on the PLearn 00035 // library, go to the PLearn Web site at www.plearn.org 00036 00037 00038 /* ******************************************************* 00039 * $Id: DiagonalizedFactorsProductVariable.cc,v 1.1 2004/07/19 22:31:11 yoshua Exp $ 00040 * This file is part of the PLearn library. 00041 ******************************************************* */ 00042 00043 #include "DiagonalizedFactorsProductVariable.h" 00044 #include "Var_utils.h" 00045 00046 namespace PLearn { 00047 using namespace std; 00048 00049 00052 PLEARN_IMPLEMENT_OBJECT(DiagonalizedFactorsProductVariable, 00053 "Variable that represents the leftmatrix*diag(vector)*rightmatrix product", 00054 "The three parents are respectively the left matrix U, the center vector d,\n" 00055 "and the right matrix V. Options allow to transpose the matrices.\n" 00056 "The output value has elements (i,j) equal to sum_k U_{ik} d_k V_{kj}\n" 00057 ); 00058 00059 DiagonalizedFactorsProductVariable::DiagonalizedFactorsProductVariable(Var left_matrix, 00060 Var center_diagonal, 00061 Var right_matrix, 00062 bool transpose_left_, 00063 bool transpose_right_) 00064 : inherited(left_matrix & center_diagonal & (VarArray)right_matrix, 00065 transpose_left_?left_matrix->width():left_matrix->length(), 00066 transpose_right_?right_matrix->length():right_matrix->width()), 00067 transpose_left(transpose_left_), transpose_right(transpose_right_) 00068 { 00069 build_(); 00070 } 00071 00072 void 00073 DiagonalizedFactorsProductVariable::build() 00074 { 00075 inherited::build(); 00076 build_(); 00077 } 00078 00079 void 00080 DiagonalizedFactorsProductVariable::build_() 00081 { 00082 if (varray.size()) { 00083 int nl = transpose_left?leftMatrix()->length():leftMatrix()->width(); 00084 int nr = transpose_right?rightMatrix()->width():rightMatrix()->length(); 00085 int nc = centerDiagonal()->size(); 00086 if (nl != nc || nr != nc) 00087 PLERROR("In DiagonalizedFactorsProductVariable: arguments have incompatible sizes!"); 00088 } 00089 } 00090 00091 void DiagonalizedFactorsProductVariable::recomputeSize(int& l, int& w) const 00092 { 00093 if (varray.size()) { 00094 l = transpose_left?varray[0]->width():varray[0]->length(); 00095 w = transpose_right?varray[2]->length():varray[2]->width(); 00096 } else 00097 l = w = 0; 00098 } 00099 00100 void DiagonalizedFactorsProductVariable::fprop() 00101 { 00102 if (transpose_left) 00103 { 00104 if (transpose_right) 00105 diagonalizedFactorsTransposeProductTranspose(matValue,leftMatrix()->matValue,centerDiagonal()->value,rightMatrix()->matValue); 00106 else 00107 diagonalizedFactorsTransposeProduct(matValue,leftMatrix()->matValue,centerDiagonal()->value,rightMatrix()->matValue); 00108 } else { 00109 if (transpose_right) 00110 diagonalizedFactorsProductTranspose(matValue,leftMatrix()->matValue,centerDiagonal()->value,rightMatrix()->matValue); 00111 else 00112 diagonalizedFactorsProduct(matValue,leftMatrix()->matValue,centerDiagonal()->value,rightMatrix()->matValue); 00113 } 00114 } 00115 00116 00117 void DiagonalizedFactorsProductVariable::bprop() 00118 { 00119 if (transpose_left) 00120 { 00121 if (transpose_right) 00122 { 00123 // SINCE res[i,j] = sum_k U[k,i] d[k] V[j,k] ==> 00124 // dC/dU[k,i] = d_k * sum_j dC/dres[i,j] V[j,k] 00125 // dC/dd[k] = sum_{ij} dC/dres[i,j] U[k,i] V[j,k] 00126 // dC/dV[j,k] = d_k * sum_i dC/dres[i,j] U[k,i] 00127 diagonalizedFactorsTransposeProductTransposeBprop(matGradient,leftMatrix()->matValue, 00128 centerDiagonal()->value,rightMatrix()->matValue, 00129 leftMatrix()->matGradient, 00130 centerDiagonal()->gradient,rightMatrix()->matGradient); 00131 } 00132 else 00133 { 00134 // SINCE res[i,j] = sum_k U[k,i] d[k] V[k,j] ==> 00135 // dC/dU[k,i] = d_k * sum_j dC/dres[i,j] V[k,j] 00136 // dC/dd[k] = sum_{ij} dC/dres[i,j] U[k,i] V[k,j] 00137 // dC/dV[k,j] = d_k sum_i dC/dres[i,j] U[k,i] 00138 diagonalizedFactorsTransposeProductBprop(matGradient,leftMatrix()->matValue,centerDiagonal()->value, 00139 rightMatrix()->matValue, leftMatrix()->matGradient, 00140 centerDiagonal()->gradient,rightMatrix()->matGradient); 00141 } 00142 } 00143 else 00144 { 00145 if (transpose_right) 00146 { 00147 // SINCE res[i,j] = sum_k U[i,k] d[k] V[j,k] ==> 00148 // dC/dU[i,k] = sum_j dC/dres[i,j] d_k V[j,k] 00149 // dC/dd[k] = sum_{ij} dC/dres[i,j] U[i,k] V[j,k] 00150 // dC/dV[j,k] = sum_i dC/dres[i,j] d_k U[i,k] 00151 diagonalizedFactorsProductTransposeBprop(matGradient,leftMatrix()->matValue,centerDiagonal()->value, 00152 rightMatrix()->matValue, leftMatrix()->matGradient, 00153 centerDiagonal()->gradient,rightMatrix()->matGradient); 00154 } 00155 else 00156 { 00157 // SINCE res[i,j] = sum_k U[i,k] d[k] V[k,j] ==> 00158 // dC/dU[i,k] += sum_j dC/dres[i,j] d_k V[k,j] 00159 // dC/dd[k] += sum_{ij} dC/dres[i,j] U[i,k] V[k,j] 00160 // dC/dV[k,j] += d_k * sum_i U[i,k] dC/dres[i,j] 00161 diagonalizedFactorsProductBprop(matGradient,leftMatrix()->matValue,centerDiagonal()->value, 00162 rightMatrix()->matValue,leftMatrix()->matGradient, 00163 centerDiagonal()->gradient,rightMatrix()->matGradient); 00164 } 00165 } 00166 } 00167 00168 00169 } // end of namespace PLearn 00170 00171

Generated on Tue Aug 17 15:51:13 2004 for PLearn by doxygen 1.3.7