Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

ProbabilitySparseMatrix.h

Go to the documentation of this file.
00001 #ifndef PROBABILITY_SPARSE_MATRIX_H 00002 #define PROBABILITY_SPARSE_MATRIX_H 00003 00004 #include <plearn/math/RowMapSparseMatrix.h> 00005 #include <plearn/base/Object.h> 00006 #include "Set.h" 00007 00008 #define NUMWIDTH 10 00009 00010 namespace PLearn { 00011 00012 //typedef map<int, real>::const_iterator RowIterator; 00013 typedef map<int, real> SparseVec; 00014 typedef const map<int, real> ConstSparseVec; 00015 00016 // Represent a sparse probability matrix P(Y=y|X=x). 00017 // On the non-filled entries, there are two possible values: zero and 00018 // "undefined". The latter would be correct if no instance 00019 // of X=x has ever been seen. This is represented in output as -1 00020 // (or a warning is raised) but does not take space internally 00021 // (it occurs when the whole column x is empty). 00022 // Unlike RowMapSparseMatrix, this class clearly 00023 // distinguishes between a read access and a write access 00024 // to avoid inefficient creation of elements when reading. 00025 class ProbabilitySparseMatrix : public Object 00026 { 00027 private: 00028 RowMapSparseMatrix<real> x2y; 00029 RowMapSparseMatrix<real> y2x; 00030 public: 00031 00032 Set Y; 00033 Set X; 00034 00035 bool raise_error; 00036 string name; 00037 00038 ProbabilitySparseMatrix(int ny=0, int nx=0, string pname = "pYX") : raise_error(true) 00039 { 00040 resize(ny,nx); 00041 name = pname; 00042 } 00043 void resize(int ny, int nx) { 00044 x2y.resize(nx,ny); 00045 y2x.resize(ny,nx); 00046 } 00047 int nx() { return x2y.length(); } 00048 int ny() { return y2x.length(); } 00049 00050 void rename(string new_name) { name = new_name; } 00051 00052 Set computeX() 00053 { 00054 X.clear(); 00055 for (int xx = 0; xx < nx(); xx++) 00056 { 00057 if (sumPYx(xx) > 0) 00058 { 00059 X.insert(xx); 00060 } 00061 } 00062 return X; 00063 } 00064 00065 Set computeY() 00066 { 00067 Y.clear(); 00068 for (int yy = 0; yy < ny(); yy++) 00069 { 00070 if (sumPyX(yy) > 0) 00071 { 00072 Y.insert(yy); 00073 } 00074 } 00075 return Y; 00076 } 00077 00078 real get(int y,int x) { 00079 map<int,real>& mx = x2y(x); 00080 if (mx.size()==0) 00081 { 00082 if (raise_error) 00083 //PLERROR("trying to access an invalid probability at P(%d|%d) in %s",y,x, name); 00084 PLWARNING("trying to access an invalid probability at P(%d|%d) in %s",y,x, name.c_str()); 00085 return 0; 00086 } 00087 if (mx.find(y)!=mx.end()) 00088 return mx[y]; 00089 return 0; 00090 } 00091 00092 bool exists(int y, int x) { 00093 map<int,real>& mx = x2y(x); 00094 if (mx.size()==0) return false; 00095 return mx.find(y)!=mx.end(); 00096 } 00097 00098 // most access are for reading, allow operator() for convenience 00099 real operator()(int y, int x) { return get(y,x); } 00100 void incr(int y, int x, real increment=1) 00101 { 00102 if (increment!=0) 00103 { 00104 real current_value = 0; 00105 map<int,real>& mx = x2y(x); 00106 if (mx.size()!=0) 00107 { 00108 map<int,real>::const_iterator it = mx.find(y); 00109 if (it!=mx.end()) 00110 current_value = it->second; 00111 } 00112 set(y,x,current_value+increment); 00113 } 00114 } 00115 const map<int,real>& getPYx(int x, bool dont_raise_error=false) // const is to force user to call set for writing 00116 { 00117 const map<int,real>& PYx = x2y(x); 00118 if (raise_error && !dont_raise_error && PYx.size()==0) 00119 PLERROR("ProbabilitySparseMatrix::getPyx: accessing an empty column at X=%d",x); 00120 return PYx; 00121 } 00122 const map<int,real>& getPyX(int y) // const is to force user to call set for writing 00123 { 00124 return y2x(y); 00125 } 00126 00127 map<int,real> getPYxCopy(int x, bool dont_raise_error=false) 00128 { 00129 map<int,real> PYx = x2y(x); 00130 if (raise_error && !dont_raise_error && PYx.size()==0) 00131 PLERROR("ProbabilitySparseMatrix::getPyx: accessing an empty column at X=%d",x); 00132 return PYx; 00133 } 00134 00135 map<int,real> getPyXCopy(int y) 00136 { 00137 map<int, real> pyX = y2x(y); 00138 return pyX; 00139 } 00140 00141 void setPYx(int x, const map<int, real>& pYx) 00142 { 00143 for (map<int, real>::const_iterator it = pYx.begin(); it != pYx.end(); ++it) 00144 set(it->first, x, it->second); 00145 } 00146 00147 void setPyX(int y, const map<int, real>& pyX) 00148 { 00149 for (map<int, real>::const_iterator it = pyX.begin(); it != pyX.end(); ++it) 00150 set(y, it->first, it->second); 00151 } 00152 00153 void set(int y,int x,real v, bool dont_warn_for_zero = false) { 00154 if (v!=0) 00155 { 00156 x2y(x,y)=v; 00157 y2x(y,x)=v; 00158 } else 00159 { 00160 if (!dont_warn_for_zero) 00161 PLWARNING("setting something to 0 in ProbabilitySparseMatrix %s", name.c_str()); 00162 map<int,real>& PYx = x2y(x); 00163 if (PYx.find(y)!=PYx.end()) 00164 { 00165 PYx.erase(y); 00166 map<int,real>& PyX = y2x(y); 00167 PyX.erase(x); 00168 } 00169 } 00170 } 00171 00172 void removeElem(int y, int x) 00173 { 00174 set(y, x, 0.0, true); 00175 } 00176 00177 real sumPYx(int x, Set Y) 00178 { 00179 real sum_pYx = 0.0; 00180 map<int, real>& col = x2y(x); 00181 for (map<int, real>::const_iterator yit = col.begin(); yit != col.end(); ++yit) 00182 { 00183 int y = yit->first; 00184 if (Y.contains(y)) 00185 sum_pYx += yit->second; 00186 } 00187 return sum_pYx; 00188 } 00189 00190 real sumPyX(int y, Set X) 00191 { 00192 real sum_pyX = 0.0; 00193 map<int, real>& row = y2x(y); 00194 for (map<int, real>::const_iterator xit = row.begin(); xit != row.end(); ++xit) 00195 { 00196 int x = xit->first; 00197 if (X.contains(x)) 00198 sum_pyX += xit->second; 00199 } 00200 return sum_pyX; 00201 } 00202 00203 real sumPYx(int x) 00204 { 00205 real sum_pYx = 0.0; 00206 map<int, real>& col = x2y(x); 00207 for (map<int, real>::const_iterator yit = col.begin(); yit != col.end(); ++yit) 00208 { 00209 sum_pYx += yit->second; 00210 } 00211 return sum_pYx; 00212 } 00213 00214 real sumPyX(int y) 00215 { 00216 real sum_pyX = 0.0; 00217 map<int, real>& row = y2x(y); 00218 for (map<int, real>::const_iterator xit = row.begin(); xit != row.end(); ++xit) 00219 { 00220 sum_pyX += xit->second; 00221 } 00222 return sum_pyX; 00223 } 00224 00225 // DEPRECATED 00226 // put existing elements to 0 without removing them 00227 void clearElements() { 00228 for (int x=0;x<nx();x++) 00229 { 00230 map<int,real>& r = x2y(x); 00231 for (map<int,real>::iterator it = r.begin(); it!=r.end(); ++it) 00232 { 00233 it->second=0; 00234 y2x(it->first,x)=0; 00235 } 00236 } 00237 } 00238 00239 // release all elements in the maps 00240 void clear() { x2y.clear(); y2x.clear(); } 00241 00242 void removeRow(int y, Set X) { 00243 y2x(y).clear(); 00244 for (SetIterator it=X.begin();it!=X.end();++it) 00245 { 00246 int x = *it; 00247 map<int,real>& pYx = x2y(x); 00248 if (pYx.size()>0) 00249 pYx.erase(y); 00250 } 00251 } 00252 00253 void removeRow(int y) 00254 { 00255 map<int, real>& row = y2x(y); 00256 for (map<int, real>::iterator it = row.begin(); it != row.end(); ++it) 00257 { 00258 int x = it->first; 00259 map<int,real>& Yx = x2y(x); 00260 if (Yx.size()>0) 00261 Yx.erase(y); 00262 } 00263 row.clear(); 00264 } 00265 00266 void removeColumn(int x) 00267 { 00268 map<int, real>& col = x2y(x); 00269 for (map<int, real>::iterator it = col.begin(); it != col.end(); ++it) 00270 { 00271 int y = it->first; 00272 map<int,real>& yX = y2x(y); 00273 if (yX.size()>0) 00274 yX.erase(x); 00275 } 00276 col.clear(); 00277 } 00278 00279 int size() 00280 { 00281 if (x2y.size() != y2x.size()) 00282 PLWARNING("x2y and y2x sizes dont match"); 00283 return y2x.size(); 00284 } 00285 00286 void removeExtra(ProbabilitySparseMatrix& m) 00287 { 00288 //for (SetIterator yit = Y.begin(); yit != Y.end(); ++yit) 00289 int _ny = ny(); 00290 Set x_to_remove; 00291 for (int y = 0; y < _ny; y++) 00292 { 00293 //int y = *yit; 00294 ConstSparseVec& yX = getPyX(y); 00295 x_to_remove.clear(); 00296 for (SparseVec::const_iterator xit = yX.begin(); xit != yX.end(); ++xit) 00297 { 00298 int x = xit->first; 00299 if (!m.exists(y, x)) 00300 x_to_remove.insert(x); 00301 } 00302 for (SetIterator xit = x_to_remove.begin(); xit != x_to_remove.end(); ++xit) 00303 { 00304 int x = *xit; 00305 removeElem(y, x); 00306 } 00307 } 00308 } 00309 00310 void fullPrint() 00311 { 00312 cout << "y2x" << endl; 00313 for (int y = 0; y < ny(); y++) 00314 { 00315 for (int x = 0; x < nx(); x++) 00316 { 00317 cout << y2x(y, x) << " "; 00318 } 00319 cout << endl; 00320 } 00321 cout << "x2y" << endl; 00322 for (int x = 0; x < nx(); x++) 00323 { 00324 for (int y = 0; y < ny(); y++) 00325 { 00326 cout << x2y(x, y) << " "; 00327 } 00328 cout << endl; 00329 } 00330 } 00331 00332 void save(string filename) 00333 { 00334 y2x.save(filename + ".y2x"); 00335 x2y.save(filename + ".x2y"); 00336 } 00337 00338 void load(string filename) 00339 { 00340 y2x.load(filename + ".y2x"); 00341 x2y.load(filename + ".x2y"); 00342 } 00343 00344 real* getAsFullVector() 00345 { 00346 // a vector of triples : (row, col, value) 00347 int vector_size = y2x.size() * 3; 00348 real* full_vector = new real[vector_size]; 00349 int pos = 0; 00350 for (int i = 0; i < ny(); i++) 00351 { 00352 map<int, real>& row_i = y2x(i); 00353 for (map<int, real>::iterator it = row_i.begin(); it != row_i.end(); ++it) 00354 { 00355 int j = it->first; 00356 real value = it->second; 00357 full_vector[pos++] = (real)i; 00358 full_vector[pos++] = (real)j; 00359 full_vector[pos++] = value; 00360 } 00361 } 00362 if (pos != vector_size) 00363 PLERROR("weird"); 00364 return full_vector; 00365 } 00366 00367 void getAsMaxSizedVectors(int max_size, vector<pair<real*, int> >& vectors) 00368 { 00369 if ((max_size % 3) != 0) PLWARNING("dangerous vector size (max_size mod 3 must equal 0)"); 00370 00371 int n_elems = y2x.size() * 3; 00372 int n_vecs = n_elems / max_size; 00373 int remaining = n_elems % max_size; 00374 if (remaining > 0) // padding to get sure that last block size (= remaining) is moddable by 3 00375 { 00376 n_vecs += 1; 00377 int mod3 = remaining % 3; 00378 if (mod3 != 0) 00379 remaining += (3 - mod3); 00380 } 00381 vectors.resize(n_vecs); 00382 //cerr << "(" << name << ") size = " << n_elems << endl; 00383 for (int i = 0; i < n_vecs; i++) 00384 { 00385 if (i == (n_vecs - 1) && remaining > 0) 00386 { 00387 vectors[i].first = new real[remaining]; 00388 vectors[i].second = remaining; 00389 //cerr << "(" << name << ") i = " << i << ", size = " << remaining << endl; 00390 } else 00391 { 00392 vectors[i].first = new real[max_size]; 00393 vectors[i].second = max_size; 00394 //cerr << "(" << name << ") i = " << i << ", size = " << max_size << endl; 00395 } 00396 } 00397 int pos = 0; 00398 for (int i = 0; i < ny(); i++) 00399 { 00400 map<int, real>& row_i = y2x(i); 00401 for (map<int, real>::iterator it = row_i.begin(); it != row_i.end(); ++it) 00402 { 00403 int j = it->first; 00404 real value = it->second; 00405 vectors[pos / max_size].first[pos++ % max_size] = i; 00406 vectors[pos / max_size].first[pos++ % max_size] = j; 00407 vectors[pos / max_size].first[pos++ % max_size] = value; 00408 } 00409 } 00410 while (pos < n_elems) // pad with (0, 0, 0) 00411 { 00412 vectors[pos / max_size].first[pos++ % max_size] = 0; 00413 vectors[pos / max_size].first[pos++ % max_size] = 0; 00414 vectors[pos / max_size].first[pos++ % max_size] = 0; 00415 } 00416 } 00417 00418 void add(real* full_vector, int n_elems) 00419 { 00420 for (int i = 0; i < n_elems; i += 3) 00421 incr((int)full_vector[i], (int)full_vector[i + 1], full_vector[i + 2]); 00422 } 00423 00424 void set(real* full_vector, int n_elems) 00425 { 00426 clear(); 00427 for (int i = 0; i < n_elems; i += 3) 00428 set((int)full_vector[i], (int)full_vector[i + 1], full_vector[i + 2]); 00429 } 00430 00431 real sumOfElements() 00432 { 00433 real sum = 0.0; 00434 for (int i = 0; i < ny(); i++) 00435 { 00436 map<int, real>& row_i = y2x(i); 00437 for (map<int, real>::iterator it = row_i.begin(); it != row_i.end(); ++it) 00438 { 00439 int j = it->first; 00440 sum += get(i, j); 00441 } 00442 } 00443 return sum; 00444 } 00445 00446 }; 00447 00448 inline void samePos(ProbabilitySparseMatrix& m1, ProbabilitySparseMatrix& m2, string m1name, string m2name) 00449 { 00450 for (SetIterator yit = m1.Y.begin(); yit != m1.Y.end(); ++yit) 00451 { 00452 int y = *yit; 00453 const map<int, real>& yX = m1.getPyX(y); 00454 for (map<int, real>::const_iterator it = yX.begin(); it != yX.end(); ++it) 00455 { 00456 int x = it->first; 00457 if (!m2.exists(y, x)) 00458 PLERROR("in samePos, %s contains an element that is not present in %s", m1name.c_str(), m2name.c_str()); 00459 } 00460 } 00461 } 00462 00463 inline void check_prob(ProbabilitySparseMatrix& pYX, string Yname, string Xname) 00464 { 00465 bool failed = false; 00466 real sum_pY = 0.0; 00467 Set X = pYX.X; 00468 for (SetIterator x_it = X.begin(); x_it!=X.end(); ++x_it) 00469 { 00470 int x = *x_it; 00471 const map<int,real>& pYx = pYX.getPYx(x); 00472 sum_pY = 0.0; 00473 for (map<int,real>::const_iterator y_it=pYx.begin();y_it!=pYx.end();++y_it) 00474 { 00475 sum_pY += y_it->second; 00476 } 00477 if (fabs(sum_pY - 1.0) > 1e-4) 00478 { 00479 failed = true; 00480 break; 00481 } 00482 } 00483 if (failed) 00484 PLERROR("check_prob failed for %s -> %s (sum of a column = %g)", Xname.c_str(), Yname.c_str(), sum_pY); 00485 } 00486 00487 inline void check_prob(Set Y, const map<int, real>& pYx) 00488 { 00489 real sum_y=0; 00490 for (map<int,real>::const_iterator y_it=pYx.begin();y_it!=pYx.end();++y_it) 00491 if (Y.contains(y_it->first)) 00492 sum_y += y_it->second; 00493 if (fabs(sum_y-1)>1e-4 && pYx.size() != 0) 00494 PLERROR("check_prob failed, sum_y=%g",sum_y); 00495 } 00496 00497 inline void update(ProbabilitySparseMatrix& pYX, ProbabilitySparseMatrix& nYX) 00498 { 00499 pYX.clear(); 00500 nYX.computeX(); 00501 nYX.computeY(); 00502 for (SetIterator xit = nYX.X.begin(); xit != nYX.X.end(); ++xit) 00503 { 00504 int x = *xit; 00505 real sumYx = nYX.sumPYx(x); 00506 if (sumYx != 0.0) 00507 { 00508 for (SetIterator yit = nYX.Y.begin(); yit != nYX.Y.end(); ++yit) 00509 { 00510 int y = *yit; 00511 real p = nYX(y, x) / sumYx; 00512 if (p) 00513 pYX.set(y, x, p); 00514 } 00515 } 00516 } 00517 pYX.computeY(); 00518 pYX.computeX(); 00519 } 00520 00521 inline void updateAndClearCounts(ProbabilitySparseMatrix& pYX, ProbabilitySparseMatrix& nYX) 00522 { 00523 pYX.clear(); 00524 nYX.computeX(); 00525 nYX.computeY(); 00526 for (SetIterator xit = nYX.X.begin(); xit != nYX.X.end(); ++xit) 00527 { 00528 int x = *xit; 00529 real sumYx = nYX.sumPYx(x); 00530 if (sumYx != 0.0) 00531 { 00532 for (SetIterator yit = nYX.Y.begin(); yit != nYX.Y.end(); ++yit) 00533 { 00534 int y = *yit; 00535 real p = nYX(y, x) / sumYx; 00536 if (p) 00537 pYX.set(y, x, p); 00538 } 00539 } 00540 } 00541 pYX.computeY(); 00542 pYX.computeX(); 00543 } 00544 00545 inline ostream& operator<<(ostream& out, ProbabilitySparseMatrix& pyx) 00546 { 00547 bool re = pyx.raise_error; 00548 pyx.raise_error = false; 00549 for (int y = 0; y < pyx.ny(); y++) 00550 { 00551 for (int x = 0; x < pyx.nx(); x++) 00552 { 00553 out << setw(NUMWIDTH) << pyx(y, x); 00554 } 00555 out << endl; 00556 } 00557 pyx.raise_error = re; 00558 return out; 00559 } 00560 00561 inline void print(ostream& out, ProbabilitySparseMatrix& pyx, Set Y, Set X) 00562 { 00563 for (SetIterator yit = Y.begin(); yit != Y.end(); ++yit) 00564 { 00565 int y = *yit; 00566 for (SetIterator xit = X.begin(); xit != X.end(); ++xit) 00567 { 00568 int x = *xit; 00569 out << setw(NUMWIDTH) << pyx(y, x); 00570 } 00571 out << endl; 00572 } 00573 } 00574 00575 inline void print(ostream& out, RowMapSparseMatrix<real>& m) 00576 { 00577 for (int i = 0; i < m.length(); i++) 00578 { 00579 for (int j = 0; j < m.width(); j++) 00580 { 00581 out << setw(NUMWIDTH) << m(i, j); 00582 } 00583 out << endl; 00584 } 00585 } 00586 00587 inline void print(ostream& out, const map<int, real>& vec, int size) 00588 { 00589 for (int i = 0; i < size; i++) 00590 { 00591 map<int, real>::const_iterator vec_it = vec.find(i); 00592 if (vec_it != vec.end()) 00593 out << setw(NUMWIDTH) << vec_it->second; 00594 else 00595 out << setw(NUMWIDTH) << 0; 00596 } 00597 out << endl; 00598 } 00599 00600 inline void print(ostream& out, const map<int, real>& vec) 00601 { 00602 for (map<int, real>::const_iterator it = vec.begin(); it != vec.end(); ++it) 00603 { 00604 out << setw(NUMWIDTH) << it->second; 00605 } 00606 out << endl; 00607 } 00608 00609 inline void print(ostream& out, const map<int, real>& vec, Set V) 00610 { 00611 for (SetIterator vit = V.begin(); vit != V.end(); ++vit) 00612 { 00613 int v = *vit; 00614 map<int, real>::const_iterator vec_it = vec.find(v); 00615 if (vec_it != vec.end()) 00616 out << setw(NUMWIDTH) << vec_it->second; 00617 else 00618 out << setw(NUMWIDTH) << 0; 00619 } 00620 out << endl; 00621 } 00622 00623 } 00624 00625 #endif

Generated on Tue Aug 17 16:02:52 2004 for PLearn by doxygen 1.3.7