SpectrumMismatchRBFKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #include <vector>
00013 
00014 #include "lib/common.h"
00015 #include "lib/io.h"
00016 #include "lib/Signal.h"
00017 #include "lib/Trie.h"
00018 #include "base/Parallel.h"
00019 
00020 #include "kernel/SpectrumMismatchRBFKernel.h"
00021 #include "features/Features.h"
00022 #include "features/StringFeatures.h"
00023 
00024 
00025 #include <vector>
00026 #include <string>
00027 
00028 #include <assert.h>
00029 
00030 #ifndef WIN32
00031 #include <pthread.h>
00032 #endif
00033 
00034 
00035 using namespace shogun;
00036 
00037 /*
00038 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00039 struct S_THREAD_PARAM
00040 {
00041 
00042     int32_t* vec;
00043     float64_t* result;
00044     float64_t* weights;
00045     CSpectrumMismatchRBFKernel* kernel;
00046     CTrie<DNATrie>* tries;
00047     float64_t factor;
00048     int32_t j;
00049     int32_t start;
00050     int32_t end;
00051     int32_t length;
00052     int32_t* vec_idx;
00053 };
00054 #endif // DOXYGEN_SHOULD_SKIP_THIS
00055 */
00056         
00057 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel (
00058     int32_t size, float64_t *AA_matrix_, int32_t degree_, int32_t max_mismatch_, float64_t width_)
00059 : CStringKernel<char>(size), alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_)
00060 {
00061     lhs=NULL;
00062     rhs=NULL;
00063 
00064     target_letter_0=-1 ;
00065 
00066     AA_matrix=new float64_t[128*128];
00067     memcpy(AA_matrix, AA_matrix_, 128*128*sizeof(float64_t)) ;
00068 }
00069 
00070 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel(
00071     CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t size, float64_t* AA_matrix_, int32_t degree_, int32_t max_mismatch_, float64_t width_)
00072 : CStringKernel<char>(size), alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_)
00073 {
00074     target_letter_0=-1 ;
00075 
00076     AA_matrix=new float64_t[128*128];
00077     memcpy(AA_matrix, AA_matrix_, 128*128*sizeof(float64_t)) ;
00078     init(l, r);
00079 }
00080 
00081 CSpectrumMismatchRBFKernel::~CSpectrumMismatchRBFKernel()
00082 {
00083     cleanup();
00084     delete[] AA_matrix ;
00085 }
00086 
00087 
00088 void CSpectrumMismatchRBFKernel::remove_lhs()
00089 {
00090 
00091     CKernel::remove_lhs();
00092 }
00093 
00094 bool CSpectrumMismatchRBFKernel::init(CFeatures* l, CFeatures* r)
00095 {
00096     int32_t lhs_changed=(lhs!=l);
00097     int32_t rhs_changed=(rhs!=r);
00098 
00099     CStringKernel<char>::init(l,r);
00100 
00101     SG_DEBUG("lhs_changed: %i\n", lhs_changed);
00102     SG_DEBUG("rhs_changed: %i\n", rhs_changed);
00103 
00104     CStringFeatures<char>* sf_l=(CStringFeatures<char>*) l;
00105     CStringFeatures<char>* sf_r=(CStringFeatures<char>*) r;
00106 
00107     SG_UNREF(alphabet);
00108     alphabet=sf_l->get_alphabet();
00109     CAlphabet* ralphabet=sf_r->get_alphabet();
00110 
00111     if (!((alphabet->get_alphabet()==DNA) || (alphabet->get_alphabet()==RNA)))
00112         properties &= ((uint64_t) (-1)) ^ (KP_LINADD | KP_BATCHEVALUATION);
00113 
00114     ASSERT(ralphabet->get_alphabet()==alphabet->get_alphabet());
00115     SG_UNREF(ralphabet);
00116 
00117     compute_all() ;
00118     
00119     return init_normalizer();
00120 }
00121 
00122 void CSpectrumMismatchRBFKernel::cleanup()
00123 {
00124 
00125     SG_UNREF(alphabet);
00126     alphabet=NULL;
00127 
00128     CKernel::cleanup();
00129 }
00130 
00131 float64_t CSpectrumMismatchRBFKernel::AA_helper(std::string &path, const char* joint_seq, unsigned int index)
00132 {
00133     float64_t diff=0.0 ;
00134 
00135     for (unsigned int i=0; i<path.size(); i++)
00136     {
00137         if (path[i]!=joint_seq[index+i])
00138         {
00139             diff += AA_matrix[ (path[i]-1)*128 + path[i] - 1] ;
00140             diff -= 2*AA_matrix[ (path[i]-1)*128 + joint_seq[index+i] - 1] ;
00141             diff += AA_matrix[ (joint_seq[index+i]-1)*128 + joint_seq[index+i] - 1] ;
00142         }
00143     }
00144 
00145     return exp( - diff/width) ;
00146 }
00147 
00148 /*
00149 float64_t CSpectrumMismatchRBFKernel::compute_helper(const char* joint_seq, 
00150                                                       std::vector<unsigned int> joint_index, std::vector<unsigned int> joint_mismatch, 
00151                                                       std::string path, unsigned int d, 
00152                                                       const int & alen) 
00153 {
00154     const char* AA = "ACDEFGHIKLMNPQRSTVWY" ;
00155     const unsigned int num_AA = strlen(AA) ;
00156 
00157     assert(path.size()==d) ;
00158     assert(joint_mismatch.size()==joint_index.size()) ;
00159     
00160     float64_t res = 0.0 ;
00161     
00162     for (unsigned int i=0; i<num_AA; i++)
00163     {
00164         std::vector<unsigned int> joint_mismatch_ ;
00165         std::vector<unsigned int> joint_index_ ;
00166 
00167         for (unsigned int j=0; j<joint_index.size(); j++)
00168         {
00169             if (joint_seq[joint_index[j]+d] != AA[i])
00170             {
00171                 if (joint_mismatch[j]+1 <= (unsigned int) max_mismatch)
00172                 {
00173                     joint_mismatch_.push_back(joint_mismatch[j]+1) ;
00174                     joint_index_.push_back(joint_index[j]) ;
00175                 }
00176             }
00177             else
00178             {
00179                 joint_mismatch_.push_back(joint_mismatch[j]) ;
00180                 joint_index_.push_back(joint_index[j]) ;
00181             }
00182         }
00183         if (joint_mismatch_.size()>0)
00184         {
00185             std::string path_ = path + AA[i] ;
00186 
00187             if (d+1 < (unsigned int) degree)
00188             {
00189                 res += compute_helper(joint_seq,  joint_index_, joint_mismatch_, path_, d+1, alen) ;
00190             }
00191             else
00192             {
00193                 int anum=0, bnum=0;
00194                 for (unsigned int j=0; j<joint_index_.size(); j++)
00195                     if (joint_index_[j] < (unsigned int)alen)
00196                     {
00197                         if (1)
00198                         {
00199                             anum++ ;
00200                             if (joint_mismatch_[j]==0)
00201                                 anum+=3 ;
00202                         }
00203                         else
00204                         {
00205                             if (joint_mismatch_[j]!=0)
00206                                 anum += AA_helper(path_, joint_seq, joint_index_[j]) ;
00207                             else
00208                                 anum++ ;
00209                         }
00210                     }
00211                     else
00212                     {
00213                         if (1)
00214                         {
00215                             bnum++ ;
00216                             if (joint_mismatch_[j]==0)
00217                                 bnum+=3 ;
00218                         }
00219                         else
00220                         {
00221                             if (joint_mismatch_[j]!=0)
00222                                 bnum += AA_helper(path_, joint_seq, joint_index_[j]) ;
00223                             else
00224                                 bnum++ ;
00225                         }
00226                     }
00227                 
00228                 //fprintf(stdout, "%s: %i x %i\n", path_.c_str(), anum, bnum) ;
00229                 
00230                 res+= anum*bnum ;
00231             }
00232         }
00233     }
00234     return res ;
00235 }
00236 */
00237 
00238 void CSpectrumMismatchRBFKernel::compute_helper_all(const char *joint_seq, 
00239                                                      std::vector<struct joint_list_struct> &joint_list,
00240                                                      std::string path, unsigned int d) 
00241 {
00242     const char* AA = "ACDEFGHIKLMNPQRSTVWY" ;
00243     const unsigned int num_AA = strlen(AA) ;
00244 
00245     assert(path.size()==d) ;
00246     
00247     for (unsigned int i=0; i<num_AA; i++)
00248     {
00249         std::vector<struct joint_list_struct> joint_list_ ;
00250         
00251         if (d==0)
00252             fprintf(stderr, "i=%i: ", i) ;
00253         if (d==0 && target_letter_0!=-1 && (int)i != target_letter_0 )
00254             continue ;
00255         
00256         if (d==1)
00257         {
00258             fprintf(stdout, "*") ;
00259             fflush(stdout) ;
00260         }
00261         if (d==2)
00262         {
00263             fprintf(stdout, "+") ;
00264             fflush(stdout) ;
00265         }
00266 
00267         for (unsigned int j=0; j<joint_list.size(); j++)
00268         {
00269             if (joint_seq[joint_list[j].index+d] != AA[i])
00270             {
00271                 if (joint_list[j].mismatch+1 <= (unsigned int) max_mismatch)
00272                 {
00273                     struct joint_list_struct list_item ;
00274                     list_item = joint_list[j] ;
00275                     list_item.mismatch = joint_list[j].mismatch+1 ;
00276                     joint_list_.push_back(list_item) ;
00277                 }
00278             }
00279             else
00280                 joint_list_.push_back(joint_list[j]) ;
00281         }
00282 
00283         if (joint_list_.size()>0)
00284         {
00285             std::string path_ = path + AA[i] ;
00286 
00287             if (d+1 < (unsigned int) degree)
00288             {
00289                 compute_helper_all(joint_seq,  joint_list_, path_, d+1) ;
00290             }
00291             else
00292             {
00293                 CArray<float64_t> feats ;
00294                 feats.resize_array(kernel_matrix.get_dim1()) ;
00295                 feats.zero() ;
00296                 
00297                 for (unsigned int j=0; j<joint_list_.size(); j++)
00298                 {
00299                     if (width==0.0)
00300                     {
00301                         feats[joint_list_[j].ex_index]++ ;
00302                         //if (joint_mismatch_[j]==0)
00303                         //  feats[joint_ex_index_[j]]+=3 ;
00304                     }
00305                     else
00306                     {
00307                         if (joint_list_[j].mismatch!=0)
00308                             feats[joint_list_[j].ex_index] += AA_helper(path_, joint_seq, joint_list_[j].index) ;
00309                         else
00310                             feats[joint_list_[j].ex_index] ++ ;
00311                     }
00312                 }
00313 
00314                 std::vector<int> idx ;
00315                 for (int r=0; r<feats.get_array_size(); r++)
00316                     if (feats[r]!=0.0)
00317                         idx.push_back(r) ;
00318 
00319                 for (unsigned int r=0; r<idx.size(); r++)
00320                     for (unsigned int s=r; s<idx.size(); s++)
00321                         if (s==r)
00322                             kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s])  ;
00323                         else
00324                         {
00325                             kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s])  ;
00326                             kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[s],idx[r]), idx[s], idx[r])  ;
00327                         }
00328             }
00329         }
00330         if (d==0)
00331             fprintf(stdout, "\n") ;
00332     }
00333 }
00334 
00335 void CSpectrumMismatchRBFKernel::compute_all()
00336 {
00337     std::string joint_seq ; 
00338     std::vector<struct joint_list_struct> joint_list ;
00339 
00340     assert(lhs->get_num_vectors()==rhs->get_num_vectors()) ;
00341     kernel_matrix.resize_array(lhs->get_num_vectors(), lhs->get_num_vectors()) ;
00342     for (int i=0; i<lhs->get_num_vectors(); i++)
00343         for (int j=0; j<lhs->get_num_vectors(); j++)
00344             kernel_matrix.set_element(0, i, j) ;
00345     
00346     for (int i=0; i<lhs->get_num_vectors(); i++)
00347     {
00348         int32_t alen ;
00349         bool free_avec ;
00350         char* avec = ((CStringFeatures<char>*) lhs)->get_feature_vector(i, alen, free_avec);
00351 
00352         for (int apos=0; apos+degree-1<alen; apos++)
00353         {
00354             struct joint_list_struct list_item ;
00355             list_item.ex_index = i ;
00356             list_item.index = apos+joint_seq.size() ;
00357             list_item.mismatch = 0 ;
00358             
00359             joint_list.push_back(list_item) ;
00360         }
00361         joint_seq += std::string(avec, alen) ;
00362         
00363         ((CStringFeatures<char>*) lhs)->free_feature_vector(avec, i, free_avec);
00364     }
00365     
00366     compute_helper_all(joint_seq.c_str(), joint_list, "", 0) ;
00367 }
00368 
00369 
00370 float64_t CSpectrumMismatchRBFKernel::compute(int32_t idx_a, int32_t idx_b)
00371 {
00372     return kernel_matrix.element(idx_a, idx_b) ;
00373 }
00374 /*
00375 bool CSpectrumMismatchRBFKernel::set_weights(
00376     float64_t* ws, int32_t d, int32_t len)
00377 {
00378     if (d==128 && len==128)
00379     {
00380         SG_DEBUG("Setting AA_matrix\n") ;
00381         memcpy(AA_matrix, ws, 128*128*sizeof(float64_t)) ;
00382         return true ;
00383     }
00384 
00385     if (d==1 && len==1)
00386     {
00387         sigma=ws[0] ;
00388         SG_DEBUG("Setting sigma to %e\n", sigma) ;
00389         return true ;
00390     }
00391 
00392     if (d==2 && len==2)
00393     {
00394         target_letter_0=ws[0] ;
00395         SG_DEBUG("Setting target letter to %c\n", target_letter_0) ;
00396         return true ;
00397     }
00398 
00399     if (d!=degree || len<1)
00400         SG_ERROR("Dimension mismatch (should be de(seq_length | 1) x degree)\n");
00401 
00402     length=len;
00403 
00404     if (length==0)
00405         length=1;
00406 
00407     int32_t num_weights=degree*(length+max_mismatch);
00408     delete[] weights;
00409     weights=new float64_t[num_weights];
00410 
00411     if (weights)
00412     {
00413         for (int32_t i=0; i<num_weights; i++) {
00414             if (ws[i]) // len(ws) might be != num_weights?
00415                 weights[i]=ws[i];
00416         }
00417         return true;
00418     }
00419     else
00420         return false;
00421 }
00422 */
00423 
00424 bool CSpectrumMismatchRBFKernel::set_AA_matrix(
00425     float64_t* AA_matrix_)
00426 {
00427 
00428     if (AA_matrix_)
00429     {
00430         SG_DEBUG("Setting AA_matrix\n") ;
00431         memcpy(AA_matrix, AA_matrix_, 128*128*sizeof(float64_t)) ;
00432         return true ;
00433     }
00434 
00435     return false;
00436 }
00437 
00438 bool CSpectrumMismatchRBFKernel::set_max_mismatch(int32_t max)
00439 {
00440     max_mismatch=max;
00441 
00442     if (lhs!=NULL && rhs!=NULL)
00443         return init(lhs, rhs);
00444     else
00445         return true;
00446 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation