Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

DisplayUtils.cc

Go to the documentation of this file.
00001 // -*- C++ -*- 00002 00003 // PLearn (A C++ Machine Learning Library) 00004 // Copyright (C) 1998 Pascal Vincent 00005 // Copyright (C) 1999-2002 Pascal Vincent, Yoshua Bengio and University of Montreal 00006 // 00007 00008 // Redistribution and use in source and binary forms, with or without 00009 // modification, are permitted provided that the following conditions are met: 00010 // 00011 // 1. Redistributions of source code must retain the above copyright 00012 // notice, this list of conditions and the following disclaimer. 00013 // 00014 // 2. Redistributions in binary form must reproduce the above copyright 00015 // notice, this list of conditions and the following disclaimer in the 00016 // documentation and/or other materials provided with the distribution. 00017 // 00018 // 3. The name of the authors may not be used to endorse or promote 00019 // products derived from this software without specific prior written 00020 // permission. 00021 // 00022 // THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 00023 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 00024 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN 00025 // NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00026 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 00027 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 00028 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 00029 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 00030 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 00031 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00032 // 00033 // This file is part of the PLearn library. For more information on the PLearn 00034 // library, go to the PLearn Web site at www.plearn.org 00035 00036 00037 00038 00039 /* ******************************************************* 00040 * $Id: DisplayUtils.cc,v 1.6 2004/07/21 16:30:51 chrish42 Exp $ 00041 * AUTHORS: Pascal Vincent & Yoshua Bengio 00042 * This file is part of the PLearn library. 00043 ******************************************************* */ 00044 00045 #include "DisplayUtils.h" 00046 #include <plearn/io/TmpFilenames.h> 00047 #include <strstream> 00048 00049 #ifdef WIN32 00050 #include <io.h> 00051 #define unlink _unlink 00052 #endif 00053 00054 namespace PLearn { 00055 using namespace std; 00056 00057 void displayHistogram(Gnuplot& gp, Mat dataColumn, 00058 int n_bins, Vec* pbins, 00059 bool regular_bins, 00060 bool normalized, string extra_args) 00061 { 00062 Vec sorted_data = dataColumn.toVecCopy(); 00063 sortElements(sorted_data); 00064 int n=sorted_data.length(); 00065 real minv = sorted_data[0]; 00066 real maxv = sorted_data[n-1]; 00067 00068 // compute "bins" vector, which specifies histogram intervals 00069 // [minv, bins[0]), [bins[0],bins[1]), ... [bin[n_bins-2],maxv] 00070 Vec bins; 00071 if (pbins) 00072 { 00073 bins = *pbins; 00074 n_bins = bins.length()+1; 00075 } 00076 else 00077 { 00078 if (n_bins==0) 00079 n_bins = MIN(5+n/10,1000); 00080 bins.resize(n_bins-1); 00081 00082 // fill the bins 00083 if (regular_bins) 00084 { 00085 real delta = (maxv-minv)/n_bins; 00086 real v = minv+delta; 00087 real* b=bins.data(); 00088 for (int i=0;i<n_bins-1;i++,v+=delta) b[i]=v; 00089 } 00090 else 00091 { 00092 real n_expected_per_bin = n/(real)n_bins; 00093 int current_bin=0; 00094 real* v=sorted_data.data(); 00095 real* b=bins.data(); 00096 real previous = 1e30; 00097 int n_repeat = 0; 00098 int previous_n_repeat = 0; 00099 int first_of_mass_point = 0; 00100 for (int i=0;i<n;i++) 00101 { 00102 if (previous==v[i]) 00103 { 00104 if (previous_n_repeat==0) first_of_mass_point = i-1; 00105 n_repeat++; 00106 } 00107 else 00108 n_repeat=0; 00109 if (n_repeat==0 && current_bin < n_bins-1) // put a left_side at i only if v[i]!=v[i-1] 00110 { 00111 if (previous_n_repeat==0) 00112 { 00113 if (i+1 >= n_expected_per_bin*(1+current_bin)) 00114 b[current_bin++]=v[i]; 00115 } 00116 else 00117 { 00118 if (n_repeat/(real)n > n_expected_per_bin) 00119 { 00120 if (current_bin>0 && b[current_bin-1] < v[first_of_mass_point]) 00121 b[current_bin++]=v[first_of_mass_point]; 00122 if (current_bin < n_bins-1) 00123 b[current_bin++]=v[i]; 00124 } 00125 else 00126 if (i+1 >= n_expected_per_bin*(1+current_bin)) 00127 b[current_bin++]=v[i]; 00128 } 00129 } 00130 previous = v[i]; 00131 previous_n_repeat = n_repeat; 00132 } 00133 } 00134 } 00135 00136 // fill histogram vector with counts in each interval: 00137 // first column is the left border of each bin, 2nd is the count 00138 Mat histogram(n_bins+1,2); 00139 real* left_side = &histogram(0,0); 00140 real* frequency = left_side+1; 00141 real* b = bins.data(); 00142 real* v=sorted_data.data(); 00143 int current_bin=0; 00144 real left = minv; 00145 for (int i=0;i<n;i++) 00146 { 00147 if (current_bin<n_bins-1 && 00148 v[i]>=b[current_bin]) 00149 { 00150 left_side[2*current_bin]=left; 00151 left = v[i]; 00152 current_bin++; 00153 } 00154 frequency[2*current_bin]++; 00155 } 00156 left_side[2*current_bin]=left; 00157 left_side[2*n_bins]=maxv+(maxv-minv)/n; 00158 real norm_factor = normalized? (1.0/n) : 1.0; 00159 for (int i=0;i<n_bins;i++) 00160 { 00161 real deltax = left_side[2*(i+1)]-left_side[2*i]; 00162 if (deltax==0) { PLWARNING("displayHistogram: 0 deltax!"); deltax=1.0; } 00163 frequency[i*2] *= norm_factor/deltax; 00164 } 00165 00166 histogram(n_bins,1)=histogram(n_bins-1,1); 00167 00168 // display the histogram 00169 string comm = string(" with steps")+extra_args; 00170 gp.plot(histogram,comm.c_str()); 00171 } 00172 00173 00174 00177 void displayVarGraph(const VarArray& outputs, bool display_values, real boxwidth, const char* the_filename, bool must_wait, VarArray display_only_these) 00178 { 00179 // parameters controlling appearance... 00180 real deltay = 100; 00181 real boxheight = 50; 00182 00183 char filename[100]; 00184 if(the_filename) 00185 strcpy(filename, the_filename); 00186 else 00187 { 00188 TmpFilenames tmpnam; 00189 strcpy(filename, tmpnam.addFilename().c_str()); 00190 } 00191 00192 multimap<real,Var> layers; 00193 typedef multimap<real,Var>::iterator mmit; 00194 00195 Mat center(Variable::nvars+1,2); 00196 center.fill(FLT_MAX); 00197 00198 int n_display_only_these = display_only_these.size(); 00199 bool display_all = n_display_only_these==0; 00200 00201 // find sources of outputs which are not in the outputs array: 00202 outputs.unmarkAncestors(); 00203 VarArray sources = outputs.sources(); 00204 outputs.unmarkAncestors(); 00205 // We dont want any source Var that is in outputs to be in sources so we remove them: 00206 outputs.setMark(); 00207 VarArray nonoutputsources; 00208 for(int i=0; i<sources.size(); i++) 00209 if(!sources[i]->isMarked() && (display_all || display_only_these.contains(sources[i]))) 00210 nonoutputsources.append(sources[i]); 00211 sources = nonoutputsources; 00212 outputs.clearMark(); 00213 00214 sources.setMark(); 00215 00216 // Place everything but the sources starting from outputs at the bottom 00217 00218 outputs.unmarkAncestors(); 00219 00220 real y = boxheight; 00221 VarArray varray = outputs; 00222 00223 while(varray.size()>0) 00224 { 00225 // varray.setMark(); // so that these don't get put in subsequent parents() calls 00226 VarArray parents; 00227 int nvars = varray.size(); 00228 for(int i=0; i<nvars; i++) 00229 { 00230 Var v = varray[i]; 00231 real old_y = center(v->varnum,1); 00232 if (old_y != FLT_MAX) // remove pair (old_y,v) from layers 00233 { 00234 pair<mmit,mmit> range = layers.equal_range(old_y); 00235 for (mmit it = range.first; it != range.second; it++) 00236 if (v->varnum == it->second->varnum) 00237 { 00238 layers.erase(it); 00239 break; 00240 } 00241 } 00242 layers.insert(pair<real,Var>(y, v)); 00243 center(v->varnum,1) = y; 00244 VarArray parents_i = v->parents(); 00245 for (int j=0;j<parents_i.size();j++) 00246 if((display_all || display_only_these.contains(parents_i[j])) && !parents.contains(parents_i[j])) 00247 parents &= parents_i[j]; 00248 } 00249 varray = parents; 00250 y += deltay; 00251 } 00252 // now place the sources 00253 int nvars = sources.size(); 00254 for(int i=0; i<nvars; i++) 00255 { 00256 Var v = sources[i]; 00257 real old_y = center(v->varnum,1); 00258 if (old_y != FLT_MAX) // remove pair (old_y,v) from layers 00259 { 00260 pair<mmit,mmit> range = layers.equal_range(old_y); 00261 for (mmit it = range.first; it != range.second; it++) 00262 if (v->varnum == it->second->varnum) 00263 { 00264 layers.erase(it); 00265 break; 00266 } 00267 } 00268 layers.insert(pair<real,Var>(y,v)); 00269 } 00270 real topy = y; 00271 00272 outputs.unmarkAncestors(); 00273 if (display_all) 00274 { 00275 VarArray ancestors = outputs.ancestors(); 00276 outputs.unmarkAncestors(); 00277 varray = ancestors; 00278 } 00279 else varray = display_only_these; 00280 00281 // Find the maximum number of vars in a level... 00282 int maxvarsperlevel = sources.size(); 00283 00284 for (real y=boxheight;y<=topy;y+=deltay) 00285 { 00286 pair<mmit,mmit> range = layers.equal_range(y); 00287 int nvars = (int)distance(range.first,range.second); 00288 if (maxvarsperlevel < nvars) 00289 maxvarsperlevel = nvars; 00290 } 00291 00292 real usewidth = (maxvarsperlevel+1)*(boxwidth+boxheight); 00293 00294 // Compute the bounding box: 00295 real min_x = 0; 00296 real min_y = 0; 00297 real max_x = usewidth; 00298 real max_y = topy; 00299 00300 min_x -= boxwidth/2; 00301 max_x += boxwidth/2; 00302 min_y -= boxheight/2; 00303 max_y += boxheight/2; 00304 00305 for (real y=boxheight;y<=topy;y+=deltay) 00306 { 00307 pair<mmit,mmit> range = layers.equal_range(y); 00308 int nvars = (int)distance(range.first,range.second); 00309 real deltax = usewidth/(nvars+1); 00310 real x = deltax; 00311 for (mmit it = range.first; it != range.second; it++, x+=deltax) 00312 { 00313 Var v = it->second; 00314 center(v->varnum,0) = x; 00315 center(v->varnum,1) = y; 00316 } 00317 } 00318 00319 // Start outputting to the file 00320 { 00321 // make it an eps file with the computed bounding box 00322 GhostScript gs(filename,min_x,min_y,max_x,max_y); 00323 00324 // Now paint 00325 00326 // gs.setlinewidth(1.0); 00327 00328 for (real y=boxheight;y<=topy;y+=deltay) 00329 { 00330 pair<mmit,mmit> range = layers.equal_range(y); 00331 int nvars = (int)distance(range.first,range.second); 00332 real deltax = usewidth/(nvars+1); 00333 real x = deltax; 00334 for (mmit it = range.first; it != range.second; it++, x+=deltax) 00335 { 00336 Var v = it->second; 00337 real my_x = x; 00338 real my_y = y; 00339 00340 // Display v 00341 gs.drawBox(my_x-boxwidth/2, my_y-boxheight/2, boxwidth, boxheight); 00342 char nameline[100]; 00343 sprintf(nameline,"%s (%d,%d)",v->getName().c_str(), v->matValue.length(), v->matValue.width()); 00344 00345 char buf[200]; 00346 ostrstream descr(buf,200); 00347 v->print(descr); 00348 descr << ends; 00349 00350 if(display_values && v->size() <= 16) 00351 { 00352 gs.usefont("Times-Bold", 11.0); 00353 gs.centerShow(my_x, my_y+boxheight/4, descr.str()); 00354 gs.usefont("Times-Roman", 10.0); 00355 gs.centerShow(my_x, my_y, nameline); 00356 gs.usefont("Courrier", 6.0); 00357 if (v->rValue.length()>0) // print rvalue if there are some... 00358 { 00359 gs.centerShow(my_x, my_y-boxheight/5, v->value); 00360 gs.centerShow(my_x, my_y-boxheight/3, v->gradient); 00361 gs.centerShow(my_x, my_y-boxheight/1, v->rValue); 00362 } 00363 else 00364 { 00365 gs.centerShow(my_x, my_y-boxheight/5, v->value); 00366 gs.centerShow(my_x, my_y-boxheight/2.5, v->gradient); 00367 } 00368 /* 00369 cout << descr.str() << " " << nameline << " (" << v->value.length() << ")" << endl; 00370 cout << "value: " << v->value << endl; 00371 cout << "gradient: " << v->gradient << endl; 00372 */ 00373 } 00374 else 00375 { 00376 gs.usefont("Times-Bold", 12.0); 00377 gs.centerShow(my_x, my_y+boxheight/4, descr.str()); 00378 gs.usefont("Times-Roman", 11.0); 00379 gs.centerShow(my_x, my_y-boxheight/4, nameline); 00380 } 00381 00382 // Display the arrows from the parents 00383 VarArray parents = v->parents(); 00384 int nparents = parents.size(); 00385 for(int p=0; p<nparents; p++) 00386 { 00387 Var parent = parents[p]; 00388 if (display_all || display_only_these.contains(parent)) 00389 { 00390 real parent_x = center(parent->varnum,0); 00391 real parent_y = center(parent->varnum,1); 00392 00393 gs.drawArrow(parent_x, parent_y-boxheight/2, 00394 my_x+0.75*boxwidth*(real(p+1)/real(nparents+1)-0.5), 00395 my_y+boxheight/2); 00396 } 00397 } 00398 } 00399 } 00400 outputs.unmarkAncestors(); 00401 } 00402 char command[1000]; 00403 if (must_wait) 00404 sprintf(command,"gv %s",filename); 00405 else 00406 sprintf(command,"gv %s &",filename); 00407 00408 system(command); 00409 00410 if(the_filename==0 && must_wait) 00411 unlink(filename); 00412 } 00413 00414 void OldDisplayVarGraph(const VarArray& outputs, bool display_values, real boxwidth, const char* the_filename, bool must_wait, VarArray display_only_these) 00415 { 00416 // parameters controlling appearance... 00417 real deltay = 100; 00418 real boxheight = 50; 00419 00420 char filename[100]; 00421 if(the_filename) 00422 strcpy(filename, the_filename); 00423 else 00424 { 00425 TmpFilenames tmpnam; 00426 strcpy(filename, tmpnam.addFilename().c_str()); 00427 } 00428 00429 Mat center(Variable::nvars+1,2); 00430 center.fill(FLT_MAX); 00431 00432 int n_display_only_these = display_only_these.size(); 00433 bool display_all = n_display_only_these==0; 00434 00435 // find sources of outputs which are not in the outputs array: 00436 outputs.unmarkAncestors(); 00437 VarArray sources = outputs.sources(); 00438 outputs.unmarkAncestors(); 00439 // We dont want any source Var that is in outputs to be in sources so we remove them: 00440 outputs.setMark(); 00441 VarArray nonoutputsources; 00442 for(int i=0; i<sources.size(); i++) 00443 if(!sources[i]->isMarked() && (display_all || display_only_these.contains(sources[i]))) 00444 nonoutputsources.append(sources[i]); 00445 sources = nonoutputsources; 00446 outputs.clearMark(); 00447 00448 // Find the maximum number of vars in a level... 00449 int maxvarsperlevel = sources.size(); 00450 sources.setMark(); 00451 VarArray varray = outputs; 00452 while(varray.size()>0) 00453 { 00454 if(varray.size()>maxvarsperlevel) 00455 maxvarsperlevel = varray.size(); 00456 varray.setMark(); // so that these don't get put in subsequent parents() calls 00457 VarArray parents; 00458 for(int i=0; i<varray.size(); i++) 00459 parents &= varray[i]->parents(); 00460 varray = VarArray(); 00461 for (int i=0;i<parents.size();i++) 00462 if(display_all || display_only_these.contains(parents[i])) 00463 varray &= parents[i]; 00464 } 00465 sources.setMark(); 00466 00467 real usewidth = (maxvarsperlevel+1)*(boxwidth+boxheight); 00468 00469 // Place everything but the sources starting from outputs at the bottom 00470 00471 outputs.unmarkAncestors(); 00472 00473 real y = boxheight; 00474 varray = outputs; 00475 00476 while(varray.size()>0) 00477 { 00478 // varray.setMark(); // so that these don't get put in subsequent parents() calls 00479 VarArray parents; 00480 int nvars = varray.size(); 00481 for(int i=0; i<nvars; i++) 00482 { 00483 Var v = varray[i]; 00484 center(v->varnum,0) = usewidth*(i+1)/(nvars+1); 00485 center(v->varnum,1) = y; 00486 // bool marked = v->isMarked(); 00487 // v->clearMark(); 00488 VarArray parents_i = v->parents(); 00489 for (int j=0;j<parents_i.size();j++) 00490 if((display_all || display_only_these.contains(parents_i[j])) && !parents.contains(parents_i[j])) 00491 parents &= parents_i[j]; 00492 } 00493 varray = parents; 00494 y += deltay; 00495 } 00496 // now place the sources 00497 int nvars = sources.size(); 00498 for(int i=0; i<nvars; i++) 00499 { 00500 Var v = sources[i]; 00501 center(v->varnum,0) = usewidth*(i+1)/(nvars+1); 00502 center(v->varnum,1) = y; 00503 } 00504 00505 outputs.unmarkAncestors(); 00506 if (display_all) 00507 { 00508 VarArray ancestors = outputs.ancestors(); 00509 outputs.unmarkAncestors(); 00510 varray = ancestors; 00511 } 00512 else varray = display_only_these; 00513 00514 // Compute the bounding box: 00515 real min_x = FLT_MAX; 00516 real min_y = FLT_MAX; 00517 real max_x = -FLT_MAX; 00518 real max_y = -FLT_MAX; 00519 00520 for(int i=0; i<varray.size(); i++) 00521 { 00522 Var v = varray[i]; 00523 real x = center(v->varnum,0); 00524 real y = center(v->varnum,1); 00525 if(x<min_x) 00526 min_x = x; 00527 if(y<min_y) 00528 min_y = y; 00529 if(x>max_x) 00530 max_x = x; 00531 if(y>max_y) 00532 max_y = y; 00533 } 00534 min_x -= boxwidth/2; 00535 max_x += boxwidth/2; 00536 min_y -= boxheight/2; 00537 max_y += boxheight/2; 00538 00539 // Start outputting to the file 00540 { 00541 // make it an eps file with the computed bounding box 00542 GhostScript gs(filename,min_x,min_y,max_x,max_y); 00543 00544 // Now paint 00545 00546 // gs.setlinewidth(1.0); 00547 00548 for(int i=0; i<varray.size(); i++) 00549 { 00550 Var v = varray[i]; 00551 real my_x = center(v->varnum,0); 00552 real my_y = center(v->varnum,1); 00553 00554 // Display v 00555 gs.drawBox(my_x-boxwidth/2, my_y-boxheight/2, boxwidth, boxheight); 00556 char nameline[100]; 00557 sprintf(nameline,"%s (%d,%d)",v->getName().c_str(), v->matValue.length(), v->matValue.width()); 00558 00559 char buf[200]; 00560 ostrstream descr(buf,200); 00561 v->print(descr); 00562 descr << ends; 00563 00564 if(display_values) 00565 { 00566 gs.usefont("Times-Bold", 11.0); 00567 gs.centerShow(my_x, my_y+boxheight/4, descr.str()); 00568 gs.usefont("Times-Roman", 10.0); 00569 gs.centerShow(my_x, my_y, nameline); 00570 gs.usefont("Courrier", 6.0); 00571 gs.centerShow(my_x, my_y-boxheight/5, v->value); 00572 gs.centerShow(my_x, my_y-boxheight/2.5, v->gradient); 00573 } 00574 else 00575 { 00576 gs.usefont("Times-Bold", 12.0); 00577 gs.centerShow(my_x, my_y+boxheight/4, descr.str()); 00578 gs.usefont("Times-Roman", 11.0); 00579 gs.centerShow(my_x, my_y-boxheight/4, nameline); 00580 } 00581 00582 // Display the arrows from the parents 00583 VarArray parents = v->parents(); 00584 int nparents = parents.size(); 00585 for(int p=0; p<nparents; p++) 00586 { 00587 Var parent = parents[p]; 00588 if (display_all || display_only_these.contains(parent)) 00589 { 00590 real parent_x = center(parent->varnum,0); 00591 real parent_y = center(parent->varnum,1); 00592 00593 gs.drawArrow(parent_x, parent_y-boxheight/2, 00594 my_x+0.75*boxwidth*(real(p+1)/real(nparents+1)-0.5), 00595 my_y+boxheight/2); 00596 } 00597 } 00598 } 00599 outputs.unmarkAncestors(); 00600 } 00601 00602 char command[1000]; 00603 if (must_wait) 00604 sprintf(command,"gv %s",filename); 00605 else 00606 sprintf(command,"gv %s &",filename); 00607 00608 system(command); 00609 00610 if(the_filename==0) 00611 unlink(filename); 00612 } 00613 00614 void displayFunction(Func f, bool display_values, bool display_differentiation, real boxwidth, const char* the_filename, bool must_wait) 00615 { 00616 if(display_differentiation) 00617 displayVarGraph(f->outputs & f->differentiate()->outputs, display_values, boxwidth, the_filename, must_wait); 00618 else 00619 displayVarGraph(f->outputs, display_values, boxwidth, the_filename, must_wait); 00620 } 00621 00622 Mat compute2dGridOutputs(Learner& learner, real min_x, real max_x, real min_y, real max_y, int length, int width, real singleoutput_threshold) 00623 { 00624 Mat m(length,width); 00625 real delta_x = (max_x-min_x)/(width-1); 00626 real delta_y = (max_y-min_y)/(length-1); 00627 00628 if(learner.inputsize()!=2 || (learner.outputsize()!=1 && learner.outputsize()!=2) ) 00629 PLERROR("learner is expected to have an inputsize of 2, and an outputsize of 1 (or possibly 2 for binary classification)"); 00630 00631 Vec input(2); 00632 Vec output(learner.outputsize()); 00633 for(int i=0; i<length; i++) 00634 { 00635 input[1] = min_y+(length-i-1)*delta_y; 00636 for(int j=0; j<width; j++) 00637 { 00638 input[0] = min_x+j*delta_x; 00639 learner.use(input,output); 00640 if(learner.outputsize()==2) 00641 m(i,j) = output[0]-output[1]; 00642 else 00643 m(i,j) = output[0]-singleoutput_threshold; 00644 } 00645 } 00646 return m; 00647 } 00648 00649 void displayPoints(GhostScript& gs, Mat data, real radius, bool color) 00650 { 00651 for(int i=0; i<data.length(); i++) 00652 { 00653 Vec point = data(i); 00654 if(color) 00655 { 00656 if(point[2]<=0.0) 00657 gs.setcolor(1.0,0.0,0.0); 00658 else 00659 gs.setcolor(0.0,0.0,1.0); 00660 gs.drawCross(point[0], point[1], radius); 00661 } 00662 else 00663 gs.drawCross(point[0], point[1], radius, point[2]<=0); 00664 } 00665 } 00666 00667 /* 00668 // Old version based on a Classifier 00669 void displayDecisionSurface(GhostScript& gs, Classifier& cl, real xmin, real xmax, int nxsamples, real ymin, real ymax, int nysamples) 00670 { 00671 Vec input(2); 00672 Vec scores(1); 00673 Mat bm(nysamples,nxsamples); 00674 00675 for(int i=0; i<nysamples; i++) 00676 for(int j=0; j<nxsamples; j++) 00677 { 00678 input[0] = xmin+(xmax-xmin)/(nxsamples-1)*j; 00679 input[1] = ymax-(ymax-ymin)/(nysamples-1)*i; 00680 cl.use(input,scores); 00681 // cerr << scores[0] << "| "; 00682 real r,g,b; 00683 if(scores[0]<=0.5) 00684 r(i,j) = scores[0]*1.8; 00685 else 00686 b(i,j) = (1.0-scores[0])*1.8; 00687 } 00688 gs.gsave(); 00689 gs.translate(xmin,ymin); 00690 gs.scale((xmax-xmin)/nxsamples, (ymax-ymin)/nysamples); 00691 gs.displayRGB(0,0,r,g,b); 00692 gs.grestore(); 00693 } 00694 */ 00695 00696 void displayDecisionSurface(GhostScript& gs, real destx, real desty, real destwidth, real destheight, 00697 Learner& learner, Mat trainset, 00698 Vec svindexes, Vec outlierindexes, int nextsvindex, 00699 real min_x, real max_x, real min_y, real max_y, 00700 real radius, 00701 int nx, int ny) 00702 { 00703 gs.gsave(); 00704 real scalefactor = (max_x-min_x)/destwidth; 00705 gs.mapping(min_x,min_y,max_x-min_x,max_y-min_y, destx, desty, destwidth, destheight); 00706 gs.setlinewidth(1.0*scalefactor); 00707 00708 real singleoutput_threshold = 0.; 00709 if(learner.outputsize()==1) 00710 { 00711 Mat targets = trainset.column(learner.inputsize()); 00712 singleoutput_threshold = 0.5*(min(targets)+max(targets)); 00713 } 00714 Mat decisions = compute2dGridOutputs(learner, min_x, max_x, min_y, max_y, ny, nx, singleoutput_threshold); 00715 00716 //real posrange = max(decisions); 00717 //real negrange = min(decisions); 00718 00719 for(int i=0; i<ny; i++) 00720 for(int j=0; j<nx; j++) 00721 { 00722 decisions(i,j) = (decisions(i,j)<0. ? 0.75 : 1.0); 00723 /* 00724 if(decisions(i,j) < 0.0) 00725 decisions(i,j) = 1.0-.5*decisions(i,j)/negrange; 00726 else 00727 decisions(i,j) = .5+.5*decisions(i,j)/posrange; 00728 */ 00729 } 00730 00731 gs.displayGray(decisions,min_x,min_y,max_x-min_x,max_y-min_y); 00732 00733 // draw x and + 00734 displayPoints(gs, trainset, radius, false); 00735 00736 // draw black circles around support vectors 00737 for(int k=0; k<svindexes.length(); k++) 00738 { 00739 real x = trainset(int(svindexes[k]),0); 00740 real y = trainset(int(svindexes[k]),1); 00741 // cerr << "{" << x << "," << y << "}"; 00742 gs.drawCircle(x,y,radius); 00743 } 00744 // cerr << endl; 00745 00746 // draw half radius circle around next support vector 00747 if(nextsvindex>=0) 00748 { 00749 real x = trainset(nextsvindex,0); 00750 real y = trainset(nextsvindex,1); 00751 gs.drawCircle(x,y,radius/2); 00752 } 00753 00754 // draw white circles around outliers 00755 Vec dashpattern(2,4.0*scalefactor); 00756 gs.setdash(dashpattern); 00757 for(int k=0; k<outlierindexes.length(); k++) 00758 { 00759 real x = trainset(int(outlierindexes[k]),0); 00760 real y = trainset(int(outlierindexes[k]),1); 00761 gs.drawCircle(x,y,radius); 00762 } 00763 00764 gs.grestore(); 00765 } 00766 00767 #ifdef WIN32 00768 #undef unlink 00769 #endif 00770 00771 } // end of namespace PLearn

Generated on Tue Aug 17 15:51:24 2004 for PLearn by doxygen 1.3.7