00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 
00037 
00038 
00039 
00040 
00043 
#include "ConcatRowsVMatrix.h"
00044 
#include <plearn/math/random.h>
00045 
#include "SelectRowsVMatrix.h"
00046 
#include "SubVMatrix.h"
00047 
#include "TrainValidTestSplitter.h"
00048 
00049 
namespace PLearn {
00050 
using namespace std;
00051 
00053 
00055 TrainValidTestSplitter::TrainValidTestSplitter() 
00056   : 
Splitter(),
00057     append_train(0),
00058     append_valid(0),
00059     n_splits(1),
00060     n_train(-1),
00061     n_valid(-1),
00062     shuffle_valid_and_test(1)
00063 {}
00064 
00065 
PLEARN_IMPLEMENT_OBJECT(
TrainValidTestSplitter,
00066     
"This splitter will basically return [Train+Valid, Test].",
00067     
"The train test returned by the splitter is formed from the first n_train+n_valid\n"
00068     
"samples in the dataset. The other samples are returned in the test set.\n"
00069     
"The validation and test sets (given by the samples after the n_train-th one) can\n"
00070     
"be shuffled in order to get a different validation and test sets at each split.\n"
00071     
"However, the train set (the first n_train samples) remains fixed.");
00072 
00074 
00076 void TrainValidTestSplitter::declareOptions(
OptionList& ol)
00077 {
00078    
declareOption(ol, 
"append_train", &TrainValidTestSplitter::append_train, OptionBase::buildoption,
00079        
"If set to 1, the train set will be appended to each split, after the test set\n"
00080        
"(the train set means the first n_train samples).");
00081 
00082    
declareOption(ol, 
"append_valid", &TrainValidTestSplitter::append_valid, OptionBase::buildoption,
00083        
"If set to 1, the validation set will be appended to each split, after the test set\n"
00084        
"(or the train set if append_train is also set to 1).");
00085 
00086    
declareOption(ol, 
"n_splits", &TrainValidTestSplitter::n_splits, OptionBase::buildoption,
00087        
"The number of splits we want (a value > 1 is useful with shuffle_valid_and_test = 1).");
00088 
00089    
declareOption(ol, 
"n_train", &TrainValidTestSplitter::n_train, OptionBase::buildoption,
00090        
"The number of samples that define the train set, assumed to be at the beginning\n"
00091        
"of the dataset.");
00092 
00093    
declareOption(ol, 
"n_valid", &TrainValidTestSplitter::n_valid, OptionBase::buildoption,
00094        
"The number of samples that define the validation set (they are taken among\n"
00095        
"the samples after the n_train first ones).");
00096      
00097    
declareOption(ol, 
"shuffle_valid_and_test", &TrainValidTestSplitter::shuffle_valid_and_test, OptionBase::buildoption,
00098        
"If set to 1, then the part of the dataset after the first n_train ones will\n"
00099        
"be shuffled before taking the validation and test sets. Note that if you want\n"
00100        
"to set it to 0, then using a TrainTestSplitter is probably more appropriate.");
00101    
00102   
00103   inherited::declareOptions(ol);
00104 }
00105 
00107 
00109 void TrainValidTestSplitter::build()
00110 {
00111   inherited::build();
00112   
build_();
00113 }
00114 
00116 
00118 void TrainValidTestSplitter::build_()
00119 {
00120   
if (dataset) {
00121     
if (
n_train < 0 || 
n_valid < 0) {
00122       
PLERROR(
"In TrainValidTestSplitter::build_ - Please initialize correctly 'n_train' and 'n_valid'");
00123     }
00124     
int n = dataset->
length();
00125     
int n_test = n - 
n_train - 
n_valid;
00126     
00127     
train_set = 
new SubVMatrix(dataset, 0, 0, 
n_train, dataset->
width());
00128     
00129     
valid_indices.
resize(
n_splits, n_valid);
00130     
test_indices.
resize(
n_splits, n_test);
00131     
TVec<int> valid_and_test_indices(n_valid + n_test);
00132     
for (
int i = 0; i < 
n_splits; i++) {
00133       
for (
int j = 0; j < n_valid + n_test; j++) {
00134         valid_and_test_indices[j] = j + 
n_train;
00135       }
00136       
if (
shuffle_valid_and_test) {
00137         
shuffleElements(valid_and_test_indices);
00138       }
00139       
valid_indices(i) << valid_and_test_indices.
subVec(0, n_valid);
00140       
test_indices(i) << valid_and_test_indices.
subVec(n_valid, n_test);
00141       
if (
shuffle_valid_and_test) {
00142         
00143         
sortElements(
valid_indices(i));
00144         
sortElements(
test_indices(i));
00145       }
00146     }
00147   }
00148 }
00149 
00151 
00153 TVec<VMat> TrainValidTestSplitter::getSplit(
int k)
00154 {
00155   
00156   
TVec<VMat> result(2);
00157   
VMat valid_set = 
new SelectRowsVMatrix(dataset, 
valid_indices(
k));
00158   result[0] = 
vconcat(
train_set, valid_set);
00159   result[1] = 
new SelectRowsVMatrix(dataset, 
test_indices(
k));
00160   
if (
append_train) {
00161     result.
append(
train_set);
00162   }
00163   
if (
append_valid) {
00164     result.
append(valid_set);
00165   }
00166   
return result;
00167 }
00168 
00170 
00172 void TrainValidTestSplitter::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies)
00173 {
00174   Splitter::makeDeepCopyFromShallowCopy(copies);
00175 
00176   
00177   
00178   
00179   
00180   
00181 
00182   
00183   
PLERROR(
"TrainValidTestSplitter::makeDeepCopyFromShallowCopy not fully (correctly) implemented yet!");
00184 }
00185 
00187 
00189 int TrainValidTestSplitter::nsplits()
 const
00190 
{
00191   
00192   
return this->
n_splits;
00193 }
00194 
00196 
00198 int TrainValidTestSplitter::nSetsPerSplit()
 const
00199 
{
00200   
00201   
int result = 2;
00202   
if (
append_train) {
00203     result++;
00204   }
00205   
if (
append_valid) {
00206     result++;
00207   }
00208   
return result;
00209 }
00210 
00212 
00214 void TrainValidTestSplitter::setDataSet(
VMat the_dataset) {
00215   inherited::setDataSet(the_dataset);
00216   
build_(); 
00217 }
00218 
00219 }