00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 
00037 
00038 
00039 
00040 
00041 
00043 
#include "SequentialSplitter.h"
00044 
00045 
namespace PLearn {
00046 
using namespace std;
00047 
00048 SequentialSplitter::SequentialSplitter(
int horizon_, 
int init_train_size_, 
bool return_entire_vmat_)
00049     : horizon(horizon_), init_train_size(init_train_size_), return_entire_vmat(return_entire_vmat_)
00050 {}
00051 
00052 
PLEARN_IMPLEMENT_OBJECT(
SequentialSplitter, 
"ONE LINE DESCR",
00053                         
"SequentialSplitter implements several splits, TODO: Comments");
00054 
00055 void SequentialSplitter::declareOptions(
OptionList& ol)
00056 {
00057   
declareOption(ol, 
"horizon", &SequentialSplitter::horizon, OptionBase::buildoption,
00058       
"How far in the future is the test set (split[1])");
00059 
00060   
declareOption(ol, 
"init_train_size", &SequentialSplitter::init_train_size, OptionBase::buildoption,
00061       
"Initial length of the train set (split[0])");
00062 
00063   
declareOption(ol, 
"return_entire_vmat", &SequentialSplitter::return_entire_vmat, OptionBase::buildoption,
00064       
"If true, the test split (split[1]) will start at t=0.");
00065 
00066   inherited::declareOptions(ol);
00067 }
00068 
00069 void SequentialSplitter::build_()
00070 {
00071 }
00072 
00073 
00074 void SequentialSplitter::build()
00075 {
00076   inherited::build();
00077   
build_();
00078 }
00079 
00080 int SequentialSplitter::nSetsPerSplit()
 const
00081 
{
00082   
return 2;
00083 }
00084 
00085 int SequentialSplitter::nsplits()
 const
00086 
{
00087   
if (dataset.
isNull())
00088     
PLERROR(
"SequentialSplitter::nsplits() - Must call setDataSet()");
00089   
if (
init_train_size < 1)
00090     
PLERROR(
"SequentialSplitter::nsplits() - init_train_size must be stricktly positive (%d)", 
init_train_size);
00091   
if (
horizon < 1)
00092     
PLERROR(
"SequentialSplitter::nsplits() - horizon must be stricktly positive (%d)", 
horizon);
00093 
00094   
return dataset.
length() - 
init_train_size - 
horizon + 1;
00095 }
00096 
00097 TVec<VMat> SequentialSplitter::getSplit(
int k)
00098 {
00099   
if (dataset.
isNull())
00100     
PLERROR(
"SequentialSplitter::getSplit() - Must call setDataSet()");
00101 
00102   
int n_splits = 
nsplits();
00103   
if (
k >= n_splits)
00104     
PLERROR(
"SequentialSplitter::getSplit() - k (%d) cannot be greater than K (%d)", 
k, n_splits);
00105 
00106   
int seq_length = dataset.
length();
00107   
if (
init_train_size >= seq_length)
00108     
PLERROR(
"SequentialSplitter::getSplit() - init_train_size (%d) >= dataset.length() (%d)", 
init_train_size, seq_length);
00109 
00110   
int t = 
init_train_size + 
k;
00111   
int start_test_t = 
return_entire_vmat ? 0 : t;
00112   
int n_test = t + 
horizon - start_test_t;
00113 
00114   
TVec<VMat> split_(2);
00115   split_[0] = dataset.
subMatRows(0, t);
00116   split_[1] = dataset.
subMatRows(start_test_t, n_test);
00117 
00118   
return split_;
00119 }
00120 
00121 
00122 }