00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
#include "VarElementVariable.h"
00044
#include "ElementAtPositionVariable.h"
00045
00046
namespace PLearn {
00047
using namespace std;
00048
00049
00052
PLEARN_IMPLEMENT_OBJECT(VarElementVariable,
00053
"ONE LINE DESCR",
00054
"NO HELP");
00055
00056 VarElementVariable::VarElementVariable(
Variable* input1,
Variable* input2)
00057 :
inherited(input1, input2, 1, 1)
00058 {
00059
build_();
00060 }
00061
00062
void
00063 VarElementVariable::build()
00064 {
00065 inherited::build();
00066
build_();
00067 }
00068
00069
void
00070 VarElementVariable::build_()
00071 {
00072
if (input2 && input2->nelems() > 2)
00073
PLERROR(
"IN VarElementVariable(Variable* input1, Variable* input2) input2 must have 1 (a k position index) or 2 elements (an i,j position index)");
00074 }
00075
00076 void VarElementVariable::recomputeSize(
int& l,
int& w)
const
00077
{ l=1; w=1; }
00078
00079 void VarElementVariable::fprop()
00080 {
00081
if(input2->isScalar())
00082 {
00083
int k =
int(input2->valuedata[0]);
00084
#ifdef BOUNDCHECK
00085
if (
k >= input1->
length())
00086
PLERROR(
"VarElementVariable::fprop() - k = %d is out of range (size is %d)",
k, input1->
length());
00087
#endif
00088
valuedata[0] = input1->valuedata[
k];
00089 }
00090
else
00091 {
00092
int i =
int(input2->valuedata[0]);
00093
int j = int(input2->valuedata[1]);
00094
#ifdef BOUNDCHECK
00095
if ((i * input1->
width() + j) >= input1->
width() * input1->
length())
00096
PLERROR(
"VarElementVariable::fprop() - (%d, %d) out of range"
00097
"(size is %d)", i, j, input1->
length() * input1->
width());
00098
#endif
00099
valuedata[0] = input1->valuedata[i*input1->
width()+j];
00100 }
00101 }
00102
00103
00104 void VarElementVariable::bprop()
00105 {
00106
if(input2->isScalar())
00107 {
00108
int k =
int(input2->valuedata[0]);
00109 input1->gradientdata[
k] += gradientdata[0];
00110 }
00111
else
00112 {
00113
int i =
int(input2->valuedata[0]);
00114
int j = int(input2->valuedata[1]);
00115 input1->gradientdata[i*input1->
width()+j] += gradientdata[0];
00116 }
00117 }
00118
00119
00120 void VarElementVariable::symbolicBprop()
00121 {
00122 input1->accg(
new ElementAtPositionVariable(g,input2,input1->
length(),input1->
width()));
00123 }
00124
00125
00126 void VarElementVariable::rfprop()
00127 {
00128
if (rValue.
length()==0)
resizeRValue();
00129
if(input2->isScalar())
00130 {
00131
int k =
int(input2->valuedata[0]);
00132 rvaluedata[0] = input1->rvaluedata[
k];
00133 }
00134
else
00135 {
00136
int i =
int(input2->valuedata[0]);
00137
int j = int(input2->valuedata[1]);
00138 rvaluedata[0] = input1->rvaluedata[i*input1->
width()+j];
00139 }
00140 }
00141
00142
00143
00144 }
00145
00146