00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00044
#include "LearnerCommand.h"
00045
#include <plearn_learners/generic/PLearner.h>
00046
00047
#include <plearn/vmat/FileVMatrix.h>
00048
#include <plearn/db/getDataSet.h>
00049
00050
00051
namespace PLearn {
00052
using namespace std;
00053
00055
PLearnCommandRegistry LearnerCommand::reg_(
new LearnerCommand);
00056
00057 LearnerCommand::LearnerCommand():
00058
PLearnCommand("learner",
00059
00060 "Allows to train,
use and test a learner",
00061
00062 "learner train <learner_spec.plearn> <trainset.vmat> <trained_learner.psave>\n"
00063 " -> Will train the specified learner on the specified trainset and
save the resulting trained learner as trained_learner.psave\n"
00064 "learner test <trained_learner.psave> <testset.vmat> <cost.stats> [<outputs.pmat>] [<costs.pmat>]\n"
00065 " -> Tests the specified learner on the testset. Will produce a cost.stats file (viewable with the plearn stats command) and optionally saves individual outputs and costs\n"
00066 "learner compute_outputs <trained_learner.psave> <test_inputs.vmat> <outputs.pmat> (there is 'learner co' as a shortcut for compute_outputs)\n"
00067
00068 "The datasets do not need to be .vmat they can be any valid vmatrix (.amat .pmat .dmat)"
00069 )
00070 {}
00071
00072
00073 void LearnerCommand::train(
const string& learner_spec_file,
const string& trainset_spec,
const string& save_learner_file)
00074 {
00075
00076
00077
PP<PLearner> learner;
00078
PLearn::load(learner_spec_file,learner);
00079
VMat trainset =
getDataSet(trainset_spec);
00080
PP<VecStatsCollector> train_stats =
new VecStatsCollector();
00081 learner->setTrainStatsCollector(train_stats);
00082 learner->setTrainingSet(trainset);
00083 learner->train();
00084
PLearn::save(save_learner_file, learner);
00085 }
00086
00087 void LearnerCommand::test(
const string& trained_learner_file,
const string& testset_spec,
const string& stats_file,
const string& outputs_file,
const string& costs_file)
00088 {
00089
PP<PLearner> learner;
00090
PLearn::load(trained_learner_file,learner);
00091
VMat testset =
getDataSet(testset_spec);
00092
int l = testset.
length();
00093
VMat testoutputs;
00094
if(outputs_file!=
"")
00095 testoutputs =
new FileVMatrix(outputs_file,l,learner->outputsize());
00096
VMat testcosts;
00097
if(costs_file!=
"")
00098 testcosts =
new FileVMatrix(costs_file,l,learner->nTestCosts());
00099
00100
PP<VecStatsCollector> test_stats;
00101 learner->test(testset, test_stats, testoutputs, testcosts);
00102
00103
PLearn::save(stats_file,test_stats);
00104 }
00105
00106 void LearnerCommand::compute_outputs(
const string& trained_learner_file,
const string& test_inputs_spec,
const string& outputs_file)
00107 {
00108
PP<PLearner> learner;
00109
PLearn::load(trained_learner_file,learner);
00110
VMat testinputs =
getDataSet(test_inputs_spec);
00111
int l = testinputs.
length();
00112
VMat testoutputs =
new FileVMatrix(outputs_file,l,learner->outputsize());
00113 learner->use(testinputs,testoutputs);
00114 }
00115
00117 void LearnerCommand::run(
const vector<string>& args)
00118 {
00119
string command = args[0];
00120
if(command==
"train")
00121 {
00122
if (args.size()==4)
00123
train(args[1],args[2],args[3]);
00124
else
00125
PLERROR(
"LearnerCommand::run you must provide 'plearn learner train learner_spec_file trainset_spec save_learner_file'");
00126 }
00127
else if(command==
"test")
00128 {
00129
if (args.size()>3)
00130 {
00131
string trained_learner_file = args[1];
00132
string testset_spec = args[2];
00133
string stats_basename = args[3];
00134
string outputs_file;
00135
if(args.size()>4)
00136 outputs_file = args[4];
00137
string costs_file;
00138
if(args.size()>5)
00139 costs_file = args[5];
00140
test(trained_learner_file, testset_spec, stats_basename, outputs_file, costs_file);
00141 }
00142
else
00143
PLERROR(
"LearnerCommand::run you must provide at least 'plearn learner test <trained_learner.psave> <testset.vmat> <cost.stats>'");
00144 }
00145
else if ((command==
"compute_outputs") ||(command==
"co"))
00146 {
00147
if (args.size()==4)
00148
compute_outputs(args[1],args[2],args[3]);
00149
else
00150
PLERROR(
"LearnerCommand::run you must provide 'plearn learner compute_outputs learner_spec_file trainset_spec save_learner_file'");
00151 }
00152
00153
else
00154
PLERROR(
"Invalid command %s check the help for available commands",command.c_str());
00155 }
00156
00157 }
00158