LinearHMM.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _LINEARHMM_H__
00013 #define _LINEARHMM_H__
00014
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "distributions/Distribution.h"
00018
00037 class CLinearHMM : public CDistribution
00038 {
00039 public:
00044 CLinearHMM(CStringFeatures<uint16_t>* f);
00045
00051 CLinearHMM(int32_t p_num_features, int32_t p_num_symbols);
00052 virtual ~CLinearHMM();
00053
00058 bool train();
00059
00067 bool train(
00068 const int32_t* indizes, int32_t num_indizes,
00069 float64_t pseudo_count);
00070
00077 float64_t get_log_likelihood_example(uint16_t* vector, int32_t len);
00078
00085 float64_t get_likelihood_example(uint16_t* vector, int32_t len);
00086
00092 virtual float64_t get_log_likelihood_example(int32_t num_example);
00093
00100 virtual float64_t get_log_derivative(
00101 int32_t num_param, int32_t num_example);
00102
00109 virtual inline float64_t get_log_derivative_obsolete(
00110 uint16_t obs, int32_t pos)
00111 {
00112 return 1.0/transition_probs[pos*num_symbols+obs];
00113 }
00114
00121 virtual inline float64_t get_derivative_obsolete(
00122 uint16_t* vector, int32_t len, int32_t pos)
00123 {
00124 ASSERT(pos<len);
00125 return get_likelihood_example(vector, len)/transition_probs[pos*num_symbols+vector[pos]];
00126 }
00127
00132 virtual inline int32_t get_sequence_length() { return sequence_length; }
00133
00138 virtual inline int32_t get_num_symbols() { return num_symbols; }
00139
00144 virtual inline int32_t get_num_model_parameters() { return num_params; }
00145
00152 virtual inline float64_t get_positional_log_parameter(
00153 uint16_t obs, int32_t position)
00154 {
00155 return log_transition_probs[position*num_symbols+obs];
00156 }
00157
00163 virtual inline float64_t get_log_model_parameter(int32_t num_param)
00164 {
00165 ASSERT(log_transition_probs);
00166 ASSERT(num_param<num_params);
00167
00168 return log_transition_probs[num_param];
00169 }
00170
00178 virtual void get_log_transition_probs(float64_t** dst, int32_t* num);
00179
00186 virtual bool set_log_transition_probs(
00187 const float64_t* src, int32_t num);
00188
00194 virtual void get_transition_probs(float64_t** dst, int32_t* num);
00195
00202 virtual bool set_transition_probs(const float64_t* src, int32_t num);
00203
00205 inline virtual const char* get_name() const { return "LinearHMM"; }
00206
00207 protected:
00209 int32_t sequence_length;
00211 int32_t num_symbols;
00213 int32_t num_params;
00215 float64_t* transition_probs;
00217 float64_t* log_transition_probs;
00218 };
00219 #endif