PluginEstimate.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _PLUGINESTIMATE_H___
00012 #define _PLUGINESTIMATE_H___
00013
00014 #include "classifier/Classifier.h"
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "distributions/hmm/LinearHMM.h"
00018
00032 class CPluginEstimate: public CClassifier
00033 {
00034 public:
00039 CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
00040 virtual ~CPluginEstimate();
00041
00046 bool train();
00047
00049 CLabels* classify(CLabels* output=NULL);
00050
00055 virtual inline void set_features(CStringFeatures<uint16_t>* feat)
00056 {
00057 SG_UNREF(features);
00058 SG_REF(feat);
00059 features=feat;
00060 }
00061
00066 virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; }
00067
00069 float64_t classify_example(int32_t vec_idx);
00070
00077 inline float64_t posterior_log_odds_obsolete(
00078 uint16_t* vector, int32_t len)
00079 {
00080 return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
00081 }
00082
00089 inline float64_t get_parameterwise_log_odds(
00090 uint16_t obs, int32_t position)
00091 {
00092 return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
00093 }
00094
00101 inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
00102 {
00103 return pos_model->get_log_derivative_obsolete(obs, pos);
00104 }
00105
00112 inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
00113 {
00114 return neg_model->get_log_derivative_obsolete(obs, pos);
00115 }
00116
00125 inline bool get_model_params(
00126 float64_t*& pos_params, float64_t*& neg_params,
00127 int32_t &seq_length, int32_t &num_symbols)
00128 {
00129 int32_t num;
00130
00131 if ((!pos_model) || (!neg_model))
00132 {
00133 SG_ERROR( "no model available\n");
00134 return false;
00135 }
00136
00137 pos_model->get_log_transition_probs(&pos_params, &num);
00138 neg_model->get_log_transition_probs(&neg_params, &num);
00139
00140 seq_length = pos_model->get_sequence_length();
00141 num_symbols = pos_model->get_num_symbols();
00142 ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters());
00143 ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols());
00144 return true;
00145 }
00146
00153 inline void set_model_params(
00154 const float64_t* pos_params, const float64_t* neg_params,
00155 int32_t seq_length, int32_t num_symbols)
00156 {
00157 int32_t num_params;
00158
00159 delete pos_model;
00160 pos_model=new CLinearHMM(seq_length, num_symbols);
00161 delete neg_model;
00162 neg_model=new CLinearHMM(seq_length, num_symbols);
00163
00164 num_params=pos_model->get_num_model_parameters();
00165 ASSERT(seq_length*num_symbols==num_params);
00166 ASSERT(num_params==neg_model->get_num_model_parameters());
00167
00168 pos_model->set_log_transition_probs(pos_params, num_params);
00169 neg_model->set_log_transition_probs(neg_params, num_params);
00170 }
00171
00176 inline int32_t get_num_params()
00177 {
00178 return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters();
00179 }
00180
00185 inline bool check_models()
00186 {
00187 return ( (pos_model!=NULL) && (neg_model!=NULL) );
00188 }
00189
00191 inline virtual const char* get_name() const { return "RealFeatures"; }
00192
00193 protected:
00195 float64_t m_pos_pseudo;
00197 float64_t m_neg_pseudo;
00198
00200 CLinearHMM* pos_model;
00202 CLinearHMM* neg_model;
00203
00205 CStringFeatures<uint16_t>* features;
00206 };
00207 #endif