00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 
00037 
00038 
00039 
00040 
00041 
00042 
00043 
#include "DiagonalizedFactorsProductVariable.h"
00044 
#include "Var_utils.h"
00045 
00046 
namespace PLearn {
00047 
using namespace std;
00048 
00049 
00052 
PLEARN_IMPLEMENT_OBJECT(DiagonalizedFactorsProductVariable,
00053                         
"Variable that represents the leftmatrix*diag(vector)*rightmatrix product",
00054                         
"The three parents are respectively the left matrix U, the center vector d,\n"
00055                         
"and the right matrix V. Options allow to transpose the matrices.\n"
00056                         
"The output value has elements (i,j) equal to sum_k U_{ik} d_k V_{kj}\n"
00057                         );
00058 
00059 DiagonalizedFactorsProductVariable::DiagonalizedFactorsProductVariable(
Var left_matrix, 
00060                                                                        
Var center_diagonal, 
00061                                                                        
Var right_matrix,
00062                                                                        
bool transpose_left_, 
00063                                                                        
bool transpose_right_)
00064     : 
inherited(left_matrix & center_diagonal & (
VarArray)right_matrix,
00065                 transpose_left_?left_matrix->width():left_matrix->length(), 
00066                 transpose_right_?right_matrix->length():right_matrix->width()),
00067       transpose_left(transpose_left_), transpose_right(transpose_right_)
00068 {
00069     
build_();
00070 }
00071 
00072 
void
00073 DiagonalizedFactorsProductVariable::build()
00074 {
00075     inherited::build();
00076     
build_();
00077 }
00078 
00079 
void
00080 DiagonalizedFactorsProductVariable::build_()
00081 {
00082     
if (varray.
size()) {
00083       
int nl = 
transpose_left?
leftMatrix()->
length():
leftMatrix()->
width();
00084       
int nr = 
transpose_right?
rightMatrix()->
width():
rightMatrix()->
length();
00085       
int nc = 
centerDiagonal()->size();
00086       
if (nl != nc || nr != nc)
00087         
PLERROR(
"In DiagonalizedFactorsProductVariable: arguments have incompatible sizes!");
00088     }
00089 }
00090 
00091 void DiagonalizedFactorsProductVariable::recomputeSize(
int& l, 
int& w)
 const
00092 
{
00093     
if (varray.
size()) {
00094       l = 
transpose_left?varray[0]->width():varray[0]->length();
00095       w = 
transpose_right?varray[2]->length():varray[2]->width();
00096     } 
else
00097       l = w = 0;
00098 }
00099 
00100 void DiagonalizedFactorsProductVariable::fprop()
00101 {
00102   
if (
transpose_left)
00103     {
00104       
if (
transpose_right)
00105         
diagonalizedFactorsTransposeProductTranspose(matValue,
leftMatrix()->matValue,
centerDiagonal()->value,
rightMatrix()->matValue);
00106       
else
00107         
diagonalizedFactorsTransposeProduct(matValue,
leftMatrix()->matValue,
centerDiagonal()->value,
rightMatrix()->matValue);
00108     } 
else {
00109       
if (
transpose_right)
00110         
diagonalizedFactorsProductTranspose(matValue,
leftMatrix()->matValue,
centerDiagonal()->value,
rightMatrix()->matValue);
00111       
else
00112         
diagonalizedFactorsProduct(matValue,
leftMatrix()->matValue,
centerDiagonal()->value,
rightMatrix()->matValue);
00113     }
00114 }
00115 
00116 
00117 void DiagonalizedFactorsProductVariable::bprop()
00118 {
00119   
if (
transpose_left)
00120   {
00121     
if (
transpose_right)
00122     {
00123       
00124       
00125       
00126       
00127       
diagonalizedFactorsTransposeProductTransposeBprop(matGradient,
leftMatrix()->matValue,
00128                                                         
centerDiagonal()->value,
rightMatrix()->matValue,
00129                                                         
leftMatrix()->matGradient,
00130                                                         
centerDiagonal()->gradient,
rightMatrix()->matGradient); 
00131     }
00132     
else
00133     {
00134       
00135       
00136       
00137       
00138       
diagonalizedFactorsTransposeProductBprop(matGradient,
leftMatrix()->matValue,
centerDiagonal()->value,
00139                                                
rightMatrix()->matValue, 
leftMatrix()->matGradient,
00140                                                
centerDiagonal()->gradient,
rightMatrix()->matGradient); 
00141     }
00142   } 
00143   
else 
00144   {
00145     
if (
transpose_right)
00146     {
00147       
00148       
00149       
00150       
00151       
diagonalizedFactorsProductTransposeBprop(matGradient,
leftMatrix()->matValue,
centerDiagonal()->value,
00152                                                
rightMatrix()->matValue, 
leftMatrix()->matGradient,
00153                                                
centerDiagonal()->gradient,
rightMatrix()->matGradient); 
00154     }
00155     
else
00156     {    
00157       
00158       
00159       
00160       
00161       
diagonalizedFactorsProductBprop(matGradient,
leftMatrix()->matValue,
centerDiagonal()->value,
00162                                       
rightMatrix()->matValue,
leftMatrix()->matGradient,
00163                                       
centerDiagonal()->gradient,
rightMatrix()->matGradient);
00164     }  
00165   }
00166 }
00167 
00168 
00169 } 
00170 
00171