00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
#include "SelectedOutputCostFunction.h"
00044
#include "SquaredErrorCostFunction.h"
00045
00046
namespace PLearn {
00047
using namespace std;
00048
00049
00050
PLEARN_IMPLEMENT_OBJECT(SquaredErrorCostFunction,
"ONE LINE DESCR",
"NO HELP");
00051
00052 real SquaredErrorCostFunction::evaluate(
const Vec& output,
const Vec& target)
const
00053
{
00054
#ifdef BOUNDCHECK
00055
if(target.
length()!=output.
length() &&
classification==
false)
00056
PLERROR(
"In SquaredErrorCostFunction::evaluate target.length() %d should be equal to output.length() %d",target.
length(),output.
length());
00057
#endif
00058
00059
real result = 0.0;
00060
if (
targetindex>=0)
00061 result =
square(output[
targetindex]-target[
targetindex]);
00062
else
00063 {
00064
real* outputdata = output.
data();
00065
real* targetdata = target.
data();
00066
if (
classification) {
00067
if (target.
length() != 1)
00068
PLERROR(
"In SquaredErrorCostFunction::evaluate target.length() %s should be 1", target.
length());
00069
00070
for (
int i = 0; i < output.
length(); ++i)
00071
if (i == targetdata[0])
00072 result +=
square(outputdata[i] -
hotvalue);
00073
else
00074 result +=
square(outputdata[i] -
coldvalue);
00075 }
else {
00076
for(
int i=0; i<output.
length(); i++)
00077 result +=
square(outputdata[i]-targetdata[i]);
00078 }
00079 }
00080
return result;
00081 }
00082
00083 void SquaredErrorCostFunction::declareOptions(
OptionList &ol)
00084 {
00085
declareOption(ol,
"targetindex", &SquaredErrorCostFunction::targetindex, OptionBase::buildoption,
"Index of target");
00086
declareOption(ol,
"hotvalue", &SquaredErrorCostFunction::hotvalue, OptionBase::buildoption,
"Hot value");
00087
declareOption(ol,
"coldvalue", &SquaredErrorCostFunction::coldvalue, OptionBase::buildoption,
"Cold value");
00088
declareOption(ol,
"classification", &SquaredErrorCostFunction::classification, OptionBase::buildoption,
"Used as classification cost");
00089 inherited::declareOptions(ol);
00090 }
00091
00092 CostFunc squared_error(
int singleoutputindex)
00093 {
00094
if(singleoutputindex>=0)
00095
return new SelectedOutputCostFunction(
new SquaredErrorCostFunction(),singleoutputindex);
00096
else
00097
return new SquaredErrorCostFunction();
00098 }
00099
00100 }
00101