00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00041
#include "ConstantRegressor.h"
00042
00043
namespace PLearn {
00044
using namespace std;
00045
00046 ConstantRegressor::ConstantRegressor()
00047 : weight_decay(0.0)
00048 {
00049 }
00050
00051
PLEARN_IMPLEMENT_OBJECT(
ConstantRegressor,
00052
"PLearner that outputs a constant (input-independent) vector.\n",
00053
"ConstantRegressor is a PLearner that outputs a constant (input-independent\n"
00054
"but training-data-dependent) vector. It is a regressor (i.e. during training\n"
00055
"the constant vector is chosen to minimize the (possibly weighted) average\n"
00056
"of the training set targets. Let\n"
00057
" N = number of training examples,\n"
00058
" M = target size (= output size),\n"
00059
" y_{ij} = the jth target value of the ith training example,\n"
00060
" w_i = weight associated to the ith training example,\n"
00061
"then the j-th component of the learned vector is\n"
00062
" (sum_{i=1}^N w_i * y_ij) / (sum_{i=1}^N w_i)\n"
00063
"The output can also be set manually with the 'constant_output' vector option\n");
00064
00065 void ConstantRegressor::declareOptions(
OptionList& ol)
00066 {
00067
00068
00069
00070
00071
00072
00073
declareOption(ol,
"weight_decay", &ConstantRegressor::weight_decay,
00074 OptionBase::buildoption,
00075
"Weight decay parameter. Default=0. NOT CURRENTLY TAKEN INTO ACCOUNT!");
00076
00077
00078
declareOption(ol,
"constant_output", &ConstantRegressor::constant_output,
00079 OptionBase::learntoption,
00080
"This is the learnt parameter, the constant output. During training\n"
00081
"It is set to the (possibly weighted) average of the targets.\n"
00082 );
00083
00084
00085 inherited::declareOptions(ol);
00086 }
00087
00088 void ConstantRegressor::build_()
00089 {
00090 }
00091
00092
00093 void ConstantRegressor::build()
00094 {
00095 inherited::build();
00096
build_();
00097 }
00098
00099
00100 void ConstantRegressor::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies)
00101 {
00102 inherited::makeDeepCopyFromShallowCopy(copies);
00103 }
00104
00105
00106 int ConstantRegressor::outputsize()
const
00107
{
00108
return targetsize();
00109 }
00110
00111 void ConstantRegressor::forget()
00112 {
00113
00114 }
00115
00116 void ConstantRegressor::train()
00117 {
00118
00119
00120
00121
Vec input;
00122
Vec target;
00123
Vec train_costs;
00124
Vec sum_of_weighted_targets;
00125
real weight;
00126 train_costs.
resize(1);
00127 input.
resize(
inputsize());
00128 target.
resize(
targetsize());
00129 sum_of_weighted_targets.
resize(
targetsize());
00130
constant_output.
resize(
targetsize());
00131
00132
if(!train_stats)
00133 train_stats =
new VecStatsCollector();
00134
00135
real sum_of_weights = 0;
00136 sum_of_weighted_targets.
clear();
00137
00138
int n_examples = train_set->
length();
00139
for (
int i=0;i<n_examples;i++)
00140 {
00141 train_set->
getExample(i, input, target, weight);
00142
multiplyAdd(sum_of_weighted_targets,target,weight,sum_of_weighted_targets);
00143 sum_of_weights += weight;
00144
multiply(sum_of_weighted_targets,
real(1.0/sum_of_weights),
constant_output);
00145 train_costs[0] =
00146 weight*
powdistance(
constant_output,target);
00147 train_stats->update(train_costs);
00148 }
00149 train_stats->finalize();
00150 }
00151
00152
00153 void ConstantRegressor::computeOutput(
const Vec& input,
Vec& output)
const
00154
{
00155
00156 output.
resize(
outputsize());
00157 output <<
constant_output;
00158 }
00159
00160 void ConstantRegressor::computeCostsFromOutputs(
const Vec& input,
const Vec& output,
00161
const Vec& target,
Vec& costs)
const
00162
{
00163
00164 costs[0] =
powdistance(output,target);
00165 }
00166
00167 TVec<string> ConstantRegressor::getTestCostNames()
const
00168
{
00169
00170
return getTrainCostNames();
00171 }
00172
00173 TVec<string> ConstantRegressor::getTrainCostNames()
const
00174
{
00175
00176
00177
return TVec<string>(1,
"mse");
00178 }
00179
00180
00181
00182 }