00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00040
#include "Splitter.h"
00041
#include "VMat.h"
00042
#include "ConcatRowsVMatrix.h"
00043
#include "ConcatColumnsVMatrix.h"
00044
#include <plearn/math/random.h>
00045
00046
namespace PLearn {
00047
using namespace std;
00048
00049
PLEARN_IMPLEMENT_ABSTRACT_OBJECT(Splitter,
"ONE LINE DESCR",
"NO HELP");
00050
00051 void Splitter::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies)
00052 {
00053
deepCopyField(
dataset, copies);
00054 }
00055
00056 void Splitter::setDataSet(
VMat the_dataset)
00057 {
00058
dataset = the_dataset;
00059 }
00060
00061
00062
00063 void split(
VMat d,
real test_fraction,
VMat& train,
VMat& test,
int i,
bool use_all)
00064 {
00065
int n = d.
length();
00066
real ftest = test_fraction>=1.0 ? test_fraction : test_fraction*
real(n);
00067
int ntest =
int(ftest);
00068
int ntrain_before_test = n - (i+1)*ntest;
00069
int ntrain_after_test = i*ntest;
00070
if (use_all) {
00071
00072
int nsplits = int(n / ftest + 0.5);
00073
00074
int nleft = n - nsplits * ntest;
00075
00076
int ntest_more = nleft / nsplits;
00077
00078
00079
int nsplits_one_more = nleft % nsplits;
00080
00081 ntest = ntest + ntest_more;
00082
if (i < nsplits_one_more) {
00083 ntest++;
00084 ntrain_before_test = n - (i+1) * ntest;
00085 }
else {
00086 ntrain_before_test =
00087 n
00088 - (nsplits_one_more) * (ntest + 1)
00089 - (i - nsplits_one_more + 1) * ntest;
00090 }
00091 ntrain_after_test = n - ntest - ntrain_before_test;
00092 }
00093
00094 test = d.
subMatRows(ntrain_before_test, ntest);
00095
if(ntrain_after_test == 0)
00096 train = d.
subMatRows(0,ntrain_before_test);
00097
else if(ntrain_before_test==0)
00098 train = d.
subMatRows(ntest, ntrain_after_test);
00099
else
00100 train =
vconcat( d.
subMatRows(0,ntrain_before_test),
00101 d.
subMatRows(ntrain_before_test+ntest, ntrain_after_test) );
00102 }
00103
00104 Vec randomSplit(
VMat d,
real test_fraction,
VMat& train,
VMat& test)
00105 {
00106
int ntest =
int( test_fraction>=1.0 ?test_fraction :test_fraction*d.
length() );
00107
int ntrain = d.
length()-ntest;
00108
Vec indices(0, d.
length()-1, 1);
00109
shuffleElements(indices);
00110 train = d.
rows(indices.
subVec(0,ntrain));
00111 test = d.
rows(indices.
subVec(ntrain,ntest));
00112
return indices;
00113 }
00114
00115 void split(
VMat d,
real validation_fraction,
real test_fraction,
VMat& train,
VMat& valid,
VMat& test,
bool do_shuffle)
00116 {
00117
int ntest =
int( test_fraction>=1.0 ?test_fraction :test_fraction*d.
length() );
00118
int nvalid = int( validation_fraction>=1.0 ?validation_fraction :validation_fraction*d.
length() );
00119
int ntrain = d.
length()-(ntest+nvalid);
00120
Vec indices(0, d.
length()-1, 1);
00121
if (do_shuffle){
00122 cout<<
"shuffle !"<<
endl;
00123
shuffleElements(indices);
00124 }
00125 train = d.
rows(indices.
subVec(0,ntrain));
00126 valid = d.
rows(indices.
subVec(ntrain,nvalid));
00127 test = d.
rows(indices.
subVec(ntrain+nvalid,ntest));
00128 cout<<
"n_train : "<<ntrain<<
endl<<
"n_valid : "<<nvalid<<
endl<<
"n_test : "<<(d.
length()-ntrain+nvalid)<<
endl;
00129 }
00130
00131 void randomSplit(
VMat d,
real validation_fraction,
real test_fraction,
VMat& train,
VMat& valid,
VMat& test)
00132 {
00133
split(d,validation_fraction,test_fraction,train,valid,test,
true);
00134 }
00135
00136 }