00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00042
#include "KFoldSplitter.h"
00043
#include "VMat_maths.h"
00044
00045
namespace PLearn {
00046
using namespace std;
00047
00048 KFoldSplitter::KFoldSplitter(
int k)
00049 : K(
k),append_train(0)
00050 {}
00051
00052
PLEARN_IMPLEMENT_OBJECT(
KFoldSplitter,
00053
"K-fold cross-validation splitter.",
00054
"KFoldSplitter implements K splits of the dataset into a training-set and a test-set.\n"
00055
"If the number of splits is higher than the number of examples, leave-one-out cross-validation\n"
00056
"will be performed."
00057 );
00058
00059 void KFoldSplitter::declareOptions(
OptionList& ol)
00060 {
00061
declareOption(ol,
"K", &KFoldSplitter::K, OptionBase::buildoption,
00062
"Split dataset in K parts.");
00063
00064
declareOption(ol,
"append_train", &KFoldSplitter::append_train, OptionBase::buildoption,
00065
"If set to 1, the trainset will be appended after the test set (thus each split\n"
00066
"will contain three sets.");
00067
00068 inherited::declareOptions(ol);
00069 }
00070
00071 void KFoldSplitter::build_()
00072 {
00073 }
00074
00075
00076 void KFoldSplitter::build()
00077 {
00078 inherited::build();
00079
build_();
00080 }
00081
00082 int KFoldSplitter::nsplits()
const
00083
{
00084
return K;
00085 }
00086
00087 int KFoldSplitter::nSetsPerSplit()
const
00088
{
00089
if (
append_train)
00090
return 3;
00091
else
00092
return 2;
00093 }
00094
00095 TVec<VMat> KFoldSplitter::getSplit(
int k)
00096 {
00097
if (
k >=
K)
00098
PLERROR(
"KFoldSplitter::getSplit() - k (%d) cannot be greater than K (%d)",
k,
K);
00099
00100
int n_data = dataset->
length();
00101
real test_fraction =
K > 0 ? (n_data/(
real)
K) : 0;
00102
if ((
int)(test_fraction) < 1)
00103 test_fraction = 1;
00104
00105
TVec<VMat> split_(2);
00106
split(dataset, test_fraction, split_[0], split_[1],
k,
true);
00107
if (
append_train) {
00108 split_.
append(split_[0]);
00109 }
00110
return split_;
00111 }
00112
00113 }