Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

RepeatSplitter.cc

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // RepeatSplitter.cc 00004 // 00005 // Copyright (C) 2003 Olivier Delalleau 00006 // 00007 // Redistribution and use in source and binary forms, with or without 00008 // modification, are permitted provided that the following conditions are met: 00009 // 00010 // 1. Redistributions of source code must retain the above copyright 00011 // notice, this list of conditions and the following disclaimer. 00012 // 00013 // 2. Redistributions in binary form must reproduce the above copyright 00014 // notice, this list of conditions and the following disclaimer in the 00015 // documentation and/or other materials provided with the distribution. 00016 // 00017 // 3. The name of the authors may not be used to endorse or promote 00018 // products derived from this software without specific prior written 00019 // permission. 00020 // 00021 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00022 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00023 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00024 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00025 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00026 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00027 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00028 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00029 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00030 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00031 // 00032 // This file is part of the PLearn library. For more information on the PLearn 00033 // library, go to the PLearn Web site at www.plearn.org 00034 00035 /* ******************************************************* 00036 * $Id: RepeatSplitter.cc,v 1.9 2004/07/21 16:30:55 chrish42 Exp $ 00037 ******************************************************* */ 00038 00041 #include <plearn/math/random.h> 00042 #include "RepeatSplitter.h" 00043 #include "SelectRowsVMatrix.h" 00044 00045 namespace PLearn { 00046 using namespace std; 00047 00049 // RepeatSplitter // 00051 RepeatSplitter::RepeatSplitter() 00052 : 00053 last_n(-1), 00054 do_not_shuffle_first(0), 00055 force_proportion(-1), 00056 n(1), 00057 seed(-1), 00058 shuffle(0) 00059 { 00060 } 00061 00062 PLEARN_IMPLEMENT_OBJECT(RepeatSplitter, 00063 "Repeat a given splitter a certain amount of times, with the possibility to\n" 00064 "shuffle randomly the dataset each time", 00065 "NO HELP"); 00066 00068 // declareOptions // 00070 void RepeatSplitter::declareOptions(OptionList& ol) 00071 { 00072 declareOption(ol, "do_not_shuffle_first", &RepeatSplitter::do_not_shuffle_first, OptionBase::buildoption, 00073 "If set to 1, then the dataset won't be shuffled the first time we do the splitting.\n" 00074 "It only makes sense to use this option if 'shuffle' is set to 1."); 00075 00076 declareOption(ol, "force_proportion", &RepeatSplitter::force_proportion, OptionBase::buildoption, 00077 "If a target value appears at least once every x samples, will ensure that after\n" 00078 "shuffling it appears at least once every (x * 'force_proportion') samples, and not\n" 00079 "more than once every (x / 'force_proportion') samples. Will be ignored if < 1.\n" 00080 "Note that this currently only works for a binary target! (and hasn't been 100% tested)."); 00081 00082 declareOption(ol, "n", &RepeatSplitter::n, OptionBase::buildoption, 00083 "How many times we want to repeat."); 00084 00085 declareOption(ol, "seed", &RepeatSplitter::seed, OptionBase::buildoption, 00086 "Initializes the random number generator (only if shuffle is set to 1).\n" 00087 "If set to -1, the initialization will depend on the clock."); 00088 00089 declareOption(ol, "shuffle", &RepeatSplitter::shuffle, OptionBase::buildoption, 00090 "If set to 1, the dataset will be shuffled differently at each repetition."); 00091 00092 declareOption(ol, "to_repeat", &RepeatSplitter::to_repeat, OptionBase::buildoption, 00093 "The splitter we want to repeat."); 00094 00095 inherited::declareOptions(ol); 00096 } 00097 00099 // build // 00101 void RepeatSplitter::build() 00102 { 00103 inherited::build(); 00104 build_(); 00105 } 00106 00108 // build_ // 00110 void RepeatSplitter::build_() 00111 { 00112 if (shuffle && dataset) { 00113 // Prepare the shuffled indices. 00114 if (seed >= 0) 00115 manual_seed(seed); 00116 else 00117 PLearn::seed(); 00118 int n_splits = nsplits(); 00119 indices = TMat<int>(n_splits, dataset.length()); 00120 TVec<int> shuffled; 00121 for (int i = 0; i < n_splits; i++) { 00122 shuffled = TVec<int>(0, dataset.length()-1, 1); 00123 // Don't shuffle if (i == 0) and do_not_shuffle_first is set to 1. 00124 if (!do_not_shuffle_first || i > 0) { 00125 shuffleElements(shuffled); 00126 if (force_proportion >= 1) { 00127 // We need to ensure the proportions of target values are respected. 00128 // First compute the target stats. 00129 StatsCollector tsc(2000); 00130 if (dataset->targetsize() != 1) { 00131 PLERROR("In RepeatSplitter::build_ - 'force_proportion' is only implemented for a 1-dimensional target"); 00132 } 00133 real t; 00134 for (int j = 0; j < dataset->length(); j++) { 00135 t = dataset->get(j, dataset->inputsize()); // We get the target. 00136 tsc.update(t); 00137 } 00138 tsc.finalize(); 00139 // Make sure the target is binary. 00140 int count = (int) tsc.getCounts()->size() - 1; 00141 if (count != 2) { 00142 PLERROR("In RepeatSplitter::build_ - 'force_proportion' is only implemented for a binary target"); 00143 } 00144 // Ensure the proportion of the targets respect the constraints. 00145 int j = 0; 00146 for (map<real,StatsCollectorCounts>::iterator it = tsc.getCounts()->begin(); j < count; j++) { 00147 t = it->first; 00148 real prop_t = real(it->second.n) / real(dataset->length()); 00149 // Find the step to use to check the proportion is ok. We want a 00150 // step such that each 'step' examples, there should be at least two 00151 // with this target, but less than 'step - 10'. 00152 // For instance, for a proportion of 0.1, 'step' would be 20, 00153 // and for a proportion of 0.95, it would be 200. 00154 // We also want the approximation made when rounding to be 00155 // negligible. 00156 int step = 20; 00157 bool ok = false; 00158 while (!ok) { 00159 int n = int(step * prop_t + 0.5); 00160 if (n >= 2 && n <= step - 10 00161 && abs(step * prop_t - real(n)) / real(step) < 0.01) { 00162 ok = true; 00163 } else { 00164 // We try a higher step. 00165 step *= 2; 00166 } 00167 } 00168 int expected_count = int(step * prop_t + 0.5); 00169 // cout << "step = " << step << ", expected_count = " << expected_count << endl; 00170 // Now verify the proportion. 00171 ok = false; 00172 int tc = dataset->inputsize(); // The target column. 00173 while (!ok) { 00174 ok = true; 00175 // First pass: ensure there is enough. 00176 int first_pass_step = int(step * force_proportion + 0.5); 00177 int k,l; 00178 for (k = 0; k < shuffled.length(); k += first_pass_step) { 00179 int count_target = 0; 00180 for (l = k; l < k + first_pass_step && l < shuffled.length(); l++) { 00181 if (dataset->get(shuffled[l], tc) == t) { 00182 count_target++; 00183 } 00184 } 00185 if (l - k == first_pass_step && count_target < expected_count) { 00186 // Not enough, need to add more. 00187 ok = false; 00188 // cout << "At l = " << l << ", need to add " << expected_count - count_target << " samples" << endl; 00189 for (int m = 0; m < expected_count - count_target; m++) { 00190 bool can_swap = false; 00191 int to_swap = -1; 00192 // Find a sample to swap in the current window. 00193 while (!can_swap) { 00194 to_swap = int(uniform_sample() * first_pass_step); 00195 if (dataset->get(shuffled[k + to_swap], tc) != t) { 00196 can_swap = true; 00197 } 00198 } 00199 to_swap += k; 00200 // Find a sample to swap in the next samples. 00201 int next = k + first_pass_step - 1; 00202 can_swap = false; 00203 while (!can_swap) { 00204 next++; 00205 if (next >= shuffled.length()) { 00206 next = 0; 00207 } 00208 if (dataset->get(shuffled[next], tc) == t) { 00209 can_swap = true; 00210 } 00211 } 00212 // And swap baby! 00213 int tmp = shuffled[next]; 00214 shuffled[next] = shuffled[to_swap]; 00215 shuffled[to_swap] = tmp; 00216 } 00217 } 00218 } 00219 // Second pass: ensure there aren't too many. 00220 int second_pass_step = int(step / force_proportion + 0.5); 00221 for (k = 0; k < shuffled.length(); k += second_pass_step) { 00222 int count_target = 0; 00223 for (l = k; l < k + second_pass_step && l < shuffled.length(); l++) { 00224 if (dataset->get(shuffled[l], tc) == count_target) { 00225 count_target++; 00226 } 00227 } 00228 if (l - k == second_pass_step && count_target > expected_count) { 00229 // Too many, need to remove some. 00230 ok = false; 00231 PLWARNING("In RepeatSplitter::build_ - The code reached hasn't been tested yet"); 00232 // cout << "At l = " << l << ", need to remove " << - expected_count + count_target << " samples" << endl; 00233 for (int m = 0; m < - expected_count + count_target; m++) { 00234 bool can_swap = false; 00235 int to_swap = k - 1; 00236 // Find a sample to swap in the current window. 00237 while (!can_swap) { 00238 to_swap++; 00239 if (dataset->get(shuffled[to_swap], tc) == t) { 00240 can_swap = true; 00241 } 00242 } 00243 // Find a sample to swap in the next samples. 00244 int next = k + first_pass_step - 1; 00245 can_swap = false; 00246 while (!can_swap) { 00247 next++; 00248 if (next >= shuffled.length()) { 00249 next = 0; 00250 } 00251 if (dataset->get(shuffled[next], tc) != t) { 00252 can_swap = true; 00253 } 00254 } 00255 // And swap baby! 00256 int tmp = shuffled[next]; 00257 shuffled[next] = shuffled[to_swap]; 00258 shuffled[to_swap] = tmp; 00259 } 00260 } 00261 } 00262 } 00263 it++; 00264 } 00265 } 00266 } 00267 indices(i) << shuffled; 00268 } 00269 } else { 00270 indices = TMat<int>(); 00271 } 00272 last_n = -1; 00273 } 00274 00276 // makeDeepCopyFromShallowCopy // 00278 void RepeatSplitter::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies) 00279 { 00280 Splitter::makeDeepCopyFromShallowCopy(copies); 00281 00282 // ### Call deepCopyField on all "pointer-like" fields 00283 // ### that you wish to be deepCopied rather than 00284 // ### shallow-copied. 00285 // ### ex: 00286 // deepCopyField(trainvec, copies); 00287 00288 deepCopyField(to_repeat, copies); 00289 00290 } 00291 00293 // getSplit // 00295 TVec<VMat> RepeatSplitter::getSplit(int k) 00296 { 00297 int n_splits = this->nsplits(); 00298 if (k >= n_splits) { 00299 PLERROR("In RepeatSplitter::getSplit: split asked is too high"); 00300 } 00301 int child_splits = to_repeat->nsplits(); 00302 int real_k = k % child_splits; 00303 if (shuffle && dataset) { 00304 int shuffle_indice = k / child_splits; 00305 if (shuffle_indice != last_n) { 00306 // We have to reshuffle the dataset, according to indices. 00307 VMat m = new SelectRowsVMatrix(dataset, indices(shuffle_indice)); 00308 to_repeat->setDataSet(m); 00309 last_n = shuffle_indice; 00310 } 00311 } 00312 return to_repeat->getSplit(real_k); 00313 } 00314 00316 // nSetsPerSplit // 00318 int RepeatSplitter::nSetsPerSplit() const 00319 { 00320 return to_repeat->nSetsPerSplit(); 00321 } 00322 00324 // nsplits // 00326 int RepeatSplitter::nsplits() const 00327 { 00328 return to_repeat->nsplits() * n; 00329 } 00330 00332 // setDataSet // 00334 void RepeatSplitter::setDataSet(VMat the_dataset) { 00335 inherited::setDataSet(the_dataset); 00336 to_repeat->setDataSet(the_dataset); 00337 build(); // necessary to recompute the indices. 00338 } 00339 00340 } // end of namespace PLearn

Generated on Tue Aug 17 16:04:08 2004 for PLearn by doxygen 1.3.7