Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

TrainValidTestSplitter.cc

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // TrainValidTestSplitter.cc 00004 // 00005 // Copyright (C) 2004 Olivier Delalleau 00006 // 00007 // Redistribution and use in source and binary forms, with or without 00008 // modification, are permitted provided that the following conditions are met: 00009 // 00010 // 1. Redistributions of source code must retain the above copyright 00011 // notice, this list of conditions and the following disclaimer. 00012 // 00013 // 2. Redistributions in binary form must reproduce the above copyright 00014 // notice, this list of conditions and the following disclaimer in the 00015 // documentation and/or other materials provided with the distribution. 00016 // 00017 // 3. The name of the authors may not be used to endorse or promote 00018 // products derived from this software without specific prior written 00019 // permission. 00020 // 00021 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00022 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00023 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00024 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00025 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00026 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00027 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00028 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00029 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 // 00032 // This file is part of the PLearn library. For more information on the PLearn 00033 // library, go to the PLearn Web site at www.plearn.org 00034 00035 /* ******************************************************* 00036 * $Id: TrainValidTestSplitter.cc,v 1.3 2004/07/21 16:30:55 chrish42 Exp $ 00037 ******************************************************* */ 00038 00039 // Authors: Olivier Delalleau 00040 00043 #include "ConcatRowsVMatrix.h" 00044 #include <plearn/math/random.h> 00045 #include "SelectRowsVMatrix.h" 00046 #include "SubVMatrix.h" 00047 #include "TrainValidTestSplitter.h" 00048 00049 namespace PLearn { 00050 using namespace std; 00051 00053 // TrainValidTestSplitter // 00055 TrainValidTestSplitter::TrainValidTestSplitter() 00056 : Splitter(), 00057 append_train(0), 00058 append_valid(0), 00059 n_splits(1), 00060 n_train(-1), 00061 n_valid(-1), 00062 shuffle_valid_and_test(1) 00063 {} 00064 00065 PLEARN_IMPLEMENT_OBJECT(TrainValidTestSplitter, 00066 "This splitter will basically return [Train+Valid, Test].", 00067 "The train test returned by the splitter is formed from the first n_train+n_valid\n" 00068 "samples in the dataset. The other samples are returned in the test set.\n" 00069 "The validation and test sets (given by the samples after the n_train-th one) can\n" 00070 "be shuffled in order to get a different validation and test sets at each split.\n" 00071 "However, the train set (the first n_train samples) remains fixed."); 00072 00074 // declareOptions // 00076 void TrainValidTestSplitter::declareOptions(OptionList& ol) 00077 { 00078 declareOption(ol, "append_train", &TrainValidTestSplitter::append_train, OptionBase::buildoption, 00079 "If set to 1, the train set will be appended to each split, after the test set\n" 00080 "(the train set means the first n_train samples)."); 00081 00082 declareOption(ol, "append_valid", &TrainValidTestSplitter::append_valid, OptionBase::buildoption, 00083 "If set to 1, the validation set will be appended to each split, after the test set\n" 00084 "(or the train set if append_train is also set to 1)."); 00085 00086 declareOption(ol, "n_splits", &TrainValidTestSplitter::n_splits, OptionBase::buildoption, 00087 "The number of splits we want (a value > 1 is useful with shuffle_valid_and_test = 1)."); 00088 00089 declareOption(ol, "n_train", &TrainValidTestSplitter::n_train, OptionBase::buildoption, 00090 "The number of samples that define the train set, assumed to be at the beginning\n" 00091 "of the dataset."); 00092 00093 declareOption(ol, "n_valid", &TrainValidTestSplitter::n_valid, OptionBase::buildoption, 00094 "The number of samples that define the validation set (they are taken among\n" 00095 "the samples after the n_train first ones)."); 00096 00097 declareOption(ol, "shuffle_valid_and_test", &TrainValidTestSplitter::shuffle_valid_and_test, OptionBase::buildoption, 00098 "If set to 1, then the part of the dataset after the first n_train ones will\n" 00099 "be shuffled before taking the validation and test sets. Note that if you want\n" 00100 "to set it to 0, then using a TrainTestSplitter is probably more appropriate."); 00101 00102 // Now call the parent class' declareOptions 00103 inherited::declareOptions(ol); 00104 } 00105 00107 // build // 00109 void TrainValidTestSplitter::build() 00110 { 00111 inherited::build(); 00112 build_(); 00113 } 00114 00116 // build_ // 00118 void TrainValidTestSplitter::build_() 00119 { 00120 if (dataset) { 00121 if (n_train < 0 || n_valid < 0) { 00122 PLERROR("In TrainValidTestSplitter::build_ - Please initialize correctly 'n_train' and 'n_valid'"); 00123 } 00124 int n = dataset->length(); 00125 int n_test = n - n_train - n_valid; 00126 // Define the train set. 00127 train_set = new SubVMatrix(dataset, 0, 0, n_train, dataset->width()); 00128 // Precompute all the indices. 00129 valid_indices.resize(n_splits, n_valid); 00130 test_indices.resize(n_splits, n_test); 00131 TVec<int> valid_and_test_indices(n_valid + n_test); 00132 for (int i = 0; i < n_splits; i++) { 00133 for (int j = 0; j < n_valid + n_test; j++) { 00134 valid_and_test_indices[j] = j + n_train; 00135 } 00136 if (shuffle_valid_and_test) { 00137 shuffleElements(valid_and_test_indices); 00138 } 00139 valid_indices(i) << valid_and_test_indices.subVec(0, n_valid); 00140 test_indices(i) << valid_and_test_indices.subVec(n_valid, n_test); 00141 if (shuffle_valid_and_test) { 00142 // Now sort the indices for (hopefully) faster access. 00143 sortElements(valid_indices(i)); 00144 sortElements(test_indices(i)); 00145 } 00146 } 00147 } 00148 } 00149 00151 // getSplit // 00153 TVec<VMat> TrainValidTestSplitter::getSplit(int k) 00154 { 00155 // ### Build and return the kth split . 00156 TVec<VMat> result(2); 00157 VMat valid_set = new SelectRowsVMatrix(dataset, valid_indices(k)); 00158 result[0] = vconcat(train_set, valid_set); 00159 result[1] = new SelectRowsVMatrix(dataset, test_indices(k)); 00160 if (append_train) { 00161 result.append(train_set); 00162 } 00163 if (append_valid) { 00164 result.append(valid_set); 00165 } 00166 return result; 00167 } 00168 00170 // makeDeepCopyFromShallowCopy // 00172 void TrainValidTestSplitter::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies) 00173 { 00174 Splitter::makeDeepCopyFromShallowCopy(copies); 00175 00176 // ### Call deepCopyField on all "pointer-like" fields 00177 // ### that you wish to be deepCopied rather than 00178 // ### shallow-copied. 00179 // ### ex: 00180 // deepCopyField(trainvec, copies); 00181 00182 // ### Remove this line when you have fully implemented this method. 00183 PLERROR("TrainValidTestSplitter::makeDeepCopyFromShallowCopy not fully (correctly) implemented yet!"); 00184 } 00185 00187 // nsplits // 00189 int TrainValidTestSplitter::nsplits() const 00190 { 00191 // ### Return the number of available splits 00192 return this->n_splits; 00193 } 00194 00196 // nSetsPerSplit // 00198 int TrainValidTestSplitter::nSetsPerSplit() const 00199 { 00200 // ### Return the number of sets per split 00201 int result = 2; 00202 if (append_train) { 00203 result++; 00204 } 00205 if (append_valid) { 00206 result++; 00207 } 00208 return result; 00209 } 00210 00212 // setDataSet // 00214 void TrainValidTestSplitter::setDataSet(VMat the_dataset) { 00215 inherited::setDataSet(the_dataset); 00216 build_(); // To recompute the indices. 00217 } 00218 00219 } // end of namespace PLearn

Generated on Tue Aug 17 16:09:06 2004 for PLearn by doxygen 1.3.7