00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00044
#include <plearn/var/AffineTransformVariable.h>
00045
#include <plearn/var/AffineTransformWeightPenalty.h>
00046
#include <plearn/var/BinaryClassificationLossVariable.h>
00047
#include <plearn/var/ClassificationLossVariable.h>
00048
#include <plearn/var/ConcatColumnsVariable.h>
00049
#include <plearn/vmat/ConcatColumnsVMatrix.h>
00050
#include <plearn/var/CrossEntropyVariable.h>
00051
#include <plearn/var/ExpVariable.h>
00052
#include <plearn/var/LogVariable.h>
00053
#include <plearn/var/LiftOutputVariable.h>
00054
#include <plearn/var/LogSoftmaxVariable.h>
00055
#include <plearn/var/MulticlassLossVariable.h>
00056
#include "MultiInstanceNNet.h"
00057
#include <plearn/var/UnfoldedSumOfVariable.h>
00058
#include <plearn/var/SumOverBagsVariable.h>
00059
#include <plearn/var/SumSquareVariable.h>
00060
#include <plearn/math/random.h>
00061
#include <plearn/var/SigmoidVariable.h>
00062
#include <plearn/var/SumVariable.h>
00063
#include <plearn/var/SumAbsVariable.h>
00064
#include <plearn/var/SumOfVariable.h>
00065
#include <plearn/vmat/SubVMatrix.h>
00066
#include <plearn/var/TanhVariable.h>
00067
#include <plearn/var/TransposeProductVariable.h>
00068
#include <plearn/var/Var_operators.h>
00069
#include <plearn/var/Var_utils.h>
00070
00071
00072
00073
00074
namespace PLearn {
00075
using namespace std;
00076
00077
PLEARN_IMPLEMENT_OBJECT(MultiInstanceNNet,
00078
"Multi-instance feedforward neural network for probabilistic classification",
00079
"The data has the form of a set of input vectors x_i associated with a single\n"
00080
"label y. Each x_i is an instance and the overall set of instance is called a bag.\n"
00081
"We don't know which of the inputs is responsible for the label, i.e.\n"
00082
"there are hidden (not observed) labels y_i associated with each of the inputs x_i.\n"
00083
"We also know that y=1 if at least one of the y_i is 1, otherwise y=0, i.e.\n"
00084
" y = y_1 or y_2 or ... y_m\n"
00085
"In terms of probabilities, it means that\n"
00086
" P(Y=0|x_1..x_m) = \\prod_{i=1}^m P(y_i=0|x_i)\n"
00087
"which determines the likelihood of the observation (x_1...x_m,y).\n"
00088
"The neural network implements the computation of P(y_i=1|x_i). The same\n"
00089
"model is assumed for all instances in the bag. The number of instances is variable but\n"
00090
"bounded a-priori (max_n_instances). The gradient is computed for a whole bag\n"
00091
"at a time. The architectural parameters and hyper-parameters of the model\n"
00092
"are otherwise the same as for the generic NNet class.\n"
00093
"The bags within each data set are specified with a 2nd target column\n"
00094
"(the first column is 0, 1 or missing; it should not be missing for the\n"
00095
"last column of the bag). The second target column should be 0,1,2, or 3:\n"
00096
" 1: first row of a bag\n"
00097
" 2: last row of a bag\n"
00098
" 3: simultaneously first and last, there is only one row in this bag\n"
00099
" 0: intermediate row of a bag\n"
00100
"following the protocol expected by the SumOverBagsVariable.\n"
00101 );
00102
00103 MultiInstanceNNet::MultiInstanceNNet()
00104 : training_set_has_changed(false),
00105 max_n_instances(1),
00106 nhidden(0),
00107 nhidden2(0),
00108 weight_decay(0),
00109 bias_decay(0),
00110 layer1_weight_decay(0),
00111 layer1_bias_decay(0),
00112 layer2_weight_decay(0),
00113 layer2_bias_decay(0),
00114 output_layer_weight_decay(0),
00115 output_layer_bias_decay(0),
00116 direct_in_to_out_weight_decay(0),
00117 L1_penalty(false),
00118 direct_in_to_out(false),
00119 interval_minval(0), interval_maxval(1),
00120 test_bag_size(0),
00121 batch_size(1)
00122 {}
00123
00124 MultiInstanceNNet::~MultiInstanceNNet()
00125 {
00126 }
00127
00128 void MultiInstanceNNet::declareOptions(
OptionList& ol)
00129 {
00130
declareOption(ol,
"max_n_instances", &MultiInstanceNNet::max_n_instances, OptionBase::buildoption,
00131
" maximum number of instances (input vectors x_i) allowed\n");
00132
00133
declareOption(ol,
"nhidden", &MultiInstanceNNet::nhidden, OptionBase::buildoption,
00134
" number of hidden units in first hidden layer (0 means no hidden layer)\n");
00135
00136
declareOption(ol,
"nhidden2", &MultiInstanceNNet::nhidden2, OptionBase::buildoption,
00137
" number of hidden units in second hidden layer (0 means no hidden layer)\n");
00138
00139
declareOption(ol,
"weight_decay", &MultiInstanceNNet::weight_decay, OptionBase::buildoption,
00140
" global weight decay for all layers\n");
00141
00142
declareOption(ol,
"bias_decay", &MultiInstanceNNet::bias_decay, OptionBase::buildoption,
00143
" global bias decay for all layers\n");
00144
00145
declareOption(ol,
"layer1_weight_decay", &MultiInstanceNNet::layer1_weight_decay, OptionBase::buildoption,
00146
" Additional weight decay for the first hidden layer. Is added to weight_decay.\n");
00147
declareOption(ol,
"layer1_bias_decay", &MultiInstanceNNet::layer1_bias_decay, OptionBase::buildoption,
00148
" Additional bias decay for the first hidden layer. Is added to bias_decay.\n");
00149
00150
declareOption(ol,
"layer2_weight_decay", &MultiInstanceNNet::layer2_weight_decay, OptionBase::buildoption,
00151
" Additional weight decay for the second hidden layer. Is added to weight_decay.\n");
00152
00153
declareOption(ol,
"layer2_bias_decay", &MultiInstanceNNet::layer2_bias_decay, OptionBase::buildoption,
00154
" Additional bias decay for the second hidden layer. Is added to bias_decay.\n");
00155
00156
declareOption(ol,
"output_layer_weight_decay", &MultiInstanceNNet::output_layer_weight_decay, OptionBase::buildoption,
00157
" Additional weight decay for the output layer. Is added to 'weight_decay'.\n");
00158
00159
declareOption(ol,
"output_layer_bias_decay", &MultiInstanceNNet::output_layer_bias_decay, OptionBase::buildoption,
00160
" Additional bias decay for the output layer. Is added to 'bias_decay'.\n");
00161
00162
declareOption(ol,
"direct_in_to_out_weight_decay", &MultiInstanceNNet::direct_in_to_out_weight_decay, OptionBase::buildoption,
00163
" Additional weight decay for the direct in-to-out layer. Is added to 'weight_decay'.\n");
00164
00165
declareOption(ol,
"L1_penalty", &MultiInstanceNNet::L1_penalty, OptionBase::buildoption,
00166
" should we use L1 penalty instead of the default L2 penalty on the weights?\n");
00167
00168
declareOption(ol,
"direct_in_to_out", &MultiInstanceNNet::direct_in_to_out, OptionBase::buildoption,
00169
" should we include direct input to output connections?\n");
00170
00171
declareOption(ol,
"optimizer", &MultiInstanceNNet::optimizer, OptionBase::buildoption,
00172
" specify the optimizer to use\n");
00173
00174
declareOption(ol,
"batch_size", &MultiInstanceNNet::batch_size, OptionBase::buildoption,
00175
" how many samples to use to estimate the avergage gradient before updating the weights\n"
00176
" 0 is equivalent to specifying training_set->n_non_missing_rows() \n");
00177
00178
declareOption(ol,
"paramsvalues", &MultiInstanceNNet::paramsvalues, OptionBase::learntoption,
00179
" The learned parameter vector\n");
00180
00181 inherited::declareOptions(ol);
00182
00183 }
00184
00185 void MultiInstanceNNet::build()
00186 {
00187 inherited::build();
00188
build_();
00189 }
00190
00191 void MultiInstanceNNet::setTrainingSet(
VMat training_set,
bool call_forget)
00192 {
00193
training_set_has_changed =
00194 !train_set || train_set->
width()!=training_set->
width() ||
00195 train_set->
length()!=training_set->
length() || train_set->inputsize()!=training_set->inputsize()
00196 || train_set->weightsize()!= training_set->weightsize();
00197
00198 train_set = training_set;
00199
if (
training_set_has_changed)
00200 {
00201 inputsize_ = train_set->inputsize();
00202 targetsize_ = train_set->targetsize();
00203 weightsize_ = train_set->weightsize();
00204 }
00205
00206
if (
training_set_has_changed || call_forget)
00207 {
00208
build();
00209
if (call_forget)
forget();
00210 }
00211
00212 }
00213
00214 void MultiInstanceNNet::build_()
00215 {
00216
00217
00218
00219
00220
00221
00222
00223
if(inputsize_>=0 && targetsize_>=0 && weightsize_>=0)
00224 {
00225
00226
00227
00228
input =
Var(
inputsize(),
"input");
00229
output =
input;
00230
params.
resize(0);
00231
00232
if (
targetsize()!=2)
00233
PLERROR(
"MultiInstanceNNet:: expected the data to have 2 target columns, got %d",
00234
targetsize());
00235
00236
00237
if(
nhidden>0)
00238 {
00239
w1 = Var(1+
inputsize(),
nhidden,
"w1");
00240
output =
tanh(
affine_transform(
output,
w1));
00241
params.
append(
w1);
00242 }
00243
00244
00245
if(
nhidden2>0)
00246 {
00247
w2 = Var(1+
nhidden,
nhidden2,
"w2");
00248
output =
tanh(
affine_transform(
output,
w2));
00249
params.
append(
w2);
00250 }
00251
00252
if (
nhidden2>0 &&
nhidden==0)
00253
PLERROR(
"MultiInstanceNNet:: can't have nhidden2 (=%d) > 0 while nhidden=0",
nhidden2);
00254
00255
00256
wout = Var(1+
output->size(),
outputsize(),
"wout");
00257
output =
affine_transform(
output,
wout);
00258
params.
append(
wout);
00259
00260
00261
if(
direct_in_to_out)
00262 {
00263
wdirect = Var(
inputsize(),
outputsize(),
"wdirect");
00264
output +=
transposeProduct(
wdirect, input);
00265
params.
append(
wdirect);
00266 }
00267
00268
00269
00270
output =
sigmoid(
output);
00271
00272
00273
00274
00275
00276
target = Var(1,
"target");
00277
00278
if(weightsize_>0)
00279 {
00280
if (weightsize_!=1)
00281
PLERROR(
"MultiInstanceNNet: expected weightsize to be 1 or 0 (or unspecified = -1, meaning 0), got %d",weightsize_);
00282
sampleweight = Var(1,
"weight");
00283 }
00284
00285
00286
penalties.
resize(0);
00287
if(
w1 && ((
layer1_weight_decay +
weight_decay)!=0 || (
layer1_bias_decay +
bias_decay)!=0))
00288
penalties.
append(
affine_transform_weight_penalty(
w1, (
layer1_weight_decay +
weight_decay), (
layer1_bias_decay +
bias_decay),
L1_penalty));
00289
if(
w2 && ((
layer2_weight_decay + weight_decay)!=0 || (
layer2_bias_decay + bias_decay)!=0))
00290
penalties.
append(
affine_transform_weight_penalty(
w2, (
layer2_weight_decay + weight_decay), (
layer2_bias_decay + bias_decay),
L1_penalty));
00291
if(
wout && ((
output_layer_weight_decay + weight_decay)!=0 || (
output_layer_bias_decay + bias_decay)!=0))
00292
penalties.
append(
affine_transform_weight_penalty(
wout, (
output_layer_weight_decay + weight_decay),
00293 (
output_layer_bias_decay + bias_decay),
L1_penalty));
00294
if(
wdirect && (
direct_in_to_out_weight_decay + weight_decay) != 0)
00295 {
00296
if (
L1_penalty)
00297
penalties.
append(
sumabs(
wdirect)*(
direct_in_to_out_weight_decay + weight_decay));
00298
else
00299
penalties.
append(
sumsquare(
wdirect)*(
direct_in_to_out_weight_decay + weight_decay));
00300 }
00301
00302
00303
if((
bool)
paramsvalues && (
paramsvalues.
size() ==
params.
nelems()))
00304
params <<
paramsvalues;
00305
else
00306 {
00307 paramsvalues.
resize(
params.
nelems());
00308
initializeParams();
00309 }
00310
params.
makeSharedValue(paramsvalues);
00311
00312
output->setName(
"element output");
00313
00314
f =
Func(input,
output);
00315
00316
input_to_logP0 = Func(input,
log(1 -
output));
00317
00318
bag_size = Var(1,1);
00319
bag_inputs = Var(
max_n_instances,
inputsize());
00320
bag_output = 1-
exp(
unfoldedSumOf(
bag_inputs,
bag_size,
input_to_logP0,
max_n_instances));
00321
00322
costs.
resize(3);
00323
00324
costs[0] =
cross_entropy(
bag_output,
target);
00325
costs[1] =
binary_classification_loss(
bag_output,
target);
00326
costs[2] =
lift_output(
bag_output,
target);
00327
test_costs =
hconcat(
costs);
00328
00329
00330
00331
00332
if(
penalties.
size() != 0) {
00333
if (weightsize_>0)
00334
00335
training_cost =
hconcat(
sampleweight*
sum(
hconcat(
costs[0] &
penalties))
00336 & (
costs[0]*
sampleweight) & (
costs[1]*sampleweight) &
costs[2]);
00337
else {
00338
training_cost =
hconcat(
sum(
hconcat(costs[0] & penalties)) &
test_costs);
00339 }
00340 }
00341
else {
00342
if(weightsize_>0) {
00343
00344
training_cost =
hconcat(
costs[0]*
sampleweight &
costs[0]*
sampleweight &
costs[1]*
sampleweight &
costs[2]);
00345 }
else {
00346
training_cost =
hconcat(
costs[0] &
test_costs);
00347 }
00348 }
00349
00350
training_cost->setName(
"training_cost");
00351
test_costs->setName(
"test_costs");
00352
00353
if (weightsize_>0)
00354
invars =
bag_inputs &
bag_size &
target &
sampleweight;
00355
else
00356
invars =
bag_inputs &
bag_size &
target;
00357
00358
inputs_and_targets_to_test_costs = Func(
invars,
test_costs);
00359
inputs_and_targets_to_training_costs = Func(
invars,
training_cost);
00360
00361
inputs_and_targets_to_test_costs->recomputeParents();
00362
inputs_and_targets_to_training_costs->recomputeParents();
00363
00364
00365 }
00366 }
00367
00368 int MultiInstanceNNet::outputsize()
const
00369
{
return 1; }
00370
00371 TVec<string> MultiInstanceNNet::getTrainCostNames()
const
00372
{
00373
TVec<string> names(4);
00374 names[0] =
"NLL+penalty";
00375 names[1] =
"NLL";
00376 names[2] =
"class_error";
00377 names[3] =
"lift_output";
00378
return names;
00379 }
00380
00381 TVec<string> MultiInstanceNNet::getTestCostNames()
const
00382
{
00383
TVec<string> names(3);
00384 names[0] =
"NLL";
00385 names[1] =
"class_error";
00386 names[2] =
"lift_output";
00387
return names;
00388 }
00389
00390
00391 void MultiInstanceNNet::train()
00392 {
00393
00394
00395
00396
00397
if(!train_set)
00398
PLERROR(
"In MultiInstanceNNet::train, you did not setTrainingSet");
00399
00400
if(!train_stats)
00401
PLERROR(
"In MultiInstanceNNet::train, you did not setTrainStatsCollector");
00402
00403
if(
f.
isNull())
00404
build();
00405
00406
00407
if (
training_set_has_changed)
00408 {
00409
00410
optstage_per_lstage = 0;
00411
int n_bags = -1;
00412
if (
batch_size<=0)
00413
optstage_per_lstage = 1;
00414
else
00415 {
00416 n_bags=0;
00417
int l = train_set->
length();
00418
ProgressBar* pb = 0;
00419
if(report_progress)
00420 pb =
new ProgressBar(
"Counting nb bags in train_set for MultiInstanceNNet ", l);
00421
Vec row(train_set->
width());
00422
int tag_column = train_set->inputsize() + train_set->targetsize() - 1;
00423
for (
int i=0;i<l;i++) {
00424 train_set->getRow(i,row);
00425
int tag = (
int)row[tag_column];
00426
if (tag & SumOverBagsVariable::TARGET_COLUMN_FIRST) {
00427
00428 n_bags++;
00429 }
00430
if(pb)
00431 pb->
update(i);
00432 }
00433
if(pb)
00434
delete pb;
00435
optstage_per_lstage = n_bags/
batch_size;
00436 }
00437
training_set_has_changed =
false;
00438 }
00439
00440
Var totalcost =
sumOverBags(train_set,
inputs_and_targets_to_training_costs,
max_n_instances,
batch_size);
00441
if(
optimizer)
00442 {
00443
optimizer->setToOptimize(
params, totalcost);
00444
optimizer->build();
00445 }
00446
00447
00448
ProgressBar* pb = 0;
00449
if(report_progress)
00450 pb =
new ProgressBar(
"Training MultiInstanceNNet from stage " +
tostring(stage) +
" to " +
tostring(nstages), nstages-stage);
00451
00452
int initial_stage = stage;
00453
bool early_stop=
false;
00454
while(stage<nstages && !early_stop)
00455 {
00456
optimizer->nstages =
optstage_per_lstage;
00457 train_stats->forget();
00458
optimizer->early_stop =
false;
00459
optimizer->optimizeN(*train_stats);
00460 train_stats->finalize();
00461
if(verbosity>2)
00462 cout <<
"Epoch " << stage <<
" train objective: " << train_stats->getMean() <<
endl;
00463 ++stage;
00464
if(pb)
00465 pb->
update(stage-initial_stage);
00466 }
00467
if(verbosity>1)
00468 cout <<
"EPOCH " << stage <<
" train objective: " << train_stats->getMean() <<
endl;
00469
00470
if(pb)
00471
delete pb;
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481 }
00482
00483
00484 void MultiInstanceNNet::computeOutput(
const Vec& inputv,
Vec& outputv)
const
00485
{
00486
f->fprop(inputv,outputv);
00487 }
00488
00490
00492 void MultiInstanceNNet::computeOutputAndCosts(
const Vec& inputv,
const Vec& targetv,
00493
Vec& outputv,
Vec& costsv)
const
00494
{
00495
f->fprop(inputv,outputv);
00496
00497
00498
00499
int bag_signal =
int(targetv[1]);
00500
if (bag_signal & 1)
00501
test_bag_size=0;
00502
bag_inputs->matValue(
test_bag_size++) << inputv;
00503
if (!(bag_signal & 2))
00504 costsv.
fill(
MISSING_VALUE);
00505
else
00506 {
00507
bag_size->valuedata[0]=
test_bag_size;
00508
target->valuedata[0] = targetv[0];
00509
if (weightsize_>0)
sampleweight->valuedata[0]=1;
00510
inputs_and_targets_to_test_costs->fproppath.fprop();
00511
inputs_and_targets_to_test_costs->outputs.copyTo(costsv);
00512 }
00513 }
00514
00516
00518 void MultiInstanceNNet::computeCostsFromOutputs(
const Vec& inputv,
const Vec& outputv,
00519
const Vec& targetv,
Vec& costsv)
const
00520
{
00521
instance_logP0.
resize(
max_n_instances);
00522
int bag_signal =
int(targetv[1]);
00523
if (bag_signal & 1)
00524
test_bag_size=0;
00525
instance_logP0[
test_bag_size++] =
safeflog(1-outputv[0]);
00526
if (!(bag_signal & 2))
00527 costsv.
fill(
MISSING_VALUE);
00528
else
00529 {
00530
instance_logP0.
resize(
test_bag_size);
00531
real bag_P0 =
safeexp(
sum(
instance_logP0));
00532
int classe = int(targetv[0]);
00533
int predicted_classe = (bag_P0>0.5)?0:1;
00534
real nll = (classe==0)?-
safeflog(bag_P0):-
safeflog(1-bag_P0);
00535
int classification_error = (classe != predicted_classe);
00536 costsv[0] = nll;
00537 costsv[1] = classification_error;
00538
00539
if (targetv[0] > 0) {
00540 costsv[2] = outputv[0];
00541 }
else {
00542 costsv[2] = -outputv[0];
00543 }
00544 }
00545 }
00546
00548
00550 void MultiInstanceNNet::initializeParams()
00551 {
00552
if (seed_>=0)
00553
manual_seed(seed_);
00554
else
00555
PLearn::seed();
00556
00557
00558
real delta = 1./
inputsize();
00559
00560
00561
00562
00563
00564
00565
00566
00567
if(
nhidden>0)
00568 {
00569
00570
00571
fill_random_normal(
w1->value, 0, delta);
00572
if(
direct_in_to_out)
00573 {
00574
00575
fill_random_normal(
wdirect->value, 0, 0.01*delta);
00576
wdirect->matValue(0).clear();
00577 }
00578 delta = 1./
nhidden;
00579
w1->matValue(0).clear();
00580 }
00581
if(
nhidden2>0)
00582 {
00583
00584
00585
fill_random_normal(
w2->value, 0, delta);
00586 delta = 1./
nhidden2;
00587
w2->matValue(0).clear();
00588 }
00589
00590
fill_random_normal(
wout->value, 0, delta);
00591
wout->matValue(0).clear();
00592
00593
00594
if(
optimizer)
00595
optimizer->reset();
00596 }
00597
00598 void MultiInstanceNNet::forget()
00599 {
00600
if (train_set)
initializeParams();
00601 stage = 0;
00602 }
00603
00605
extern void varDeepCopyField(
Var& field, CopiesMap& copies);
00606
00607 void MultiInstanceNNet::makeDeepCopyFromShallowCopy(
CopiesMap& copies)
00608 {
00609 inherited::makeDeepCopyFromShallowCopy(copies);
00610
deepCopyField(
instance_logP0, copies);
00611
varDeepCopyField(
input, copies);
00612
varDeepCopyField(
target, copies);
00613
varDeepCopyField(
sampleweight, copies);
00614
varDeepCopyField(
w1, copies);
00615
varDeepCopyField(
w2, copies);
00616
varDeepCopyField(
wout, copies);
00617
varDeepCopyField(
wdirect, copies);
00618
varDeepCopyField(
output, copies);
00619
varDeepCopyField(
bag_size, copies);
00620
varDeepCopyField(
bag_inputs, copies);
00621
varDeepCopyField(
bag_output, copies);
00622
deepCopyField(
inputs_and_targets_to_test_costs, copies);
00623
deepCopyField(
inputs_and_targets_to_training_costs, copies);
00624
deepCopyField(
input_to_logP0, copies);
00625
varDeepCopyField(
nll, copies);
00626
deepCopyField(
costs, copies);
00627
deepCopyField(
penalties, copies);
00628
varDeepCopyField(
training_cost, copies);
00629
varDeepCopyField(
test_costs, copies);
00630
deepCopyField(
invars, copies);
00631
deepCopyField(
params, copies);
00632
deepCopyField(
paramsvalues, copies);
00633
deepCopyField(
f, copies);
00634
deepCopyField(
test_costf, copies);
00635
deepCopyField(
output_and_target_to_cost, copies);
00636
deepCopyField(
optimizer, copies);
00637 }
00638
00639 }