HMM.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #ifndef __CHMM_H__
00013 #define __CHMM_H__
00014 
00015 #include "lib/Mathematics.h"
00016 #include "lib/common.h"
00017 #include "lib/io.h"
00018 #include "lib/config.h"
00019 #include "features/StringFeatures.h"
00020 #include "distributions/Distribution.h"
00021 
00022 #include <stdio.h>
00023 
00024 #ifdef USE_HMMPARALLEL
00025 #define USE_HMMPARALLEL_STRUCTURES 1
00026 #endif
00027 
00028 class CHMM;
00031 
00033 typedef float64_t T_ALPHA_BETA_TABLE;
00034 
00035 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00037 struct T_ALPHA_BETA
00038 {
00040     int32_t dimension;
00041 
00043     T_ALPHA_BETA_TABLE* table;
00044 
00046     bool updated;
00047 
00049     float64_t sum;
00050 };
00051 #endif // DOXYGEN_SHOULD_SKIP_THIS
00052 
00057 #ifdef USE_BIGSTATES
00058 typedef uint16_t T_STATES ;
00059 #else
00060 typedef uint8_t T_STATES ;
00061 #endif
00062 typedef T_STATES* P_STATES ;
00063 
00065 
00066 enum BaumWelchViterbiType
00067 {
00068     BW_NORMAL,
00069     BW_TRANS,
00070     BW_DEFINED,
00071     VIT_NORMAL,
00072     VIT_DEFINED
00073 };
00074 
00075 
00077 class CModel
00078 {
00079     public:
00081         CModel();
00082 
00084         virtual ~CModel();
00085 
00087         inline void sort_learn_a()
00088         {
00089             CMath::sort(learn_a,2) ;
00090         }
00091 
00093         inline void sort_learn_b()
00094         {
00095             CMath::sort(learn_b,2) ;
00096         }
00097 
00102 
00103         inline int32_t get_learn_a(int32_t line, int32_t column) const
00104         {
00105             return learn_a[line*2 + column];
00106         }
00107 
00109         inline int32_t get_learn_b(int32_t line, int32_t column) const 
00110         {
00111             return learn_b[line*2 + column];
00112         }
00113 
00115         inline int32_t get_learn_p(int32_t offset) const 
00116         {
00117             return learn_p[offset];
00118         }
00119 
00121         inline int32_t get_learn_q(int32_t offset) const 
00122         {
00123             return learn_q[offset];
00124         }
00125 
00127         inline int32_t get_const_a(int32_t line, int32_t column) const
00128         {
00129             return const_a[line*2 + column];
00130         }
00131 
00133         inline int32_t get_const_b(int32_t line, int32_t column) const 
00134         {
00135             return const_b[line*2 + column];
00136         }
00137 
00139         inline int32_t get_const_p(int32_t offset) const 
00140         {
00141             return const_p[offset];
00142         }
00143 
00145         inline int32_t get_const_q(int32_t offset) const
00146         {
00147             return const_q[offset];
00148         }
00149 
00151         inline float64_t get_const_a_val(int32_t line) const
00152         {
00153             return const_a_val[line];
00154         }
00155 
00157         inline float64_t get_const_b_val(int32_t line) const 
00158         {
00159             return const_b_val[line];
00160         }
00161 
00163         inline float64_t get_const_p_val(int32_t offset) const 
00164         {
00165             return const_p_val[offset];
00166         }
00167 
00169         inline float64_t get_const_q_val(int32_t offset) const
00170         {
00171             return const_q_val[offset];
00172         }
00173 #ifdef FIX_POS
00175         inline char get_fix_pos_state(int32_t pos, T_STATES state, T_STATES num_states)
00176         {
00177 #ifdef HMM_DEBUG
00178             if ((pos<0)||(pos*num_states+state>65336))
00179                 SG_DEBUG("index out of range in get_fix_pos_state(%i,%i,%i) \n", pos,state,num_states) ;
00180 #endif
00181             return fix_pos_state[pos*num_states+state] ;
00182         }
00183 #endif
00184 
00185 
00190 
00191         inline void set_learn_a(int32_t offset, int32_t value)
00192         {
00193             learn_a[offset]=value;
00194         }
00195 
00197         inline void set_learn_b(int32_t offset, int32_t value)
00198         {
00199             learn_b[offset]=value;
00200         }
00201 
00203         inline void set_learn_p(int32_t offset, int32_t value)
00204         {
00205             learn_p[offset]=value;
00206         }
00207 
00209         inline void set_learn_q(int32_t offset, int32_t value)
00210         {
00211             learn_q[offset]=value;
00212         }
00213 
00215         inline void set_const_a(int32_t offset, int32_t value)
00216         {
00217             const_a[offset]=value;
00218         }
00219 
00221         inline void set_const_b(int32_t offset, int32_t value)
00222         {
00223             const_b[offset]=value;
00224         }
00225 
00227         inline void set_const_p(int32_t offset, int32_t value)
00228         {
00229             const_p[offset]=value;
00230         }
00231 
00233         inline void set_const_q(int32_t offset, int32_t value)
00234         {
00235             const_q[offset]=value;
00236         }
00237 
00239         inline void set_const_a_val(int32_t offset, float64_t value)
00240         {
00241             const_a_val[offset]=value;
00242         }
00243 
00245         inline void set_const_b_val(int32_t offset, float64_t value)
00246         {
00247             const_b_val[offset]=value;
00248         }
00249 
00251         inline void set_const_p_val(int32_t offset, float64_t value)
00252         {
00253             const_p_val[offset]=value;
00254         }
00255 
00257         inline void set_const_q_val(int32_t offset, float64_t value)
00258         {
00259             const_q_val[offset]=value;
00260         }
00261 #ifdef FIX_POS
00263         inline void set_fix_pos_state(
00264             int32_t pos, T_STATES state, T_STATES num_states, char value)
00265         {
00266 #ifdef HMM_DEBUG
00267             if ((pos<0)||(pos*num_states+state>65336))
00268                 SG_DEBUG("index out of range in set_fix_pos_state(%i,%i,%i,%i) [%i]\n", pos,state,num_states,(int)value, pos*num_states+state) ;
00269 #endif
00270             fix_pos_state[pos*num_states+state]=value;
00271             if (value==FIX_ALLOWED)
00272                 for (int32_t i=0; i<num_states; i++)
00273                     if (get_fix_pos_state(pos,i,num_states)==FIX_DEFAULT)
00274                         set_fix_pos_state(pos,i,num_states,FIX_DISALLOWED) ;
00275         }
00277 
00279         const static char FIX_DISALLOWED ;
00280 
00282         const static char FIX_ALLOWED ;
00283 
00285         const static char FIX_DEFAULT ;
00286 
00288         const static float64_t DISALLOWED_PENALTY ;
00289 #endif
00290     protected:
00297 
00298         int32_t* learn_a;
00299 
00301         int32_t* learn_b;
00302 
00304         int32_t* learn_p;
00305 
00307         int32_t* learn_q;
00309 
00316 
00317         int32_t* const_a;
00318 
00320         int32_t* const_b;
00321 
00323         int32_t* const_p;
00324 
00326         int32_t* const_q;       
00327 
00328 
00330         float64_t* const_a_val;
00331 
00333         float64_t* const_b_val;
00334 
00336         float64_t* const_p_val;
00337 
00339         float64_t* const_q_val;     
00340 
00341 #ifdef FIX_POS
00342 
00345         char* fix_pos_state;
00346 #endif
00347 
00348 };
00349 
00350 
00361 class CHMM : public CDistribution
00362 {
00363     private:
00364 
00365         T_STATES trans_list_len ;
00366         T_STATES **trans_list_forward  ;
00367         T_STATES *trans_list_forward_cnt  ;
00368         float64_t **trans_list_forward_val ;
00369         T_STATES **trans_list_backward  ;
00370         T_STATES *trans_list_backward_cnt  ;
00371         bool mem_initialized ;
00372 
00373 #ifdef USE_HMMPARALLEL_STRUCTURES
00374 
00376         struct S_DIM_THREAD_PARAM
00377         {
00378             CHMM* hmm;
00379             int32_t dim;
00380             float64_t prob_sum;
00381         };
00382 
00384         struct S_BW_THREAD_PARAM
00385         {
00386             CHMM* hmm;
00387             int32_t dim_start;
00388             int32_t dim_stop;
00389 
00390             float64_t ret;
00391 
00392             float64_t* p_buf;
00393             float64_t* q_buf;
00394             float64_t* a_buf;
00395             float64_t* b_buf;
00396         };
00397 
00398         inline T_ALPHA_BETA & ALPHA_CACHE(int32_t dim) {
00399             return alpha_cache[dim%parallel->get_num_threads()] ; } ;
00400         inline T_ALPHA_BETA & BETA_CACHE(int32_t dim) {
00401             return beta_cache[dim%parallel->get_num_threads()] ; } ;
00402 #ifdef USE_LOGSUMARRAY 
00403         inline float64_t* ARRAYS(int32_t dim) {
00404             return arrayS[dim%parallel->get_num_threads()] ; } ;
00405 #endif
00406         inline float64_t* ARRAYN1(int32_t dim) {
00407             return arrayN1[dim%parallel->get_num_threads()] ; } ;
00408         inline float64_t* ARRAYN2(int32_t dim) {
00409             return arrayN2[dim%parallel->get_num_threads()] ; } ;
00410         inline T_STATES* STATES_PER_OBSERVATION_PSI(int32_t dim) {
00411             return states_per_observation_psi[dim%parallel->get_num_threads()] ; } ;
00412         inline const T_STATES* STATES_PER_OBSERVATION_PSI(int32_t dim) const {
00413             return states_per_observation_psi[dim%parallel->get_num_threads()] ; } ;
00414         inline T_STATES* PATH(int32_t dim) {
00415             return path[dim%parallel->get_num_threads()] ; } ;
00416         inline bool & PATH_PROB_UPDATED(int32_t dim) {
00417             return path_prob_updated[dim%parallel->get_num_threads()] ; } ;
00418         inline int32_t & PATH_PROB_DIMENSION(int32_t dim) {
00419             return path_prob_dimension[dim%parallel->get_num_threads()] ; } ;
00420 #else
00421         inline T_ALPHA_BETA & ALPHA_CACHE(int32_t /*dim*/) {
00422             return alpha_cache ; } ;
00423         inline T_ALPHA_BETA & BETA_CACHE(int32_t /*dim*/) {
00424             return beta_cache ; } ;
00425 #ifdef USE_LOGSUMARRAY
00426         inline float64_t* ARRAYS(int32_t dim) {
00427             return arrayS ; } ;
00428 #endif
00429         inline float64_t* ARRAYN1(int32_t /*dim*/) {
00430             return arrayN1 ; } ;
00431         inline float64_t* ARRAYN2(int32_t /*dim*/) {
00432             return arrayN2 ; } ;
00433         inline T_STATES* STATES_PER_OBSERVATION_PSI(int32_t /*dim*/) {
00434             return states_per_observation_psi ; } ;
00435         inline const T_STATES* STATES_PER_OBSERVATION_PSI(int32_t /*dim*/) const {
00436             return states_per_observation_psi ; } ;
00437         inline T_STATES* PATH(int32_t /*dim*/) {
00438             return path ; } ;
00439         inline bool & PATH_PROB_UPDATED(int32_t /*dim*/) {
00440             return path_prob_updated ; } ;
00441         inline int32_t & PATH_PROB_DIMENSION(int32_t /*dim*/) {
00442             return path_prob_dimension ; } ;
00443 #endif
00444 
00449         bool converged(float64_t x, float64_t y);
00450 
00456     public:
00467         CHMM(
00468             int32_t N, int32_t M, CModel* model, float64_t PSEUDO);
00469         CHMM(
00470             CStringFeatures<uint16_t>* obs, int32_t N, int32_t M,
00471             float64_t PSEUDO);
00472         CHMM(
00473             int32_t N, float64_t* p, float64_t* q, float64_t* a);
00474         CHMM(
00475             int32_t N, float64_t* p, float64_t* q, int32_t num_trans,
00476             float64_t* a_trans);
00477 
00482         CHMM(FILE* model_file, float64_t PSEUDO);
00483 
00485         CHMM(CHMM* h);
00486 
00488         virtual ~CHMM();
00489 
00490         virtual inline bool train() { return false; }
00491         virtual inline int32_t get_num_model_parameters() { return N*(N+M+2); }
00492         virtual float64_t get_log_model_parameter(int32_t num_param);
00493         virtual float64_t get_log_derivative(int32_t num_param, int32_t num_example);
00494         virtual float64_t get_log_likelihood_example(int32_t num_example)
00495         {
00496             return model_probability(num_example);
00497         }
00498 
00504         bool initialize(CModel* model, float64_t PSEUDO, FILE* model_file=NULL);
00506 
00508         bool alloc_state_dependend_arrays();
00509 
00511         void free_state_dependend_arrays();
00512 
00524         float64_t forward_comp(int32_t time, int32_t state, int32_t dimension);
00525         float64_t forward_comp_old(
00526             int32_t time, int32_t state, int32_t dimension);
00527 
00535         float64_t backward_comp(int32_t time, int32_t state, int32_t dimension);
00536         float64_t backward_comp_old(
00537             int32_t time, int32_t state, int32_t dimension);
00538 
00543         float64_t best_path(int32_t dimension);
00544         inline uint16_t get_best_path_state(int32_t dim, int32_t t)
00545         {
00546             ASSERT(PATH(dim));
00547             return PATH(dim)[t];
00548         }
00549 
00552         float64_t model_probability_comp() ;
00553 
00555         inline float64_t model_probability(int32_t dimension=-1)
00556         {
00557             //for faster calculation cache model probability
00558             if (dimension==-1)
00559             {
00560                 if (mod_prob_updated)
00561                     return mod_prob/p_observations->get_num_vectors();
00562                 else
00563                     return model_probability_comp()/p_observations->get_num_vectors();
00564             }
00565             else
00566                 return forward(p_observations->get_vector_length(dimension), 0, dimension);
00567         }
00568 
00574         inline float64_t linear_model_probability(int32_t dimension)
00575         {
00576             float64_t lik=0;
00577             int32_t len=0;
00578             uint16_t* o=p_observations->get_feature_vector(dimension, len);
00579             float64_t* obs_b=observation_matrix_b;
00580 
00581             ASSERT(N==len);
00582 
00583             for (int32_t i=0; i<N; i++)
00584             {
00585                 lik+=obs_b[*o++];
00586                 obs_b+=M;
00587             }
00588             return lik;
00589 
00590             // sorry, the above code is the speed optimized version of :
00591             /*  float64_t lik=0;
00592 
00593                 for (int32_t i=0; i<N; i++)
00594                 lik+=get_b(i, p_observations->get_feature(dimension, i));
00595                 return lik;
00596                 */
00597             // : that
00598         }
00599 
00601 
00604         inline bool set_iterations(int32_t num) { iterations=num; return true; }
00605         inline int32_t get_iterations() { return iterations; }
00606         inline bool set_epsilon (float64_t eps) { epsilon=eps; return true; }
00607         inline float64_t get_epsilon() { return epsilon; }
00608 
00612         bool baum_welch_viterbi_train(BaumWelchViterbiType type);
00613 
00620         void estimate_model_baum_welch(CHMM* train);
00621         void estimate_model_baum_welch_trans(CHMM* train);
00622 
00623 #ifdef USE_HMMPARALLEL_STRUCTURES
00624         void ab_buf_comp(
00625             float64_t* p_buf, float64_t* q_buf, float64_t* a_buf,
00626             float64_t* b_buf, int32_t dim) ;
00627 #else
00628         void estimate_model_baum_welch_old(CHMM* train);
00629 #endif
00630 
00634         void estimate_model_baum_welch_defined(CHMM* train);
00635 
00639         void estimate_model_viterbi(CHMM* train);
00640 
00644         void estimate_model_viterbi_defined(CHMM* train);
00645 
00647 
00649         bool linear_train(bool right_align=false);
00650 
00652         bool permutation_entropy(int32_t window_width, int32_t sequence_number);
00653 
00660         void output_model(bool verbose=false);
00661 
00663         void output_model_defined(bool verbose=false);
00665 
00666 
00669 
00671         void normalize(bool keep_dead_states=false);
00672 
00676         void add_states(int32_t num_states, float64_t default_val=0);
00677 
00683         bool append_model(
00684             CHMM* append_model, float64_t* cur_out, float64_t* app_out);
00685 
00689         bool append_model(CHMM* append_model);
00690 
00692         void chop(float64_t value);
00693 
00695         void convert_to_log();
00696 
00698         void init_model_random();
00699 
00705         void init_model_defined();
00706 
00708         void clear_model();
00709 
00711         void clear_model_defined();
00712 
00714         void copy_model(CHMM* l);
00715 
00720         void invalidate_model();
00721 
00725         inline bool get_status() const 
00726         {   
00727             return status; 
00728         } 
00729 
00731         inline float64_t get_pseudo() const
00732         {
00733             return PSEUDO ;
00734         }
00735 
00737         inline void set_pseudo(float64_t pseudo) 
00738         {
00739             PSEUDO=pseudo ;
00740         }
00741 
00742 #ifdef USE_HMMPARALLEL_STRUCTURES
00743         static void* bw_dim_prefetch(void * params);
00744         static void* bw_single_dim_prefetch(void * params);
00745         static void* vit_dim_prefetch(void * params);
00746 #endif
00747 
00748 #ifdef FIX_POS
00749 
00752         inline bool set_fix_pos_state(int32_t pos, T_STATES state, char value)
00753         {
00754             if (!model)
00755                 return false ;
00756             model->set_fix_pos_state(pos, state, N, value) ;
00757             return true ;
00758         } ;
00759 #endif  
00760 
00761 
00770         void set_observations(CStringFeatures<uint16_t>* obs, CHMM* hmm=NULL);
00771 
00775         void set_observation_nocache(CStringFeatures<uint16_t>* obs);
00776 
00778         inline CStringFeatures<uint16_t>* get_observations()
00779         {
00780             SG_REF(p_observations);
00781             return p_observations;
00782         }
00784 
00852         bool load_definitions(FILE* file, bool verbose, bool initialize=true);
00853 
00889         bool load_model(FILE* file);
00890 
00894         bool save_model(FILE* file);
00895 
00899         bool save_model_derivatives(FILE* file);
00900 
00904         bool save_model_derivatives_bin(FILE* file);
00905 
00909         bool save_model_bin(FILE* file);
00910 
00912         bool check_model_derivatives() ;
00913         bool check_model_derivatives_combined() ;
00914 
00920         T_STATES* get_path(int32_t dim, float64_t& prob);
00921 
00925         bool save_path(FILE* file);
00926 
00930         bool save_path_derivatives(FILE* file);
00931 
00935         bool save_path_derivatives_bin(FILE* file);
00936 
00937 #ifdef USE_HMMDEBUG
00939         bool check_path_derivatives() ;
00940 #endif //USE_HMMDEBUG
00941 
00945         bool save_likelihood_bin(FILE* file);
00946 
00950         bool save_likelihood(FILE* file);
00952 
00958 
00960         inline T_STATES get_N() const { return N ; }
00961 
00963         inline int32_t get_M() const { return M ; }
00964 
00969         inline void set_q(T_STATES offset, float64_t value)
00970         {
00971 #ifdef HMM_DEBUG
00972             if (offset>=N)
00973                 SG_DEBUG("index out of range in set_q(%i,%e) [%i]\n", offset,value,N) ;
00974 #endif
00975             end_state_distribution_q[offset]=value;
00976         }
00977 
00982         inline void set_p(T_STATES offset, float64_t value)
00983         {
00984 #ifdef HMM_DEBUG
00985             if (offset>=N)
00986                 SG_DEBUG("index out of range in set_p(%i,.) [%i]\n", offset,N) ;
00987 #endif
00988             initial_state_distribution_p[offset]=value;
00989         }
00990 
00996         inline void set_A(T_STATES line_, T_STATES column, float64_t value)
00997         {
00998 #ifdef HMM_DEBUG
00999             if ((line_>N)||(column>N))
01000                 SG_DEBUG("index out of range in set_A(%i,%i,.) [%i,%i]\n",line_,column,N,N) ;
01001 #endif
01002             transition_matrix_A[line_+column*N]=value;
01003         }
01004 
01010         inline void set_a(T_STATES line_, T_STATES column, float64_t value)
01011         {
01012 #ifdef HMM_DEBUG
01013             if ((line_>N)||(column>N))
01014                 SG_DEBUG("index out of range in set_a(%i,%i,.) [%i,%i]\n",line_,column,N,N) ;
01015 #endif
01016             transition_matrix_a[line_+column*N]=value; // look also best_path!
01017         }
01018 
01024         inline void set_B(T_STATES line_, uint16_t column, float64_t value)
01025         {
01026 #ifdef HMM_DEBUG
01027             if ((line_>=N)||(column>=M))
01028                 SG_DEBUG("index out of range in set_B(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01029 #endif
01030             observation_matrix_B[line_*M+column]=value;
01031         }
01032 
01038         inline void set_b(T_STATES line_, uint16_t column, float64_t value)
01039         {
01040 #ifdef HMM_DEBUG
01041             if ((line_>=N)||(column>=M))
01042                 SG_DEBUG("index out of range in set_b(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01043 #endif
01044             observation_matrix_b[line_*M+column]=value;
01045         }
01046 
01053         inline void set_psi(
01054             int32_t time, T_STATES state, T_STATES value, int32_t dimension)
01055         {
01056 #ifdef HMM_DEBUG
01057             if ((time>=p_observations->get_max_vector_length())||(state>N))
01058                 SG_DEBUG("index out of range in set_psi(%i,%i,.) [%i,%i]\n",time,state,p_observations->get_max_vector_length(),N) ;
01059 #endif
01060             STATES_PER_OBSERVATION_PSI(dimension)[time*N+state]=value;
01061         }
01062 
01067         inline float64_t get_q(T_STATES offset) const 
01068         {
01069 #ifdef HMM_DEBUG
01070             if (offset>=N)
01071                 SG_DEBUG("index out of range in %e=get_q(%i) [%i]\n", end_state_distribution_q[offset],offset,N) ;
01072 #endif
01073             return end_state_distribution_q[offset];
01074         }
01075 
01080         inline float64_t get_p(T_STATES offset) const 
01081         {
01082 #ifdef HMM_DEBUG
01083             if (offset>=N)
01084                 SG_DEBUG("index out of range in get_p(%i,.) [%i]\n", offset,N) ;
01085 #endif
01086             return initial_state_distribution_p[offset];
01087         }
01088 
01094         inline float64_t get_A(T_STATES line_, T_STATES column) const
01095         {
01096 #ifdef HMM_DEBUG
01097             if ((line_>N)||(column>N))
01098                 SG_DEBUG("index out of range in get_A(%i,%i) [%i,%i]\n",line_,column,N,N) ;
01099 #endif
01100             return transition_matrix_A[line_+column*N];
01101         }
01102 
01108         inline float64_t get_a(T_STATES line_, T_STATES column) const
01109         {
01110 #ifdef HMM_DEBUG
01111             if ((line_>N)||(column>N))
01112                 SG_DEBUG("index out of range in get_a(%i,%i) [%i,%i]\n",line_,column,N,N) ;
01113 #endif
01114             return transition_matrix_a[line_+column*N]; // look also best_path()!
01115         }
01116 
01122         inline float64_t get_B(T_STATES line_, uint16_t column) const
01123         {
01124 #ifdef HMM_DEBUG
01125             if ((line_>=N)||(column>=M))
01126                 SG_DEBUG("index out of range in get_B(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01127 #endif
01128             return observation_matrix_B[line_*M+column];
01129         }
01130 
01136         inline float64_t get_b(T_STATES line_, uint16_t column) const 
01137         {
01138 #ifdef HMM_DEBUG
01139             if ((line_>=N)||(column>=M))
01140                 SG_DEBUG("index out of range in get_b(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01141 #endif
01142             //SG_PRINT("idx %d\n", line_*M+column);
01143             return observation_matrix_b[line_*M+column];
01144         }
01145 
01152         inline T_STATES get_psi(
01153             int32_t time, T_STATES state, int32_t dimension) const
01154         {
01155 #ifdef HMM_DEBUG
01156             if ((time>=p_observations->get_max_vector_length())||(state>N))
01157                 SG_DEBUG("index out of range in get_psi(%i,%i) [%i,%i]\n",time,state,p_observations->get_max_vector_length(),N) ;
01158 #endif
01159             return STATES_PER_OBSERVATION_PSI(dimension)[time*N+state];
01160         }
01161 
01163 
01165         inline virtual const char* get_name() const { return "HMM"; }
01166     protected:
01171 
01172         int32_t M;
01173 
01175         int32_t N;
01176 
01178         float64_t PSEUDO;
01179 
01180         // line number during processing input files
01181         int32_t line;
01182 
01184         CStringFeatures<uint16_t>* p_observations;
01185 
01186         //train definition for HMM
01187         CModel* model;
01188 
01190         float64_t* transition_matrix_A;
01191 
01193         float64_t* observation_matrix_B;
01194 
01196         float64_t* transition_matrix_a;
01197 
01199         float64_t* initial_state_distribution_p;
01200 
01202         float64_t* end_state_distribution_q;        
01203 
01205         float64_t* observation_matrix_b;    
01206 
01208         int32_t iterations;
01209         int32_t iteration_count;
01210 
01212         float64_t epsilon;
01213         int32_t conv_it;
01214 
01216         float64_t all_pat_prob; 
01217 
01219         float64_t pat_prob; 
01220 
01222         float64_t mod_prob; 
01223 
01225         bool mod_prob_updated;  
01226 
01228         bool all_path_prob_updated; 
01229 
01231         int32_t path_deriv_dimension;
01232 
01234         bool path_deriv_updated;
01235 
01236         // true if model is using log likelihood
01237         bool loglikelihood;     
01238 
01239         // true->ok, false->error
01240         bool status;            
01241 
01242         // true->stolen from other HMMs, false->got own
01243         bool reused_caches;
01245 
01246 #ifdef USE_HMMPARALLEL_STRUCTURES
01247 
01248         float64_t** arrayN1 /*[parallel.get_num_threads()]*/ ;
01250         float64_t** arrayN2 /*[parallel.get_num_threads()]*/ ;
01251 #else //USE_HMMPARALLEL_STRUCTURES
01252 
01253         float64_t* arrayN1;
01255         float64_t* arrayN2;
01256 #endif //USE_HMMPARALLEL_STRUCTURES
01257 
01258 #ifdef USE_LOGSUMARRAY
01259 #ifdef USE_HMMPARALLEL_STRUCTURES
01260 
01261         float64_t** arrayS /*[parallel.get_num_threads()]*/;
01262 #else
01263 
01264         float64_t* arrayS;
01265 #endif // USE_HMMPARALLEL_STRUCTURES
01266 #endif // USE_LOGSUMARRAY
01267 
01268 #ifdef USE_HMMPARALLEL_STRUCTURES
01269 
01271         T_ALPHA_BETA* alpha_cache /*[parallel.get_num_threads()]*/ ;
01273         T_ALPHA_BETA* beta_cache /*[parallel.get_num_threads()]*/ ;
01274 
01276         T_STATES** states_per_observation_psi /*[parallel.get_num_threads()]*/ ;
01277 
01279         T_STATES** path /*[parallel.get_num_threads()]*/ ;
01280 
01282         bool* path_prob_updated /*[parallel.get_num_threads()]*/;
01283 
01285         int32_t* path_prob_dimension /*[parallel.get_num_threads()]*/ ; 
01286 
01287 #else //USE_HMMPARALLEL_STRUCTURES
01289         T_ALPHA_BETA alpha_cache;
01291         T_ALPHA_BETA beta_cache;
01292 
01294         T_STATES* states_per_observation_psi;
01295 
01297         T_STATES* path;
01298 
01300         bool path_prob_updated;
01301 
01303         int32_t path_prob_dimension;
01304 
01305 #endif //USE_HMMPARALLEL_STRUCTURES
01306 
01307 
01309         static const int32_t GOTN;
01311         static const int32_t GOTM;
01313         static const int32_t GOTO;
01315         static const int32_t GOTa;
01317         static const int32_t GOTb;
01319         static const int32_t GOTp;
01321         static const int32_t GOTq;
01322 
01324         static const int32_t GOTlearn_a;
01326         static const int32_t GOTlearn_b;
01328         static const int32_t GOTlearn_p;
01330         static const int32_t GOTlearn_q;
01332         static const int32_t GOTconst_a;
01334         static const int32_t GOTconst_b;
01336         static const int32_t GOTconst_p;
01338         static const int32_t GOTconst_q;
01339 
01340         public:
01345 
01347 inline float64_t state_probability(
01348     int32_t time, int32_t state, int32_t dimension)
01349 {
01350     return forward(time, state, dimension) + backward(time, state, dimension) - model_probability(dimension);
01351 }
01352 
01354 inline float64_t transition_probability(
01355     int32_t time, int32_t state_i, int32_t state_j, int32_t dimension)
01356 {
01357     return forward(time, state_i, dimension) + 
01358         backward(time+1, state_j, dimension) + 
01359         get_a(state_i,state_j) + get_b(state_j,p_observations->get_feature(dimension ,time+1)) - model_probability(dimension);
01360 }
01361 
01368 
01371 inline float64_t linear_model_derivative(
01372     T_STATES i, uint16_t j, int32_t dimension)
01373 {
01374     float64_t der=0;
01375 
01376     for (int32_t k=0; k<N; k++)
01377     {
01378         if (k!=i || p_observations->get_feature(dimension, k) != j)
01379             der+=get_b(k, p_observations->get_feature(dimension, k));
01380     }
01381 
01382     return der;
01383 }
01384 
01388 inline float64_t model_derivative_p(T_STATES i, int32_t dimension)
01389 {
01390     return backward(0,i,dimension)+get_b(i, p_observations->get_feature(dimension, 0));     
01391 }
01392 
01396 inline float64_t model_derivative_q(T_STATES i, int32_t dimension)
01397 {
01398     return forward(p_observations->get_vector_length(dimension)-1,i,dimension) ;
01399 }
01400 
01402 inline float64_t model_derivative_a(T_STATES i, T_STATES j, int32_t dimension)
01403 {
01404     float64_t sum=-CMath::INFTY;
01405     for (int32_t t=0; t<p_observations->get_vector_length(dimension)-1; t++)
01406         sum= CMath::logarithmic_sum(sum, forward(t, i, dimension) + backward(t+1, j, dimension) + get_b(j, p_observations->get_feature(dimension,t+1)));
01407 
01408     return sum;
01409 }
01410 
01411 
01413 inline float64_t model_derivative_b(T_STATES i, uint16_t j, int32_t dimension)
01414 {
01415     float64_t sum=-CMath::INFTY;
01416     for (int32_t t=0; t<p_observations->get_vector_length(dimension); t++)
01417     {
01418         if (p_observations->get_feature(dimension,t)==j)
01419             sum= CMath::logarithmic_sum(sum, forward(t,i,dimension)+backward(t,i,dimension)-get_b(i,p_observations->get_feature(dimension,t)));
01420     }
01421     //if (sum==-CMath::INFTY)
01422     // SG_DEBUG( "log derivative is -inf: dim=%i, state=%i, obs=%i\n",dimension, i, j) ;
01423     return sum;
01424 } 
01426 
01433 
01435 inline float64_t path_derivative_p(T_STATES i, int32_t dimension)
01436 {
01437     best_path(dimension);
01438     return (i==PATH(dimension)[0]) ? (exp(-get_p(PATH(dimension)[0]))) : (0) ;
01439 }
01440 
01442 inline float64_t path_derivative_q(T_STATES i, int32_t dimension)
01443 {
01444     best_path(dimension);
01445     return (i==PATH(dimension)[p_observations->get_vector_length(dimension)-1]) ? (exp(-get_q(PATH(dimension)[p_observations->get_vector_length(dimension)-1]))) : 0 ;
01446 }
01447 
01449 inline float64_t path_derivative_a(T_STATES i, T_STATES j, int32_t dimension)
01450 {
01451     prepare_path_derivative(dimension) ;
01452     return (get_A(i,j)==0) ? (0) : (get_A(i,j)*exp(-get_a(i,j))) ;
01453 }
01454 
01456 inline float64_t path_derivative_b(T_STATES i, uint16_t j, int32_t dimension)
01457 {
01458     prepare_path_derivative(dimension) ;
01459     return (get_B(i,j)==0) ? (0) : (get_B(i,j)*exp(-get_b(i,j))) ;
01460 } 
01461 
01463 
01464 
01465 protected:
01470 
01471     bool get_numbuffer(FILE* file, char* buffer, int32_t length);
01472 
01474     void open_bracket(FILE* file);
01475 
01477     void close_bracket(FILE* file);
01478 
01480     bool comma_or_space(FILE* file);
01481 
01483     inline void error(int32_t p_line, const char* str)
01484     {
01485         if (p_line)
01486             SG_ERROR( "error in line %d %s\n", p_line, str);
01487         else
01488             SG_ERROR( "error %s\n", str);
01489     }
01491 
01493     inline void prepare_path_derivative(int32_t dim)
01494     {
01495         if (path_deriv_updated && (path_deriv_dimension==dim))
01496             return ;
01497         int32_t i,j,t ;
01498         best_path(dim);
01499         //initialize with zeros
01500         for (i=0; i<N; i++)
01501         {
01502             for (j=0; j<N; j++)
01503                 set_A(i,j, 0);
01504             for (j=0; j<M; j++)
01505                 set_B(i,j, 0);
01506         }
01507 
01508         //counting occurences for A and B
01509         for (t=0; t<p_observations->get_vector_length(dim)-1; t++)
01510         {
01511             set_A(PATH(dim)[t], PATH(dim)[t+1], get_A(PATH(dim)[t], PATH(dim)[t+1])+1);
01512             set_B(PATH(dim)[t], p_observations->get_feature(dim,t),  get_B(PATH(dim)[t], p_observations->get_feature(dim,t))+1);
01513         }
01514         set_B(PATH(dim)[p_observations->get_vector_length(dim)-1], p_observations->get_feature(dim,p_observations->get_vector_length(dim)-1),  get_B(PATH(dim)[p_observations->get_vector_length(dim)-1], p_observations->get_feature(dim,p_observations->get_vector_length(dim)-1)) + 1);
01515         path_deriv_dimension=dim ;
01516         path_deriv_updated=true ;
01517     } ;
01519 
01521     inline float64_t forward(int32_t time, int32_t state, int32_t dimension)
01522     {
01523         if (time<1)
01524             time=0;
01525 
01526         if (ALPHA_CACHE(dimension).table && (dimension==ALPHA_CACHE(dimension).dimension) && ALPHA_CACHE(dimension).updated)
01527         {
01528             if (time<p_observations->get_vector_length(dimension))
01529                 return ALPHA_CACHE(dimension).table[time*N+state];
01530             else
01531                 return ALPHA_CACHE(dimension).sum;
01532         }
01533         else
01534             return forward_comp(time, state, dimension) ;
01535     }
01536 
01538     inline float64_t backward(int32_t time, int32_t state, int32_t dimension)
01539     {
01540         if (BETA_CACHE(dimension).table && (dimension==BETA_CACHE(dimension).dimension) && (BETA_CACHE(dimension).updated))
01541         {
01542             if (time<0)
01543                 return BETA_CACHE(dimension).sum;
01544             if (time<p_observations->get_vector_length(dimension))
01545                 return BETA_CACHE(dimension).table[time*N+state];
01546             else
01547                 return -CMath::INFTY;
01548         }
01549         else
01550             return backward_comp(time, state, dimension) ;
01551     }
01552 
01553 };
01554 #endif

SHOGUN Machine Learning Toolbox - Documentation