00001
#ifndef PROBABILITY_SPARSE_MATRIX_H
00002
#define PROBABILITY_SPARSE_MATRIX_H
00003
00004
#include <plearn/math/RowMapSparseMatrix.h>
00005
#include <plearn/base/Object.h>
00006
#include "Set.h"
00007
00008 #define NUMWIDTH 10
00009
00010
namespace PLearn {
00011
00012
00013 typedef map<int, real>
SparseVec;
00014 typedef const map<int, real>
ConstSparseVec;
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 class ProbabilitySparseMatrix :
public Object
00026 {
00027
private:
00028 RowMapSparseMatrix<real> x2y;
00029 RowMapSparseMatrix<real> y2x;
00030
public:
00031
00032 Set Y;
00033 Set X;
00034
00035 bool raise_error;
00036 string name;
00037
00038 ProbabilitySparseMatrix(
int ny=0,
int nx=0,
string pname =
"pYX") :
raise_error(true)
00039 {
00040
resize(ny,nx);
00041
name = pname;
00042 }
00043 void resize(
int ny,
int nx) {
00044
x2y.
resize(nx,ny);
00045
y2x.
resize(ny,nx);
00046 }
00047 int nx() {
return x2y.
length(); }
00048 int ny() {
return y2x.
length(); }
00049
00050 void rename(
string new_name) {
name = new_name; }
00051
00052 Set computeX()
00053 {
00054
X.
clear();
00055
for (
int xx = 0; xx <
nx(); xx++)
00056 {
00057
if (
sumPYx(xx) > 0)
00058 {
00059
X.
insert(xx);
00060 }
00061 }
00062
return X;
00063 }
00064
00065 Set computeY()
00066 {
00067
Y.
clear();
00068
for (
int yy = 0; yy <
ny(); yy++)
00069 {
00070
if (
sumPyX(yy) > 0)
00071 {
00072
Y.
insert(yy);
00073 }
00074 }
00075
return Y;
00076 }
00077
00078 real get(
int y,
int x) {
00079 map<int,real>& mx =
x2y(
x);
00080
if (mx.size()==0)
00081 {
00082
if (
raise_error)
00083
00084
PLWARNING(
"trying to access an invalid probability at P(%d|%d) in %s",y,
x,
name.c_str());
00085
return 0;
00086 }
00087
if (mx.find(y)!=mx.end())
00088
return mx[y];
00089
return 0;
00090 }
00091
00092 bool exists(
int y,
int x) {
00093 map<int,real>& mx =
x2y(
x);
00094
if (mx.size()==0)
return false;
00095
return mx.find(y)!=mx.end();
00096 }
00097
00098
00099 real operator()(
int y,
int x) {
return get(y,
x); }
00100 void incr(
int y,
int x,
real increment=1)
00101 {
00102
if (increment!=0)
00103 {
00104
real current_value = 0;
00105 map<int,real>& mx =
x2y(
x);
00106
if (mx.size()!=0)
00107 {
00108 map<int,real>::const_iterator it = mx.find(y);
00109
if (it!=mx.end())
00110 current_value = it->second;
00111 }
00112
set(y,
x,current_value+increment);
00113 }
00114 }
00115 const map<int,real>&
getPYx(
int x,
bool dont_raise_error=
false)
00116 {
00117
const map<int,real>& PYx =
x2y(
x);
00118
if (
raise_error && !dont_raise_error && PYx.size()==0)
00119
PLERROR(
"ProbabilitySparseMatrix::getPyx: accessing an empty column at X=%d",
x);
00120
return PYx;
00121 }
00122 const map<int,real>&
getPyX(
int y)
00123 {
00124
return y2x(y);
00125 }
00126
00127 map<int,real>
getPYxCopy(
int x,
bool dont_raise_error=
false)
00128 {
00129 map<int,real> PYx =
x2y(
x);
00130
if (
raise_error && !dont_raise_error && PYx.size()==0)
00131
PLERROR(
"ProbabilitySparseMatrix::getPyx: accessing an empty column at X=%d",
x);
00132
return PYx;
00133 }
00134
00135 map<int,real>
getPyXCopy(
int y)
00136 {
00137 map<int, real> pyX =
y2x(y);
00138
return pyX;
00139 }
00140
00141 void setPYx(
int x,
const map<int, real>& pYx)
00142 {
00143
for (map<int, real>::const_iterator it = pYx.begin(); it != pYx.end(); ++it)
00144
set(it->first,
x, it->second);
00145 }
00146
00147 void setPyX(
int y,
const map<int, real>& pyX)
00148 {
00149
for (map<int, real>::const_iterator it = pyX.begin(); it != pyX.end(); ++it)
00150
set(y, it->first, it->second);
00151 }
00152
00153 void set(
int y,
int x,
real v,
bool dont_warn_for_zero =
false) {
00154
if (v!=0)
00155 {
00156
x2y(
x,y)=v;
00157
y2x(y,
x)=v;
00158 }
else
00159 {
00160
if (!dont_warn_for_zero)
00161
PLWARNING(
"setting something to 0 in ProbabilitySparseMatrix %s",
name.c_str());
00162 map<int,real>& PYx =
x2y(
x);
00163
if (PYx.find(y)!=PYx.end())
00164 {
00165 PYx.erase(y);
00166 map<int,real>& PyX =
y2x(y);
00167 PyX.erase(
x);
00168 }
00169 }
00170 }
00171
00172 void removeElem(
int y,
int x)
00173 {
00174 set(y,
x, 0.0,
true);
00175 }
00176
00177 real sumPYx(
int x,
Set Y)
00178 {
00179
real sum_pYx = 0.0;
00180 map<int, real>& col =
x2y(
x);
00181
for (map<int, real>::const_iterator yit = col.begin(); yit != col.end(); ++yit)
00182 {
00183
int y = yit->first;
00184
if (Y.
contains(y))
00185 sum_pYx += yit->second;
00186 }
00187
return sum_pYx;
00188 }
00189
00190 real sumPyX(
int y,
Set X)
00191 {
00192
real sum_pyX = 0.0;
00193 map<int, real>& row =
y2x(y);
00194
for (map<int, real>::const_iterator xit = row.begin(); xit != row.end(); ++xit)
00195 {
00196
int x = xit->first;
00197
if (X.
contains(
x))
00198 sum_pyX += xit->second;
00199 }
00200
return sum_pyX;
00201 }
00202
00203 real sumPYx(
int x)
00204 {
00205
real sum_pYx = 0.0;
00206 map<int, real>& col =
x2y(
x);
00207
for (map<int, real>::const_iterator yit = col.begin(); yit != col.end(); ++yit)
00208 {
00209 sum_pYx += yit->second;
00210 }
00211
return sum_pYx;
00212 }
00213
00214 real sumPyX(
int y)
00215 {
00216
real sum_pyX = 0.0;
00217 map<int, real>& row =
y2x(y);
00218
for (map<int, real>::const_iterator xit = row.begin(); xit != row.end(); ++xit)
00219 {
00220 sum_pyX += xit->second;
00221 }
00222
return sum_pyX;
00223 }
00224
00225
00226
00227 void clearElements() {
00228
for (
int x=0;
x<
nx();
x++)
00229 {
00230 map<int,real>& r =
x2y(
x);
00231
for (map<int,real>::iterator it = r.begin(); it!=r.end(); ++it)
00232 {
00233 it->second=0;
00234
y2x(it->first,
x)=0;
00235 }
00236 }
00237 }
00238
00239
00240 void clear() {
x2y.
clear();
y2x.
clear(); }
00241
00242 void removeRow(
int y,
Set X) {
00243
y2x(y).
clear();
00244
for (
SetIterator it=X.
begin();it!=X.
end();++it)
00245 {
00246
int x = *it;
00247 map<int,real>& pYx =
x2y(
x);
00248
if (pYx.size()>0)
00249 pYx.erase(y);
00250 }
00251 }
00252
00253 void removeRow(
int y)
00254 {
00255 map<int, real>& row =
y2x(y);
00256
for (map<int, real>::iterator it = row.begin(); it != row.end(); ++it)
00257 {
00258
int x = it->first;
00259 map<int,real>& Yx =
x2y(
x);
00260
if (Yx.size()>0)
00261 Yx.erase(y);
00262 }
00263 row.clear();
00264 }
00265
00266 void removeColumn(
int x)
00267 {
00268 map<int, real>& col =
x2y(
x);
00269
for (map<int, real>::iterator it = col.begin(); it != col.end(); ++it)
00270 {
00271
int y = it->first;
00272 map<int,real>& yX =
y2x(y);
00273
if (yX.size()>0)
00274 yX.erase(
x);
00275 }
00276 col.clear();
00277 }
00278
00279 int size()
00280 {
00281
if (
x2y.
size() !=
y2x.
size())
00282
PLWARNING(
"x2y and y2x sizes dont match");
00283
return y2x.
size();
00284 }
00285
00286 void removeExtra(
ProbabilitySparseMatrix& m)
00287 {
00288
00289
int _ny =
ny();
00290
Set x_to_remove;
00291
for (
int y = 0; y < _ny; y++)
00292 {
00293
00294
ConstSparseVec& yX = getPyX(y);
00295 x_to_remove.
clear();
00296
for (SparseVec::const_iterator xit = yX.begin(); xit != yX.end(); ++xit)
00297 {
00298
int x = xit->first;
00299
if (!m.
exists(y,
x))
00300 x_to_remove.
insert(
x);
00301 }
00302
for (
SetIterator xit = x_to_remove.
begin(); xit != x_to_remove.
end(); ++xit)
00303 {
00304
int x = *xit;
00305 removeElem(y,
x);
00306 }
00307 }
00308 }
00309
00310 void fullPrint()
00311 {
00312 cout <<
"y2x" <<
endl;
00313
for (
int y = 0; y <
ny(); y++)
00314 {
00315
for (
int x = 0;
x <
nx();
x++)
00316 {
00317 cout <<
y2x(y,
x) <<
" ";
00318 }
00319 cout <<
endl;
00320 }
00321 cout <<
"x2y" <<
endl;
00322
for (
int x = 0;
x <
nx();
x++)
00323 {
00324
for (
int y = 0; y <
ny(); y++)
00325 {
00326 cout <<
x2y(
x, y) <<
" ";
00327 }
00328 cout <<
endl;
00329 }
00330 }
00331
00332 void save(
string filename)
00333 {
00334
y2x.
save(filename +
".y2x");
00335
x2y.
save(filename +
".x2y");
00336 }
00337
00338 void load(
string filename)
00339 {
00340
y2x.
load(filename +
".y2x");
00341
x2y.
load(filename +
".x2y");
00342 }
00343
00344 real*
getAsFullVector()
00345 {
00346
00347
int vector_size =
y2x.
size() * 3;
00348
real* full_vector =
new real[vector_size];
00349
int pos = 0;
00350
for (
int i = 0; i <
ny(); i++)
00351 {
00352 map<int, real>& row_i =
y2x(i);
00353
for (map<int, real>::iterator it = row_i.begin(); it != row_i.end(); ++it)
00354 {
00355
int j = it->first;
00356 real value = it->second;
00357 full_vector[pos++] = (real)i;
00358 full_vector[pos++] = (real)j;
00359 full_vector[pos++] = value;
00360 }
00361 }
00362
if (pos != vector_size)
00363
PLERROR(
"weird");
00364
return full_vector;
00365 }
00366
00367 void getAsMaxSizedVectors(
int max_size,
vector<pair<real*, int> >& vectors)
00368 {
00369
if ((max_size % 3) != 0)
PLWARNING(
"dangerous vector size (max_size mod 3 must equal 0)");
00370
00371
int n_elems =
y2x.
size() * 3;
00372
int n_vecs = n_elems / max_size;
00373
int remaining = n_elems % max_size;
00374
if (remaining > 0)
00375 {
00376 n_vecs += 1;
00377
int mod3 = remaining % 3;
00378
if (mod3 != 0)
00379 remaining += (3 - mod3);
00380 }
00381 vectors.resize(n_vecs);
00382
00383
for (
int i = 0; i < n_vecs; i++)
00384 {
00385
if (i == (n_vecs - 1) && remaining > 0)
00386 {
00387 vectors[i].first =
new real[remaining];
00388 vectors[i].second = remaining;
00389
00390 }
else
00391 {
00392 vectors[i].first =
new real[max_size];
00393 vectors[i].second = max_size;
00394
00395 }
00396 }
00397
int pos = 0;
00398
for (
int i = 0; i <
ny(); i++)
00399 {
00400 map<int, real>& row_i =
y2x(i);
00401
for (map<int, real>::iterator it = row_i.begin(); it != row_i.end(); ++it)
00402 {
00403
int j = it->first;
00404
real value = it->second;
00405 vectors[pos / max_size].first[pos++ % max_size] = i;
00406 vectors[pos / max_size].first[pos++ % max_size] = j;
00407 vectors[pos / max_size].first[pos++ % max_size] = value;
00408 }
00409 }
00410
while (pos < n_elems)
00411 {
00412 vectors[pos / max_size].first[pos++ % max_size] = 0;
00413 vectors[pos / max_size].first[pos++ % max_size] = 0;
00414 vectors[pos / max_size].first[pos++ % max_size] = 0;
00415 }
00416 }
00417
00418 void add(
real* full_vector,
int n_elems)
00419 {
00420
for (
int i = 0; i < n_elems; i += 3)
00421 incr((
int)full_vector[i], (
int)full_vector[i + 1], full_vector[i + 2]);
00422 }
00423
00424 void set(
real* full_vector,
int n_elems)
00425 {
00426
clear();
00427
for (
int i = 0; i < n_elems; i += 3)
00428 set((
int)full_vector[i], (
int)full_vector[i + 1], full_vector[i + 2]);
00429 }
00430
00431 real sumOfElements()
00432 {
00433
real sum = 0.0;
00434
for (
int i = 0; i <
ny(); i++)
00435 {
00436 map<int, real>& row_i =
y2x(i);
00437
for (map<int, real>::iterator it = row_i.begin(); it != row_i.end(); ++it)
00438 {
00439
int j = it->first;
00440
sum += get(i, j);
00441 }
00442 }
00443
return sum;
00444 }
00445
00446 };
00447
00448 inline void samePos(
ProbabilitySparseMatrix& m1,
ProbabilitySparseMatrix& m2,
string m1name,
string m2name)
00449 {
00450
for (
SetIterator yit = m1.
Y.
begin(); yit != m1.
Y.
end(); ++yit)
00451 {
00452
int y = *yit;
00453
const map<int, real>& yX = m1.
getPyX(y);
00454
for (map<int, real>::const_iterator it = yX.begin(); it != yX.end(); ++it)
00455 {
00456
int x = it->first;
00457
if (!m2.
exists(y,
x))
00458
PLERROR(
"in samePos, %s contains an element that is not present in %s", m1name.c_str(), m2name.c_str());
00459 }
00460 }
00461 }
00462
00463 inline void check_prob(
ProbabilitySparseMatrix& pYX,
string Yname,
string Xname)
00464 {
00465
bool failed =
false;
00466
real sum_pY = 0.0;
00467
Set X = pYX.
X;
00468
for (
SetIterator x_it = X.
begin(); x_it!=X.
end(); ++x_it)
00469 {
00470
int x = *x_it;
00471
const map<int,real>& pYx = pYX.
getPYx(
x);
00472 sum_pY = 0.0;
00473
for (map<int,real>::const_iterator y_it=pYx.begin();y_it!=pYx.end();++y_it)
00474 {
00475 sum_pY += y_it->second;
00476 }
00477
if (fabs(sum_pY - 1.0) > 1e-4)
00478 {
00479 failed =
true;
00480
break;
00481 }
00482 }
00483
if (failed)
00484
PLERROR(
"check_prob failed for %s -> %s (sum of a column = %g)", Xname.c_str(), Yname.c_str(), sum_pY);
00485 }
00486
00487 inline void check_prob(
Set Y,
const map<int, real>& pYx)
00488 {
00489
real sum_y=0;
00490
for (map<int,real>::const_iterator y_it=pYx.begin();y_it!=pYx.end();++y_it)
00491
if (Y.
contains(y_it->first))
00492 sum_y += y_it->second;
00493
if (fabs(sum_y-1)>1e-4 && pYx.size() != 0)
00494
PLERROR(
"check_prob failed, sum_y=%g",sum_y);
00495 }
00496
00497 inline void update(
ProbabilitySparseMatrix& pYX,
ProbabilitySparseMatrix& nYX)
00498 {
00499 pYX.
clear();
00500 nYX.
computeX();
00501 nYX.
computeY();
00502
for (
SetIterator xit = nYX.
X.
begin(); xit != nYX.
X.
end(); ++xit)
00503 {
00504
int x = *xit;
00505
real sumYx = nYX.
sumPYx(
x);
00506
if (sumYx != 0.0)
00507 {
00508
for (
SetIterator yit = nYX.
Y.
begin(); yit != nYX.
Y.
end(); ++yit)
00509 {
00510
int y = *yit;
00511
real p = nYX(y,
x) / sumYx;
00512
if (p)
00513 pYX.
set(y,
x, p);
00514 }
00515 }
00516 }
00517 pYX.
computeY();
00518 pYX.
computeX();
00519 }
00520
00521 inline void updateAndClearCounts(
ProbabilitySparseMatrix& pYX,
ProbabilitySparseMatrix& nYX)
00522 {
00523 pYX.
clear();
00524 nYX.
computeX();
00525 nYX.
computeY();
00526
for (
SetIterator xit = nYX.
X.
begin(); xit != nYX.
X.
end(); ++xit)
00527 {
00528
int x = *xit;
00529
real sumYx = nYX.
sumPYx(
x);
00530
if (sumYx != 0.0)
00531 {
00532
for (
SetIterator yit = nYX.
Y.
begin(); yit != nYX.
Y.
end(); ++yit)
00533 {
00534
int y = *yit;
00535
real p = nYX(y,
x) / sumYx;
00536
if (p)
00537 pYX.
set(y,
x, p);
00538 }
00539 }
00540 }
00541 pYX.
computeY();
00542 pYX.
computeX();
00543 }
00544
00545 inline ostream&
operator<<(ostream& out,
ProbabilitySparseMatrix& pyx)
00546 {
00547
bool re = pyx.
raise_error;
00548 pyx.
raise_error =
false;
00549
for (
int y = 0; y < pyx.
ny(); y++)
00550 {
00551
for (
int x = 0;
x < pyx.
nx();
x++)
00552 {
00553 out << setw(
NUMWIDTH) << pyx(y,
x);
00554 }
00555 out <<
endl;
00556 }
00557 pyx.
raise_error = re;
00558
return out;
00559 }
00560
00561 inline void print(ostream& out,
ProbabilitySparseMatrix& pyx,
Set Y,
Set X)
00562 {
00563
for (
SetIterator yit = Y.
begin(); yit != Y.
end(); ++yit)
00564 {
00565
int y = *yit;
00566
for (
SetIterator xit = X.
begin(); xit != X.
end(); ++xit)
00567 {
00568
int x = *xit;
00569 out << setw(
NUMWIDTH) << pyx(y,
x);
00570 }
00571 out <<
endl;
00572 }
00573 }
00574
00575 inline void print(ostream& out,
RowMapSparseMatrix<real>& m)
00576 {
00577
for (
int i = 0; i < m.
length(); i++)
00578 {
00579
for (
int j = 0; j < m.
width(); j++)
00580 {
00581 out << setw(
NUMWIDTH) << m(i, j);
00582 }
00583 out <<
endl;
00584 }
00585 }
00586
00587 inline void print(ostream& out,
const map<int, real>& vec,
int size)
00588 {
00589
for (
int i = 0; i < size; i++)
00590 {
00591 map<int, real>::const_iterator vec_it = vec.find(i);
00592
if (vec_it != vec.end())
00593 out << setw(
NUMWIDTH) << vec_it->second;
00594
else
00595 out << setw(
NUMWIDTH) << 0;
00596 }
00597 out <<
endl;
00598 }
00599
00600 inline void print(ostream& out,
const map<int, real>& vec)
00601 {
00602
for (map<int, real>::const_iterator it = vec.begin(); it != vec.end(); ++it)
00603 {
00604 out << setw(
NUMWIDTH) << it->second;
00605 }
00606 out <<
endl;
00607 }
00608
00609 inline void print(ostream& out,
const map<int, real>& vec,
Set V)
00610 {
00611
for (
SetIterator vit = V.
begin(); vit != V.
end(); ++vit)
00612 {
00613
int v = *vit;
00614 map<int, real>::const_iterator vec_it = vec.find(v);
00615
if (vec_it != vec.end())
00616 out << setw(
NUMWIDTH) << vec_it->second;
00617
else
00618 out << setw(
NUMWIDTH) << 0;
00619 }
00620 out <<
endl;
00621 }
00622
00623 }
00624
00625
#endif