00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00040
#ifndef PDistribution_INC
00041
#define PDistribution_INC
00042
00043
#include <plearn_learners/generic/PLearner.h>
00044
00045
namespace PLearn {
00046
using namespace std;
00047
00048 class PDistribution:
public PLearner
00049 {
00050
00051
private:
00052
00053 typedef PLearner inherited;
00054
00056 mutable Vec store_expect,
store_result;
00057 mutable Mat store_cov;
00058
00059
protected:
00060
00061
00062
00063
00064
00065 TVec<int> cond_sort;
00066 int n_input;
00067 int n_margin;
00068 int n_target;
00069
00070
00071
00075 bool already_sorted;
00076
00081 TVec<int> cond_swap;
00082
00084 real delta_curve;
00085
00089 bool full_joint_distribution;
00090
00091 mutable Vec input_part;
00092 mutable Vec target_part;
00093
00096 mutable bool need_set_input;
00097
00098
public:
00099
00100
00101
00102
00103
00104 TVec<int> conditional_flags;
00105 real lower_bound,
upper_bound;
00106 int n_curve_points;
00107 string outputs_def;
00108 Vec provide_input;
00109
00110
00111
00112
00113
00115
PDistribution();
00116
00117
00118
00119
00120
00121
private:
00122
00124
void build_();
00125
00126
protected:
00127
00129
static void declareOptions(
OptionList& ol);
00130
00131
public:
00132
00133
00134
00135
00136
00138
virtual void build();
00139
00141
virtual void makeDeepCopyFromShallowCopy(map<const void*, void*>& copies);
00142
00143
00144
PLEARN_DECLARE_OBJECT(
PDistribution);
00145
00146
00147
00148
00149
00151
virtual int outputsize() const;
00152
00155 virtual
void forget();
00156
00159 virtual
void train();
00160
00162 virtual
void computeOutput(const
Vec& input,
Vec& output) const;
00163
00166 virtual
void computeCostsFromOutputs(const
Vec& input, const
Vec& output,
00167 const
Vec& target,
Vec& costs) const;
00168
00169
00170
00171
00172
00173 private:
00174
00180
void setConditionalFlagsWithoutUpdate(
TVec<
int>& flags);
00181
00182 protected:
00183
00188
bool ensureFullJointDistribution(
TVec<
int>& old_flags);
00189
00193
void finishConditionalBuild();
00194
00196
void resizeParts();
00197
00202
void sortFromFlags(
Vec& v);
00203
void sortFromFlags(
Mat& m,
bool sort_columns = true,
bool sort_rows = false);
00204
00210
bool splitCond(const
Vec& input) const;
00211
00215 virtual
void updateFromConditionalSorting();
00216
00217 public:
00218
00220
void setConditionalFlags(
TVec<
int>& flags);
00221
00225 virtual
void setInput(const
Vec& input) const;
00226
00229 virtual
void setTrainingSet(
VMat training_set,
bool call_forget=true);
00230
00232 virtual
TVec<
string> getTestCostNames() const;
00233
00235 virtual
TVec<
string> getTrainCostNames() const;
00236
00238 virtual
real log_density(const
Vec& y) const;
00239
00242 virtual
real density(const
Vec& y) const;
00243
00245 virtual
real survival_fn(const
Vec& y) const;
00246
00248 virtual
real cdf(const
Vec& y) const;
00249
00251 virtual
void expectation(
Vec& mu) const;
00252
00254 virtual
void variance(
Mat& cov) const;
00255
00257 virtual
void resetGenerator(
long g_seed) const;
00258
00260 virtual
void generate(
Vec& y) const;
00261
00264
void generateN(const
Mat& Y) const;
00265
00266 };
00267
00268
00269 DECLARE_OBJECT_PTR(
PDistribution);
00270
00271 }
00272
00273 #endif