00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
#include "MiniBatchClassificationLossVariable.h"
00044
00045
namespace PLearn {
00046
using namespace std;
00047
00048
00051
PLEARN_IMPLEMENT_OBJECT(MiniBatchClassificationLossVariable,
00052
"ONE LINE DESCR",
00053
"NO HELP");
00054
00055 MiniBatchClassificationLossVariable::MiniBatchClassificationLossVariable(
Variable* netout,
Variable* classnum)
00056 :
inherited(netout,classnum,classnum->length(),classnum->width())
00057 {
00058
build_();
00059 }
00060
00061
void
00062 MiniBatchClassificationLossVariable::build()
00063 {
00064 inherited::build();
00065
build_();
00066 }
00067
00068
void
00069 MiniBatchClassificationLossVariable::build_()
00070 {
00071
00072
if(input2 && !input2->isVec())
00073
PLERROR(
"In MiniBatchClassificationLossVariable: classnum must be a vector variable representing the indexs of netout (typically class numbers)");
00074 }
00075
00076
00077 void MiniBatchClassificationLossVariable::recomputeSize(
int& l,
int& w)
const
00078
{
00079
if (input2) {
00080 l = input2->
length();
00081 w = input2->
width();
00082 }
else
00083 l = w = 0;
00084 }
00085
00086 void MiniBatchClassificationLossVariable::fprop()
00087 {
00088
int n = input2->size();
00089
if(input1->
length()==n)
00090
for (
int i=0; i<n; i++)
00091 {
00092
int topscorepos =
argmax(input1->matValue.
row(i));
00093
int num =
int(input2->valuedata[i]);
00094 valuedata[i] = (topscorepos==num ?0 :1);
00095 }
00096
else if(input1->
width()==n)
00097
for (
int i=0; i<n; i++)
00098 {
00099
int topscorepos =
argmax(input1->matValue.
column(i));
00100
int num =
int(input2->valuedata[i]);
00101 valuedata[i] = (topscorepos==num ?0 :1);
00102 }
00103
else PLERROR(
"In MiniBatchClassificationLossVariable: The length or width of netout doesn't equal to the size of classnum");
00104 }
00105
00106
00107 void MiniBatchClassificationLossVariable::symbolicBprop()
00108 {
00109
PLERROR(
"MiniBatchClassificationLossVariable::symbolicBprop not implemented.");
00110 }
00111
00112
00113
00114 }
00115
00116