PluginEstimate.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _PLUGINESTIMATE_H___
00012 #define _PLUGINESTIMATE_H___
00013 
00014 #include "classifier/Classifier.h"
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "distributions/hmm/LinearHMM.h"
00018 
00032 class CPluginEstimate: public CClassifier
00033 {
00034     public:
00039         CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
00040         virtual ~CPluginEstimate();
00041 
00046         bool train();
00047 
00049         CLabels* classify(CLabels* output=NULL);
00050 
00055         virtual inline void set_features(CStringFeatures<uint16_t>* feat)
00056         {
00057             SG_UNREF(features);
00058             SG_REF(feat);
00059             features=feat;
00060         }
00061 
00066         virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; }
00067 
00069         float64_t classify_example(int32_t vec_idx);
00070 
00077         inline float64_t posterior_log_odds_obsolete(
00078             uint16_t* vector, int32_t len)
00079         {
00080             return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
00081         }
00082 
00089         inline float64_t get_parameterwise_log_odds(
00090             uint16_t obs, int32_t position)
00091         {
00092             return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
00093         }
00094 
00101         inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
00102         {
00103             return pos_model->get_log_derivative_obsolete(obs, pos);
00104         }
00105 
00112         inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
00113         {
00114             return neg_model->get_log_derivative_obsolete(obs, pos);
00115         }
00116 
00125         inline bool get_model_params(
00126             float64_t*& pos_params, float64_t*& neg_params,
00127             int32_t &seq_length, int32_t &num_symbols)
00128         {
00129             int32_t num;
00130 
00131             if ((!pos_model) || (!neg_model))
00132             {
00133                 SG_ERROR( "no model available\n");
00134                 return false;
00135             }
00136 
00137             pos_model->get_log_transition_probs(&pos_params, &num);
00138             neg_model->get_log_transition_probs(&neg_params, &num);
00139 
00140             seq_length = pos_model->get_sequence_length();
00141             num_symbols = pos_model->get_num_symbols();
00142             ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters());
00143             ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols());
00144             return true;
00145         }
00146 
00153         inline void set_model_params(
00154             const float64_t* pos_params, const float64_t* neg_params,
00155             int32_t seq_length, int32_t num_symbols)
00156         {
00157             int32_t num_params;
00158 
00159             delete pos_model;
00160             pos_model=new CLinearHMM(seq_length, num_symbols);
00161             delete neg_model;
00162             neg_model=new CLinearHMM(seq_length, num_symbols);
00163 
00164             num_params=pos_model->get_num_model_parameters();
00165             ASSERT(seq_length*num_symbols==num_params);
00166             ASSERT(num_params==neg_model->get_num_model_parameters());
00167 
00168             pos_model->set_log_transition_probs(pos_params, num_params);
00169             neg_model->set_log_transition_probs(neg_params, num_params);
00170         }
00171 
00176         inline int32_t get_num_params()
00177         {
00178             return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters();
00179         }
00180         
00185         inline bool check_models()
00186         {
00187             return ( (pos_model!=NULL) && (neg_model!=NULL) );
00188         }
00189 
00191         inline virtual const char* get_name() const { return "RealFeatures"; }
00192 
00193     protected:
00195         float64_t m_pos_pseudo;
00197         float64_t m_neg_pseudo;
00198 
00200         CLinearHMM* pos_model;
00202         CLinearHMM* neg_model;
00203 
00205         CStringFeatures<uint16_t>* features;
00206 };
00207 #endif

SHOGUN Machine Learning Toolbox - Documentation