00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
#include "ArgminOfVariable.h"
00044
00045
namespace PLearn {
00046
using namespace std;
00047
00048
00050
00051
00052 ArgminOfVariable::ArgminOfVariable(
Variable* the_v,
00053
Variable* the_expression,
00054
Variable* the_values_of_v,
00055
const VarArray& the_inputs)
00056 :
NaryVariable(the_inputs,1,1), inputs(the_inputs), expression(the_expression),
00057 values_of_v(the_values_of_v), v(the_v)
00058 {
00059
if (!
v->isScalar() || !
values_of_v->isVec())
00060
PLERROR(
"ArgminOfVariable currently implemented only for a scalar v and a vector values_of_v");
00061
vv_path =
propagationPath(
inputs,
values_of_v);
00062
e_path =
propagationPath(
inputs& (
VarArray)
v,
expression);
00063
v_path =
propagationPath(v,
expression);
00064 }
00065
00066
00067
PLEARN_IMPLEMENT_OBJECT(
ArgminOfVariable,
"ONE LINE DESCR",
"NO HELP");
00068
00069 void ArgminOfVariable::recomputeSize(
int& l,
int& w)
const
00070
{ l=1; w=1; }
00071
00072
00073 void ArgminOfVariable::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies)
00074 {
00075 NaryVariable::makeDeepCopyFromShallowCopy(copies);
00076
deepCopyField(
inputs, copies);
00077
deepCopyField(
expression, copies);
00078
deepCopyField(
values_of_v, copies);
00079
deepCopyField(
v, copies);
00080
deepCopyField(
vv_path, copies);
00081
deepCopyField(
e_path, copies);
00082
deepCopyField(
v_path, copies);
00083 }
00084
00085
00086
00087 void ArgminOfVariable::fprop()
00088 {
00089
vv_path.
fprop();
00090
real min_value_of_expression = FLT_MAX;
00091
real argmin_value_of_v =
values_of_v->value[0];
00092
for (
int i=0;i<
values_of_v->nelems();i++)
00093 {
00094
v->value[0] =
values_of_v->value[i];
00095
if (i==0)
00096
e_path.
fprop();
00097
else
00098
v_path.
fprop();
00099
real e =
expression->value[0];
00100
if (e<min_value_of_expression)
00101 {
00102 min_value_of_expression = e;
00103 argmin_value_of_v =
v->value[0];
00104
index_of_argmin = i;
00105 }
00106 }
00107 value[0] = argmin_value_of_v;
00108 }
00109
00110
00111 void ArgminOfVariable::bprop()
00112 {
00113
vv_path.
clearGradient();
00114
values_of_v->gradientdata[
index_of_argmin] = gradientdata[0];
00115
vv_path.
bprop();
00116 }
00117
00118
00119
00120 }
00121
00122