00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
#include "IfThenElseVariable.h"
00044
#include "Var_utils.h"
00045
00046
namespace PLearn {
00047
using namespace std;
00048
00049
00052
PLEARN_IMPLEMENT_OBJECT(IfThenElseVariable,
00053
"Variable that represents the element-wise IF-THEN-ELSE",
00054
"NO HELP");
00055
00056 IfThenElseVariable::IfThenElseVariable(
Var IfVar,
Var ThenVar,
Var ElseVar)
00057 :
inherited(IfVar & ThenVar & (
VarArray)ElseVar,ThenVar->length(), ThenVar->width())
00058 {
00059
build_();
00060 }
00061
00062
void
00063 IfThenElseVariable::build()
00064 {
00065 inherited::build();
00066
build_();
00067 }
00068
00069
void
00070 IfThenElseVariable::build_()
00071 {
00072
if (varray.
size()) {
00073
if (varray[1]->
length() != varray[2]->
length() || varray[1]->width() != varray[2]->width())
00074
PLERROR(
"In IfThenElseVariable: ElseVar and ThenVar must have the same size");
00075
if (!varray[0]->isScalar() && (varray[0]->
length() != varray[1]->
length() || varray[0]->width() != varray[1]->width()))
00076
PLERROR(
"In IfThenElseVariable: IfVar must either be a scalar or have the same size as ThenVar and ElseVar");
00077 }
00078 }
00079
00080 void IfThenElseVariable::recomputeSize(
int& l,
int& w)
const
00081
{
00082
if (varray.
size()) {
00083 l = varray[1]->
length();
00084 w = varray[1]->width();
00085 }
else
00086 l = w = 0;
00087 }
00088
00089 void IfThenElseVariable::fprop()
00090 {
00091
if(
If()->isScalar())
00092 {
00093
00094
00095
bool test =
If()->valuedata[0] == 0 ?
false :
true;
00096
if (test)
00097 value <<
Then()->value;
00098
else
00099 value <<
Else()->value;
00100 }
00101
else
00102 {
00103
real* ifv =
If()->valuedata;
00104
real* thenv =
Then()->valuedata;
00105
real* elsev =
Else()->valuedata;
00106
for (
int k=0;
k<
nelems();
k++)
00107 {
00108
00109
00110
if ( ifv[
k] == 0 ?
false:
true )
00111 valuedata[
k]=thenv[
k];
00112
else
00113 valuedata[
k]=elsev[
k];
00114 }
00115 }
00116 }
00117
00118
00119 void IfThenElseVariable::bprop()
00120 {
00121
if(
If()->isScalar())
00122 {
00123
00124
00125
bool test =
If()->valuedata[0] == 0 ?
false :
true;
00126
if (test)
00127
Then()->gradient += gradient;
00128
else
00129
Else()->gradient += gradient;
00130 }
00131
else
00132 {
00133
real* ifv =
If()->valuedata;
00134
real* theng =
Then()->gradientdata;
00135
real* elseg =
Else()->gradientdata;
00136
for (
int k=0;
k<
nelems();
k++)
00137 {
00138
00139
00140
if ( ifv[
k] == 0 ?
false:
true )
00141 theng[
k] += gradientdata[
k];
00142
else
00143 elseg[
k] += gradientdata[
k];
00144 }
00145 }
00146 }
00147
00148
00149 void IfThenElseVariable::symbolicBprop()
00150 {
00151
Var zero(
length(),
width());
00152
Then()->accg(
ifThenElse(
If(), g, zero));
00153
Else()->accg(
ifThenElse(
If(), zero, g));
00154 }
00155
00156
00157 void IfThenElseVariable::rfprop()
00158 {
00159
if (rValue.
length()==0)
resizeRValue();
00160
if(
If()->isScalar())
00161 {
00162
00163
00164
bool test =
If()->valuedata[0] == 0 ?
false :
true;
00165
if (test)
00166 rValue <<
Then()->rValue;
00167
else
00168 rValue <<
Else()->rValue;
00169 }
00170
else
00171 {
00172
real* ifv =
If()->valuedata;
00173
real* rthenv =
Then()->rvaluedata;
00174
real* relsev =
Else()->rvaluedata;
00175
for (
int k=0;
k<
nelems();
k++)
00176 {
00177
00178
00179
if ( ifv[
k] == 0 ?
false:
true )
00180 rvaluedata[
k]=rthenv[
k];
00181
else
00182 rvaluedata[
k]=relsev[
k];
00183 }
00184 }
00185 }
00186
00187
00188
00189 }
00190
00191