00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00043
#include "SequentialSplitter.h"
00044
00045
namespace PLearn {
00046
using namespace std;
00047
00048 SequentialSplitter::SequentialSplitter(
int horizon_,
int init_train_size_,
bool return_entire_vmat_)
00049 : horizon(horizon_), init_train_size(init_train_size_), return_entire_vmat(return_entire_vmat_)
00050 {}
00051
00052
PLEARN_IMPLEMENT_OBJECT(
SequentialSplitter,
"ONE LINE DESCR",
00053
"SequentialSplitter implements several splits, TODO: Comments");
00054
00055 void SequentialSplitter::declareOptions(
OptionList& ol)
00056 {
00057
declareOption(ol,
"horizon", &SequentialSplitter::horizon, OptionBase::buildoption,
00058
"How far in the future is the test set (split[1])");
00059
00060
declareOption(ol,
"init_train_size", &SequentialSplitter::init_train_size, OptionBase::buildoption,
00061
"Initial length of the train set (split[0])");
00062
00063
declareOption(ol,
"return_entire_vmat", &SequentialSplitter::return_entire_vmat, OptionBase::buildoption,
00064
"If true, the test split (split[1]) will start at t=0.");
00065
00066 inherited::declareOptions(ol);
00067 }
00068
00069 void SequentialSplitter::build_()
00070 {
00071 }
00072
00073
00074 void SequentialSplitter::build()
00075 {
00076 inherited::build();
00077
build_();
00078 }
00079
00080 int SequentialSplitter::nSetsPerSplit()
const
00081
{
00082
return 2;
00083 }
00084
00085 int SequentialSplitter::nsplits()
const
00086
{
00087
if (dataset.
isNull())
00088
PLERROR(
"SequentialSplitter::nsplits() - Must call setDataSet()");
00089
if (
init_train_size < 1)
00090
PLERROR(
"SequentialSplitter::nsplits() - init_train_size must be stricktly positive (%d)",
init_train_size);
00091
if (
horizon < 1)
00092
PLERROR(
"SequentialSplitter::nsplits() - horizon must be stricktly positive (%d)",
horizon);
00093
00094
return dataset.
length() -
init_train_size -
horizon + 1;
00095 }
00096
00097 TVec<VMat> SequentialSplitter::getSplit(
int k)
00098 {
00099
if (dataset.
isNull())
00100
PLERROR(
"SequentialSplitter::getSplit() - Must call setDataSet()");
00101
00102
int n_splits =
nsplits();
00103
if (
k >= n_splits)
00104
PLERROR(
"SequentialSplitter::getSplit() - k (%d) cannot be greater than K (%d)",
k, n_splits);
00105
00106
int seq_length = dataset.
length();
00107
if (
init_train_size >= seq_length)
00108
PLERROR(
"SequentialSplitter::getSplit() - init_train_size (%d) >= dataset.length() (%d)",
init_train_size, seq_length);
00109
00110
int t =
init_train_size +
k;
00111
int start_test_t =
return_entire_vmat ? 0 : t;
00112
int n_test = t +
horizon - start_test_t;
00113
00114
TVec<VMat> split_(2);
00115 split_[0] = dataset.
subMatRows(0, t);
00116 split_[1] = dataset.
subMatRows(start_test_t, n_test);
00117
00118
return split_;
00119 }
00120
00121
00122 }