00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
#include "ConcatRowsVMatrix.h"
00042
#include "SelectColumnsVMatrix.h"
00043
00044
namespace PLearn {
00045
using namespace std;
00046
00049
PLEARN_IMPLEMENT_OBJECT(ConcatRowsVMatrix,
00050
"Concatenates the rows of a number of VMat.",
00051
"It can also be used to select fields which are common to those VMat,\n"
00052
"using the 'only_common_fields' option.\n"
00053
"Otherwise, the fields are just assumed to be those of the first VMat.\n"
00054 );
00055
00057
00059 ConcatRowsVMatrix::ConcatRowsVMatrix(
TVec<VMat> the_array)
00060 : array(the_array),
00061 fully_check_mappings(false),
00062 only_common_fields(false)
00063 {
00064
if (
array.
size() > 0)
00065
build_();
00066 };
00067
00068 ConcatRowsVMatrix::ConcatRowsVMatrix(
VMat d1,
VMat d2)
00069 : fully_check_mappings(false),
00070 only_common_fields(false)
00071 {
00072
array.
resize(2);
00073
array[0] = d1;
00074
array[1] = d2;
00075
build_();
00076 };
00077
00079
00081 void ConcatRowsVMatrix::declareOptions(
OptionList &ol)
00082 {
00083
declareOption(ol,
"array", &ConcatRowsVMatrix::array, OptionBase::buildoption,
00084
"The VMat to concatenate.");
00085
00086
declareOption(ol,
"fully_check_mappings", &ConcatRowsVMatrix::fully_check_mappings, OptionBase::buildoption,
00087
"If set to 1, then columns for which there is a string <-> real mapping will be examined\n"
00088
"to ensure that no numerical data in a VMat conflicts with a mapping in another VMat.");
00089
00090
declareOption(ol,
"only_common_fields", &ConcatRowsVMatrix::only_common_fields, OptionBase::buildoption,
00091
"If set to 1, then only the fields whose names are common to all VMat\n"
00092
"in 'array' will be kept (and reordered if needed).");
00093
00094 inherited::declareOptions(ol);
00095 }
00096
00098
00100 void ConcatRowsVMatrix::build()
00101 {
00102 inherited::build();
00103
build_();
00104 }
00105
00107
00109 void ConcatRowsVMatrix::build_()
00110 {
00111
int n =
array.
size();
00112
if (n < 1)
00113
PLERROR(
"ConcatRowsVMatrix expects underlying-array of length >= 1, got 0");
00114
00115
00116 fieldinfos =
array[0]->getFieldInfos();
00117
if (
only_common_fields) {
00118
findCommonFields();
00119 }
else {
00120
to_concat = array;
00121 }
00122
00123
00124
recomputeDimensions();
00125
00126
00127
ensureMappingsConsistency();
00128
if (
fully_check_mappings)
00129
fullyCheckMappings();
00130
if (
need_fix_mappings && !
fully_check_mappings)
00131
PLWARNING(
"In ConcatRowsVMatrix::build_ - Mappings need to be fixed, but you did not set 'fully_check_mappings' to true, this might be dangerous");
00132
00133
00134 inputsize_ =
to_concat[0]->inputsize();
00135 targetsize_ = to_concat[0]->targetsize();
00136 weightsize_ = to_concat[0]->weightsize();
00137 }
00138
00140
00142 real ConcatRowsVMatrix::dot(
int i1,
int i2,
int inputsize)
const
00143
{
00144
int whichvm1, rowofvm1;
00145
getpositions(i1,whichvm1,rowofvm1);
00146
int whichvm2, rowofvm2;
00147
getpositions(i2,whichvm2,rowofvm2);
00148
if(whichvm1==whichvm2 && !
need_fix_mappings)
00149
return to_concat[whichvm1]->dot(rowofvm1, rowofvm2, inputsize);
00150
else
00151
return VMatrix::dot(i1,i2,inputsize);
00152 }
00153
00154 real ConcatRowsVMatrix::dot(
int i,
const Vec& v)
const
00155
{
00156
if (!
need_fix_mappings) {
00157
int whichvm, rowofvm;
00158
getpositions(i,whichvm,rowofvm);
00159
return to_concat[whichvm]->dot(rowofvm,v);
00160 }
00161
else
00162
return VMatrix::dot(i, v);
00163 }
00164
00166
00168 void ConcatRowsVMatrix::ensureMappingsConsistency() {
00169
00170
00171
00172
00173
00174
need_fix_mappings =
false;
00175 copyStringMappingsFrom(
to_concat[0]);
00176 map<string, real> other_map;
00177 map<string, real>* cur_map;
00178 map<string, real>::iterator it, find_map, jt;
00179
bool report_progress =
false;
00180
ProgressBar* pb = 0;
00181
if (report_progress)
00182 pb =
new ProgressBar(
"Checking mappings consistency",
width());
00183
for (
int j = 0; j <
width(); j++) {
00184 cur_map = &map_sr[j];
00185
00186
real max = -
REAL_MAX;
00187
for (jt = cur_map->begin(); jt != cur_map->end(); jt++)
00188
if (jt->second >
max)
00189
max = jt->second;
00190
for (
int i = 1; i <
to_concat.
length(); i++) {
00191 other_map =
to_concat[i]->getStringToRealMapping(j);
00192
for(it = other_map.begin(); it != other_map.end(); it++) {
00193
00194
00195
00196
00197
00198
00199 find_map = cur_map->find(it->first);
00200
if (find_map != cur_map->end()) {
00201
00202
if (find_map->second != it->second) {
00203
00204
00205
00206
00207
need_fix_mappings =
true;
00208
fixed_mappings.
resize(to_concat.length(),
width());
00209
fixed_mappings(i, j)[it->second] = find_map->second;
00210 }
00211 }
else {
00212
00213
00214
00215
real new_map_val = it->second;
00216
if (getValString(j, it->second) !=
"") {
00217
00218
need_fix_mappings =
true;
00219
fixed_mappings.
resize(to_concat.length(),
width());
00220
00221
max++;
00222
00223
00224
00225
fixed_mappings(i, j)[it->second] =
max;
00226 new_map_val =
max;
00227 }
else {
00228
00229
if (new_map_val >
max)
00230
max = new_map_val;
00231
00232
00233 }
00234 addStringMapping(j, it->first, new_map_val);
00235 }
00236 }
00237 }
00238
if (report_progress)
00239 pb->
update(j + 1);
00240 }
00241
if (pb)
00242
delete pb;
00243 }
00244
00246
00248 void ConcatRowsVMatrix::findCommonFields() {
00249
00250
TVec<VMField> final_fields(fieldinfos.
length());
00251 final_fields << fieldinfos;
00252
TVec<VMField> other_fields;
00253
TVec<VMField> tmp(final_fields.
length());
00254
for (
int i = 1; i <
array.
length(); i++) {
00255 map<string, bool> can_be_kept;
00256 other_fields =
array[i]->getFieldInfos();
00257
for (
int j = 0; j < other_fields.
length(); j++) {
00258 can_be_kept[other_fields[j].name] =
true;
00259 }
00260 tmp.
resize(0);
00261
for (
int j = 0; j < final_fields.
length(); j++)
00262
if (can_be_kept.count(final_fields[j].name) > 0)
00263 tmp.
append(final_fields[j]);
00264 final_fields.
resize(tmp.
length());
00265 final_fields << tmp;
00266 }
00267 fieldinfos.
resize(final_fields.
length());
00268 fieldinfos << final_fields;
00269
00270
TVec<string> final_fieldnames(final_fields.length());
00271
for (
int i = 0; i < final_fields.length(); i++)
00272 final_fieldnames[i] = final_fields[i].name;
00273
00274
to_concat.
resize(
array.
length());
00275
for (
int i = 0; i <
array.
length(); i++)
00276
to_concat[i] =
new SelectColumnsVMatrix(
array[i], final_fieldnames);
00277 }
00278
00280
00282 void ConcatRowsVMatrix::fullyCheckMappings(
bool report_progress) {
00283
Vec row(
width());
00284
TVec<int> max(
width());
00285
ProgressBar* pb = 0;
00286
if (report_progress)
00287 pb =
new ProgressBar(
"Full check of string mappings",
length());
00288
int count = 0;
00289
for (
int i = 0; i <
to_concat.
length(); i++) {
00290
for (
int j = 0; j <
to_concat[i]->
length(); j++) {
00291 to_concat[i]->getRow(j, row);
00292
for (
int k = 0;
k <
width();
k++) {
00293
if (!
is_missing(row[
k]) && map_sr[
k].
size() > 0) {
00294
00295
00296
00297
00298
if (to_concat[i]->getValString(
k, row[
k]) ==
"") {
00299
00300
if (map_rs[
k].
find(row[
k]) != map_rs[
k].
end()) {
00301
00302
00303
00304
PLERROR(
"In ConcatRowsVMatrix::fullyCheckMappings - In column %s of concatenated VMat number %d, the row %d contains a numerical value (%f) that is used in a string mapping (mapped to %s)",
getFieldInfos(
k).name.c_str(), i, j, row[
k], map_rs[
k][row[
k]].c_str());
00305 }
00306 }
00307 }
00308 }
00309
if (report_progress)
00310 pb->
update(++
count);
00311 }
00312 }
00313
if (pb)
00314
delete pb;
00315 }
00316
00318
00320 real ConcatRowsVMatrix::get(
int i,
int j)
const
00321
{
00322
static real val;
00323
int whichvm, rowofvm;
00324
getpositions(i,whichvm,rowofvm);
00325
if (!
need_fix_mappings ||
fixed_mappings(whichvm, j).
size() == 0)
00326
return to_concat[whichvm]->get(rowofvm,j);
00327
else {
00328
val =
to_concat[whichvm]->get(rowofvm,j);
00329
if (!
is_missing(
val)) {
00330 map<real, real>::iterator it =
fixed_mappings(whichvm, j).find(
val);
00331
if (it != fixed_mappings(whichvm, j).end()) {
00332
00333
return it->second;
00334 }
00335 }
00336
return val;
00337 }
00338 }
00339
00341
00343 void ConcatRowsVMatrix::getpositions(
int i,
int& whichvm,
int& rowofvm)
const
00344
{
00345
#ifdef BOUNDCHECK
00346
if(i<0 || i>=
length())
00347
PLERROR(
"In ConcatRowsVMatrix::getpositions OUT OF BOUNDS");
00348
#endif
00349
00350
int pos = 0;
00351
int k=0;
00352
while(i>=pos+
to_concat[
k]->
length())
00353 {
00354 pos += to_concat[
k]->length();
00355
k++;
00356 }
00357
00358 whichvm =
k;
00359 rowofvm = i-pos;
00360 }
00361
00363
00365 void ConcatRowsVMatrix::getSubRow(
int i,
int j,
Vec v)
const
00366
{
00367
static map<real, real> fixed;
00368
static map<real, real>::iterator it;
00369
int whichvm, rowofvm;
00370
getpositions(i,whichvm,rowofvm);
00371
to_concat[whichvm]->getSubRow(rowofvm, j, v);
00372
if (
need_fix_mappings) {
00373
for (
int k = j;
k < j + v.
length();
k++) {
00374 fixed =
fixed_mappings(whichvm,
k);
00375
if (!
is_missing(v[
k-j])) {
00376 it = fixed.find(v[
k -j]);
00377
if (it != fixed.end())
00378 v[
k - j] = it->second;
00379 }
00380 }
00381 }
00382 }
00383
00385
00387 void ConcatRowsVMatrix::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies) {
00388 inherited::makeDeepCopyFromShallowCopy(copies);
00389
00390
00391
00392
00393
00394
00395
00396
00397
PLERROR(
"ConcatRowsVMatrix::makeDeepCopyFromShallowCopy not fully (correctly) implemented yet!");
00398
00399 }
00400
00402
00404 void ConcatRowsVMatrix::putMat(
int i,
int j,
Mat m) {
00405
int whichvm, rowofvm;
00406
for (
int row = 0; row <
length(); row++) {
00407
getpositions(row + i, whichvm, rowofvm);
00408
to_concat[whichvm]->putSubRow(rowofvm, j, m(row));
00409 }
00410 }
00411
00413
00415 void ConcatRowsVMatrix::recomputeDimensions() {
00416 width_ =
to_concat[0]->width();
00417 length_ = 0;
00418
for (
int i=0; i<to_concat.length(); i++) {
00419
if (to_concat[i]->width() != width_)
00420
PLERROR(
"ConcatRowsVMatrix: underlying-VMat %d has %d width, while 0-th has %d",i,
array[i]->
width(),width_);
00421 length_ += to_concat[i]->length();
00422 }
00423 }
00424
00426
00428 void ConcatRowsVMatrix::reset_dimensions() {
00429
for (
int i=0;i<
to_concat.
size();i++)
00430
to_concat[i]->reset_dimensions();
00431
recomputeDimensions();
00432 }
00433
00434
00435 }