Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

Splitter.cc

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // Splitter.cc 00004 // 00005 // Copyright (C) 2002 Pascal Vincent, Frederic Morin 00006 // 00007 // Redistribution and use in source and binary forms, with or without 00008 // modification, are permitted provided that the following conditions are met: 00009 // 00010 // 1. Redistributions of source code must retain the above copyright 00011 // notice, this list of conditions and the following disclaimer. 00012 // 00013 // 2. Redistributions in binary form must reproduce the above copyright 00014 // notice, this list of conditions and the following disclaimer in the 00015 // documentation and/or other materials provided with the distribution. 00016 // 00017 // 3. The name of the authors may not be used to endorse or promote 00018 // products derived from this software without specific prior written 00019 // permission. 00020 // 00021 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00022 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00023 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00024 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00025 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00026 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00027 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00028 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00029 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 // 00032 // This file is part of the PLearn library. For more information on the PLearn 00033 // library, go to the PLearn Web site at www.plearn.org 00034 00035 /* ******************************************************* 00036 * $Id: Splitter.cc,v 1.8 2004/07/21 16:30:55 chrish42 Exp $ 00037 ******************************************************* */ 00038 00040 #include "Splitter.h" 00041 #include "VMat.h" 00042 #include "ConcatRowsVMatrix.h" 00043 #include "ConcatColumnsVMatrix.h" 00044 #include <plearn/math/random.h> 00045 00046 namespace PLearn { 00047 using namespace std; 00048 00049 PLEARN_IMPLEMENT_ABSTRACT_OBJECT(Splitter, "ONE LINE DESCR", "NO HELP"); 00050 00051 void Splitter::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies) 00052 { 00053 deepCopyField(dataset, copies); 00054 } 00055 00056 void Splitter::setDataSet(VMat the_dataset) 00057 { 00058 dataset = the_dataset; 00059 } 00060 00061 // Useful splitting functions 00062 00063 void split(VMat d, real test_fraction, VMat& train, VMat& test, int i, bool use_all) 00064 { 00065 int n = d.length(); 00066 real ftest = test_fraction>=1.0 ? test_fraction : test_fraction*real(n); 00067 int ntest = int(ftest); 00068 int ntrain_before_test = n - (i+1)*ntest; 00069 int ntrain_after_test = i*ntest; 00070 if (use_all) { 00071 // See how many splits there are. 00072 int nsplits = int(n / ftest + 0.5); 00073 // See how many examples will be left. 00074 int nleft = n - nsplits * ntest; 00075 // Deduce how many examples to add in each split. 00076 int ntest_more = nleft / nsplits; 00077 // And, finally, how many splits will have one more example so that they are 00078 // all taken somewhere. 00079 int nsplits_one_more = nleft % nsplits; 00080 // Now recompute ntest, ntrain_before_test and ntrain_after_test. 00081 ntest = ntest + ntest_more; 00082 if (i < nsplits_one_more) { 00083 ntest++; 00084 ntrain_before_test = n - (i+1) * ntest; 00085 } else { 00086 ntrain_before_test = 00087 n 00088 - (nsplits_one_more) * (ntest + 1) 00089 - (i - nsplits_one_more + 1) * ntest; 00090 } 00091 ntrain_after_test = n - ntest - ntrain_before_test; 00092 } 00093 00094 test = d.subMatRows(ntrain_before_test, ntest); 00095 if(ntrain_after_test == 0) 00096 train = d.subMatRows(0,ntrain_before_test); 00097 else if(ntrain_before_test==0) 00098 train = d.subMatRows(ntest, ntrain_after_test); 00099 else 00100 train = vconcat( d.subMatRows(0,ntrain_before_test), 00101 d.subMatRows(ntrain_before_test+ntest, ntrain_after_test) ); 00102 } 00103 00104 Vec randomSplit(VMat d, real test_fraction, VMat& train, VMat& test) 00105 { 00106 int ntest = int( test_fraction>=1.0 ?test_fraction :test_fraction*d.length() ); 00107 int ntrain = d.length()-ntest; 00108 Vec indices(0, d.length()-1, 1); // Range-vector 00109 shuffleElements(indices); 00110 train = d.rows(indices.subVec(0,ntrain)); 00111 test = d.rows(indices.subVec(ntrain,ntest)); 00112 return indices; 00113 } 00114 00115 void split(VMat d, real validation_fraction, real test_fraction, VMat& train, VMat& valid, VMat& test,bool do_shuffle) 00116 { 00117 int ntest = int( test_fraction>=1.0 ?test_fraction :test_fraction*d.length() ); 00118 int nvalid = int( validation_fraction>=1.0 ?validation_fraction :validation_fraction*d.length() ); 00119 int ntrain = d.length()-(ntest+nvalid); 00120 Vec indices(0, d.length()-1, 1); // Range-vector 00121 if (do_shuffle){ 00122 cout<<"shuffle !"<<endl; 00123 shuffleElements(indices); 00124 } 00125 train = d.rows(indices.subVec(0,ntrain)); 00126 valid = d.rows(indices.subVec(ntrain,nvalid)); 00127 test = d.rows(indices.subVec(ntrain+nvalid,ntest)); 00128 cout<<"n_train : "<<ntrain<<endl<<"n_valid : "<<nvalid<<endl<<"n_test : "<<(d.length()-ntrain+nvalid)<<endl; 00129 } 00130 00131 void randomSplit(VMat d, real validation_fraction, real test_fraction, VMat& train, VMat& valid, VMat& test) 00132 { 00133 split(d,validation_fraction,test_fraction,train,valid,test,true); 00134 } 00135 00136 } // end of namespace PLearn

Generated on Tue Aug 17 16:06:38 2004 for PLearn by doxygen 1.3.7