00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00045
#include "TrainTestSplitter.h"
00046
00047
namespace PLearn {
00048
using namespace std;
00049
00050 TrainTestSplitter::TrainTestSplitter(
real the_test_fraction)
00051 : append_train(0), test_fraction(the_test_fraction)
00052 {}
00053
00054
PLEARN_IMPLEMENT_OBJECT(
TrainTestSplitter,
"ONE LINE DESCR",
00055
"TrainTestSplitter implements a single split of the dataset into a training-set and a test-set (the test part being the last few samples of the dataset)");
00056
00057 void TrainTestSplitter::declareOptions(
OptionList& ol)
00058 {
00059
declareOption(ol,
"append_train", &TrainTestSplitter::append_train, OptionBase::buildoption,
00060
"if set to 1, the trainset will be appended after the test set (thus each split"
00061
" will contain three sets)");
00062
00063
declareOption(ol,
"test_fraction", &TrainTestSplitter::test_fraction, OptionBase::buildoption,
00064
"the fraction of the dataset reserved to the test set");
00065
00066 inherited::declareOptions(ol);
00067 }
00068
00069 void TrainTestSplitter::build_()
00070 {
00071 }
00072
00073
00074 void TrainTestSplitter::build()
00075 {
00076 inherited::build();
00077
build_();
00078 }
00079
00080 int TrainTestSplitter::nsplits()
const
00081
{
00082
return 1;
00083 }
00084
00085 int TrainTestSplitter::nSetsPerSplit()
const
00086
{
00087
if (
append_train)
00088
return 3;
00089
else
00090
return 2;
00091 }
00092
00093 TVec<VMat> TrainTestSplitter::getSplit(
int k)
00094 {
00095
if (
k)
00096
PLERROR(
"TrainTestSplitter::getSplit() - k cannot be greater than 0");
00097
00098
TVec<VMat> split_(2);
00099
00100
int l = dataset->
length();
00101
int test_length =
int(
test_fraction*l);
00102
int train_length = l - test_length;
00103
00104 split_[0] = dataset.
subMatRows(0, train_length);
00105 split_[1] = dataset.
subMatRows(train_length, test_length);
00106
if (
append_train) {
00107 split_.
resize(3);
00108 split_[2] = split_[0];
00109 }
00110
return split_;
00111 }
00112
00113 }