WeightedDegreePositionStringKernel.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #ifndef _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___
00013 #define _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___
00014 
00015 #include "lib/common.h"
00016 #include "kernel/StringKernel.h"
00017 #include "kernel/WeightedDegreeStringKernel.h"
00018 #include "lib/Trie.h"
00019 
00020 class CSVM ;
00021 
00045 class CWeightedDegreePositionStringKernel: public CStringKernel<char>
00046 {
00047     public:
00055         CWeightedDegreePositionStringKernel(
00056             int32_t size, int32_t degree,
00057             int32_t max_mismatch=0, int32_t mkl_stepsize=1);
00058 
00069         CWeightedDegreePositionStringKernel(
00070             int32_t size, float64_t* weights, int32_t degree,
00071             int32_t max_mismatch, int32_t* shift, int32_t shift_len,
00072             int32_t mkl_stepsize=1);
00073 
00080         CWeightedDegreePositionStringKernel(
00081             CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t degree);
00082 
00083         virtual ~CWeightedDegreePositionStringKernel();
00084 
00091         virtual bool init(CFeatures* l, CFeatures* r);
00092 
00094         virtual void cleanup();
00095 
00101         bool load_init(FILE* src);
00102 
00108         bool save_init(FILE* dest);
00109 
00114         virtual EKernelType get_kernel_type() { return K_WEIGHTEDDEGREEPOS; }
00115 
00120         virtual const char* get_name() const { return "WeightedDegreePos"; }
00121 
00129         inline virtual bool init_optimization(
00130             int32_t p_count, int32_t *IDX, float64_t * alphas)
00131         { 
00132             return init_optimization(p_count, IDX, alphas, -1);
00133         }
00134 
00146         virtual bool init_optimization(
00147             int32_t count, int32_t *IDX, float64_t * alphas, int32_t tree_num,
00148             int32_t upto_tree=-1);
00149 
00154         virtual bool delete_optimization();
00155 
00161         inline virtual float64_t compute_optimized(int32_t idx)
00162         { 
00163             ASSERT(get_is_initialized());
00164             ASSERT(alphabet);
00165             ASSERT(alphabet->get_alphabet()==DNA || alphabet->get_alphabet()==RNA);
00166             return compute_by_tree(idx);
00167         }
00168 
00173         static void* compute_batch_helper(void* p);
00174 
00185         virtual void compute_batch(
00186             int32_t num_vec, int32_t* vec_idx, float64_t* target,
00187             int32_t num_suppvec, int32_t* IDX, float64_t* alphas,
00188             float64_t factor=1.0);
00189 
00193         inline virtual void clear_normal()
00194         {
00195             if ((opt_type==FASTBUTMEMHUNGRY) && (tries.get_use_compact_terminal_nodes()))
00196             {
00197                 tries.set_use_compact_terminal_nodes(false) ;
00198                 SG_DEBUG( "disabling compact trie nodes with FASTBUTMEMHUNGRY\n") ;
00199             }
00200 
00201             if (get_is_initialized())
00202             {
00203                 if (opt_type==SLOWBUTMEMEFFICIENT)
00204                     tries.delete_trees(true); 
00205                 else if (opt_type==FASTBUTMEMHUNGRY)
00206                     tries.delete_trees(false);  // still buggy
00207                 else
00208                     SG_ERROR( "unknown optimization type\n");
00209 
00210                 set_is_initialized(false);
00211             }
00212         }
00213 
00219         inline virtual void add_to_normal(int32_t idx, float64_t weight)
00220         {
00221             add_example_to_tree(idx, weight);
00222             set_is_initialized(true);
00223         }
00224 
00229         inline virtual int32_t get_num_subkernels()
00230         {
00231             if (position_weights!=NULL)
00232                 return (int32_t) ceil(1.0*seq_length/mkl_stepsize) ;
00233             if (length==0)
00234                 return (int32_t) ceil(1.0*get_degree()/mkl_stepsize);
00235             return (int32_t) ceil(1.0*get_degree()*length/mkl_stepsize) ;
00236         }
00237 
00243         inline void compute_by_subkernel(
00244             int32_t idx, float64_t * subkernel_contrib)
00245         { 
00246             if (get_is_initialized())
00247             {
00248                 compute_by_tree(idx, subkernel_contrib);
00249                 return ;
00250             }
00251 
00252             SG_ERROR( "CWeightedDegreePositionStringKernel optimization not initialized\n") ;
00253         }
00254 
00260         inline const float64_t* get_subkernel_weights(int32_t& num_weights)
00261         {
00262             num_weights = get_num_subkernels() ;
00263 
00264             delete[] weights_buffer ;
00265             weights_buffer = new float64_t[num_weights] ;
00266 
00267             if (position_weights!=NULL)
00268                 for (int32_t i=0; i<num_weights; i++)
00269                     weights_buffer[i] = position_weights[i*mkl_stepsize] ;
00270             else
00271                 for (int32_t i=0; i<num_weights; i++)
00272                     weights_buffer[i] = weights[i*mkl_stepsize] ;
00273 
00274             return weights_buffer ;
00275         }
00276 
00282         inline void set_subkernel_weights(
00283             float64_t* weights2, int32_t num_weights2)
00284         {
00285             int32_t num_weights = get_num_subkernels() ;
00286             if (num_weights!=num_weights2)
00287                 SG_ERROR( "number of weights do not match\n") ;
00288 
00289             if (position_weights!=NULL)
00290                 for (int32_t i=0; i<num_weights; i++)
00291                     for (int32_t j=0; j<mkl_stepsize; j++)
00292                     {
00293                         if (i*mkl_stepsize+j<seq_length)
00294                             position_weights[i*mkl_stepsize+j] = weights2[i] ;
00295                     }
00296             else if (length==0)
00297             {
00298                 for (int32_t i=0; i<num_weights; i++)
00299                     for (int32_t j=0; j<mkl_stepsize; j++)
00300                         if (i*mkl_stepsize+j<get_degree())
00301                             weights[i*mkl_stepsize+j] = weights2[i] ;
00302             }
00303             else
00304             {
00305                 for (int32_t i=0; i<num_weights; i++)
00306                     for (int32_t j=0; j<mkl_stepsize; j++)
00307                         if (i*mkl_stepsize+j<get_degree()*length)
00308                             weights[i*mkl_stepsize+j] = weights2[i] ;
00309             }
00310         }
00311 
00312         // other kernel tree operations
00318         float64_t* compute_abs_weights(int32_t & len);
00319 
00324         bool is_tree_initialized() { return tree_initialized; }
00325 
00330         inline int32_t get_max_mismatch() { return max_mismatch; }
00331 
00336         inline int32_t get_degree() { return degree; }
00337 
00343         inline float64_t *get_degree_weights(int32_t& d, int32_t& len)
00344         {
00345             d=degree;
00346             len=length;
00347             return weights;
00348         }
00349 
00355         inline float64_t *get_weights(int32_t& num_weights)
00356         {
00357             if (position_weights!=NULL)
00358             {
00359                 num_weights = seq_length ;
00360                 return position_weights ;
00361             }
00362             if (length==0)
00363                 num_weights = degree ;
00364             else
00365                 num_weights = degree*length ;
00366             return weights;
00367         }
00368 
00374         inline float64_t *get_position_weights(int32_t& len)
00375         {
00376             len=seq_length;
00377             return position_weights;
00378         }
00379 
00385         bool set_shifts(int32_t* shifts, int32_t len);
00386 
00393         virtual bool set_weights(float64_t* weights, int32_t d, int32_t len=0);
00394 
00399         virtual bool set_wd_weights();
00400 
00407         virtual bool set_position_weights(float64_t* pws, int32_t len=0);
00408 
00416         bool set_position_weights_lhs(float64_t* pws, int32_t len, int32_t num);
00417 
00425         bool set_position_weights_rhs(float64_t* pws, int32_t len, int32_t num);
00426 
00431         bool init_block_weights();
00432 
00437         bool init_block_weights_from_wd();
00438 
00443         bool init_block_weights_from_wd_external();
00444 
00449         bool init_block_weights_const();
00450 
00455         bool init_block_weights_linear();
00456 
00461         bool init_block_weights_sqpoly();
00462 
00467         bool init_block_weights_cubicpoly();
00468 
00473         bool init_block_weights_exp();
00474 
00479         bool init_block_weights_log();
00480 
00485         bool init_block_weights_external();
00486 
00491         bool delete_position_weights()
00492         {
00493             delete[] position_weights;
00494             position_weights=NULL;
00495             return true;
00496         }
00497 
00502         bool delete_position_weights_lhs()
00503         {
00504             delete[] position_weights_lhs;
00505             position_weights_lhs=NULL;
00506             return true;
00507         }
00508 
00513         bool delete_position_weights_rhs()
00514         {
00515             delete[] position_weights_rhs;
00516             position_weights_rhs=NULL;
00517             return true;
00518         }
00519 
00525         virtual float64_t compute_by_tree(int32_t idx);
00526 
00532         virtual void compute_by_tree(int32_t idx, float64_t* LevelContrib);
00533 
00546         float64_t* compute_scoring(
00547             int32_t max_degree, int32_t& num_feat, int32_t& num_sym,
00548             float64_t* target, int32_t num_suppvec, int32_t* IDX,
00549             float64_t* weights);
00550 
00559         char* compute_consensus(
00560             int32_t &num_feat, int32_t num_suppvec, int32_t* IDX,
00561             float64_t* alphas);
00562 
00574         float64_t* extract_w(
00575             int32_t max_degree, int32_t& num_feat, int32_t& num_sym,
00576             float64_t* w_result, int32_t num_suppvec, int32_t* IDX,
00577             float64_t* alphas);
00578 
00591         float64_t* compute_POIM(
00592             int32_t max_degree, int32_t& num_feat, int32_t& num_sym,
00593             float64_t* poim_result, int32_t num_suppvec, int32_t* IDX,
00594             float64_t* alphas, float64_t* distrib);
00595 
00602         void prepare_POIM2(
00603             float64_t* distrib, int32_t num_sym, int32_t num_feat);
00604 
00611         void compute_POIM2(int32_t max_degree, CSVM* svm);
00612 
00618         void get_POIM2(float64_t** poim, int32_t* result_len);
00619 
00621         void cleanup_POIM2();
00622         
00623     protected:
00625         void create_empty_tries();
00626 
00632         virtual void add_example_to_tree(
00633             int32_t idx, float64_t weight);
00634 
00641         void add_example_to_single_tree(
00642             int32_t idx, float64_t weight, int32_t tree_num);
00643 
00652         virtual float64_t compute(int32_t idx_a, int32_t idx_b);
00653 
00662         float64_t compute_with_mismatch(
00663             char* avec, int32_t alen, char* bvec, int32_t blen);
00664 
00673         float64_t compute_without_mismatch(
00674             char* avec, int32_t alen, char* bvec, int32_t blen);
00675 
00684         float64_t compute_without_mismatch_matrix(
00685             char* avec, int32_t alen, char* bvec, int32_t blen);
00686 
00697         float64_t compute_without_mismatch_position_weights(
00698             char* avec, float64_t *posweights_lhs, int32_t alen,
00699             char* bvec, float64_t *posweights_rhs, int32_t blen);
00700 
00702         virtual void remove_lhs();
00703 
00704     protected:
00706         float64_t* weights;
00708         float64_t* position_weights;
00710         float64_t* position_weights_lhs;
00712         float64_t* position_weights_rhs;
00714         bool* position_mask;
00715 
00717         float64_t* weights_buffer;
00719         int32_t mkl_stepsize;
00720 
00722         int32_t degree;
00724         int32_t length;
00725 
00727         int32_t max_mismatch;
00729         int32_t seq_length;
00730 
00732         int32_t *shift;
00734         int32_t shift_len;
00736         int32_t max_shift;
00737 
00739         bool block_computation;
00740 
00742         int32_t num_block_weights_external;
00744         float64_t* block_weights_external;
00745 
00747         float64_t* block_weights;
00749         EWDKernType type;
00751         int32_t which_degree;
00752 
00754         CTrie<DNATrie> tries;
00756         CTrie<POIMTrie> poim_tries;
00757 
00759         bool tree_initialized;
00761         bool use_poim_tries;
00762 
00764         float64_t* m_poim_distrib;
00766         float64_t* m_poim;
00767 
00769         int32_t m_poim_num_sym;
00771         int32_t m_poim_num_feat;
00773         int32_t m_poim_result_len;
00774 
00776         CAlphabet* alphabet;
00777 };
00778 #endif /* _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H__ */

SHOGUN Machine Learning Toolbox - Documentation