Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _LINEARHMM_H__
00013 #define _LINEARHMM_H__
00014
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "distributions/Distribution.h"
00018
00019 namespace shogun
00020 {
00039 class CLinearHMM : public CDistribution
00040 {
00041 public:
00046 CLinearHMM(CStringFeatures<uint16_t>* f);
00047
00053 CLinearHMM(int32_t p_num_features, int32_t p_num_symbols);
00054 virtual ~CLinearHMM();
00055
00064 virtual bool train(CFeatures* data=NULL);
00065
00073 bool train(
00074 const int32_t* indizes, int32_t num_indizes,
00075 float64_t pseudo_count);
00076
00083 float64_t get_log_likelihood_example(uint16_t* vector, int32_t len);
00084
00091 float64_t get_likelihood_example(uint16_t* vector, int32_t len);
00092
00098 virtual float64_t get_log_likelihood_example(int32_t num_example);
00099
00106 virtual float64_t get_log_derivative(
00107 int32_t num_param, int32_t num_example);
00108
00115 virtual inline float64_t get_log_derivative_obsolete(
00116 uint16_t obs, int32_t pos)
00117 {
00118 return 1.0/transition_probs[pos*num_symbols+obs];
00119 }
00120
00127 virtual inline float64_t get_derivative_obsolete(
00128 uint16_t* vector, int32_t len, int32_t pos)
00129 {
00130 ASSERT(pos<len);
00131 return get_likelihood_example(vector, len)/transition_probs[pos*num_symbols+vector[pos]];
00132 }
00133
00138 virtual inline int32_t get_sequence_length() { return sequence_length; }
00139
00144 virtual inline int32_t get_num_symbols() { return num_symbols; }
00145
00150 virtual inline int32_t get_num_model_parameters() { return num_params; }
00151
00158 virtual inline float64_t get_positional_log_parameter(
00159 uint16_t obs, int32_t position)
00160 {
00161 return log_transition_probs[position*num_symbols+obs];
00162 }
00163
00169 virtual inline float64_t get_log_model_parameter(int32_t num_param)
00170 {
00171 ASSERT(log_transition_probs);
00172 ASSERT(num_param<num_params);
00173
00174 return log_transition_probs[num_param];
00175 }
00176
00184 virtual void get_log_transition_probs(float64_t** dst, int32_t* num);
00185
00192 virtual bool set_log_transition_probs(
00193 const float64_t* src, int32_t num);
00194
00200 virtual void get_transition_probs(float64_t** dst, int32_t* num);
00201
00208 virtual bool set_transition_probs(const float64_t* src, int32_t num);
00209
00211 inline virtual const char* get_name() const { return "LinearHMM"; }
00212
00213 protected:
00215 int32_t sequence_length;
00217 int32_t num_symbols;
00219 int32_t num_params;
00221 float64_t* transition_probs;
00223 float64_t* log_transition_probs;
00224 };
00225 }
00226 #endif