00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
#include "BatchVMatrix.h"
00040
00041
namespace PLearn {
00042
using namespace std;
00043
00044
PLEARN_IMPLEMENT_OBJECT(BatchVMatrix,
"ONE LINE DESCR",
"VMat class that replicates small parts of a matrix (mini-batches), so that each mini-batch appears twice (consecutively).");
00045
00047
00049 void BatchVMatrix::declareOptions(
OptionList& ol)
00050 {
00051
declareOption(ol,
"m", &BatchVMatrix::m, OptionBase::buildoption,
00052
" The matrix viewed by the BatchVMatrix\n");
00053
declareOption(ol,
"batch_size", &BatchVMatrix::batch_size, OptionBase::buildoption,
00054
" The size of each mini-batch\n");
00055 inherited::declareOptions(ol);
00056 }
00057
00059
00061 void BatchVMatrix::makeDeepCopyFromShallowCopy(map<const void*, void*>& copies)
00062 {
00063 inherited::makeDeepCopyFromShallowCopy(copies);
00064
deepCopyField(
m, copies);
00065 }
00066
00068
00070 void BatchVMatrix::build()
00071 {
00072 inherited::build();
00073
build_();
00074 }
00075
00077
00079 void BatchVMatrix::build_()
00080 {
00081
if (
m) {
00082 width_ =
m->
width();
00083 length_ =
m->
length() * 2;
00084 fieldinfos =
m->getFieldInfos();
00085
last_batch = (
m->
length()-1) /
batch_size;
00086
last_batch_size =
m->
length() %
batch_size;
00087
if (
last_batch_size == 0)
00088
last_batch_size =
batch_size;
00089 }
00090 }
00091
00093
00095 real BatchVMatrix::get(
int i,
int j)
const {
00096
int n_batch = i / (2 *
batch_size);
00097
int k =
batch_size;
00098
if (n_batch ==
last_batch) {
00099
00100
k =
last_batch_size;
00101 }
00102
int i_ = n_batch * batch_size + (i - n_batch * 2 * batch_size) %
k;
00103
return m->get(i_, j);
00104 }
00105
00107
00109 void BatchVMatrix::put(
int i,
int j,
real value) {
00110
int n_batch = i / (2 *
batch_size);
00111
int i_ = n_batch *
batch_size + (i - n_batch * 2 *
batch_size) %
batch_size;
00112
m->put(i_, j, value);
00113 }
00114
00115 }