00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <vector>
00013
00014 #include "lib/common.h"
00015 #include "lib/io.h"
00016 #include "lib/Signal.h"
00017 #include "lib/Trie.h"
00018 #include "base/Parallel.h"
00019
00020 #include "kernel/SpectrumMismatchRBFKernel.h"
00021 #include "features/Features.h"
00022 #include "features/StringFeatures.h"
00023
00024
00025 #include <vector>
00026 #include <string>
00027
00028 #include <assert.h>
00029
00030 #ifndef WIN32
00031 #include <pthread.h>
00032 #endif
00033
00034
00035 using namespace shogun;
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel (
00058 int32_t size, float64_t *AA_matrix_, int32_t degree_, int32_t max_mismatch_, float64_t width_)
00059 : CStringKernel<char>(size), alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_)
00060 {
00061 lhs=NULL;
00062 rhs=NULL;
00063
00064 target_letter_0=-1 ;
00065
00066 AA_matrix=new float64_t[128*128];
00067 memcpy(AA_matrix, AA_matrix_, 128*128*sizeof(float64_t)) ;
00068 }
00069
00070 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel(
00071 CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t size, float64_t* AA_matrix_, int32_t degree_, int32_t max_mismatch_, float64_t width_)
00072 : CStringKernel<char>(size), alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_)
00073 {
00074 target_letter_0=-1 ;
00075
00076 AA_matrix=new float64_t[128*128];
00077 memcpy(AA_matrix, AA_matrix_, 128*128*sizeof(float64_t)) ;
00078 init(l, r);
00079 }
00080
00081 CSpectrumMismatchRBFKernel::~CSpectrumMismatchRBFKernel()
00082 {
00083 cleanup();
00084 delete[] AA_matrix ;
00085 }
00086
00087
00088 void CSpectrumMismatchRBFKernel::remove_lhs()
00089 {
00090
00091 CKernel::remove_lhs();
00092 }
00093
00094 bool CSpectrumMismatchRBFKernel::init(CFeatures* l, CFeatures* r)
00095 {
00096 int32_t lhs_changed=(lhs!=l);
00097 int32_t rhs_changed=(rhs!=r);
00098
00099 CStringKernel<char>::init(l,r);
00100
00101 SG_DEBUG("lhs_changed: %i\n", lhs_changed);
00102 SG_DEBUG("rhs_changed: %i\n", rhs_changed);
00103
00104 CStringFeatures<char>* sf_l=(CStringFeatures<char>*) l;
00105 CStringFeatures<char>* sf_r=(CStringFeatures<char>*) r;
00106
00107 SG_UNREF(alphabet);
00108 alphabet=sf_l->get_alphabet();
00109 CAlphabet* ralphabet=sf_r->get_alphabet();
00110
00111 if (!((alphabet->get_alphabet()==DNA) || (alphabet->get_alphabet()==RNA)))
00112 properties &= ((uint64_t) (-1)) ^ (KP_LINADD | KP_BATCHEVALUATION);
00113
00114 ASSERT(ralphabet->get_alphabet()==alphabet->get_alphabet());
00115 SG_UNREF(ralphabet);
00116
00117 compute_all() ;
00118
00119 return init_normalizer();
00120 }
00121
00122 void CSpectrumMismatchRBFKernel::cleanup()
00123 {
00124
00125 SG_UNREF(alphabet);
00126 alphabet=NULL;
00127
00128 CKernel::cleanup();
00129 }
00130
00131 float64_t CSpectrumMismatchRBFKernel::AA_helper(std::string &path, const char* joint_seq, unsigned int index)
00132 {
00133 float64_t diff=0.0 ;
00134
00135 for (unsigned int i=0; i<path.size(); i++)
00136 {
00137 if (path[i]!=joint_seq[index+i])
00138 {
00139 diff += AA_matrix[ (path[i]-1)*128 + path[i] - 1] ;
00140 diff -= 2*AA_matrix[ (path[i]-1)*128 + joint_seq[index+i] - 1] ;
00141 diff += AA_matrix[ (joint_seq[index+i]-1)*128 + joint_seq[index+i] - 1] ;
00142 }
00143 }
00144
00145 return exp( - diff/width) ;
00146 }
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238 void CSpectrumMismatchRBFKernel::compute_helper_all(const char *joint_seq,
00239 std::vector<struct joint_list_struct> &joint_list,
00240 std::string path, unsigned int d)
00241 {
00242 const char* AA = "ACDEFGHIKLMNPQRSTVWY" ;
00243 const unsigned int num_AA = strlen(AA) ;
00244
00245 assert(path.size()==d) ;
00246
00247 for (unsigned int i=0; i<num_AA; i++)
00248 {
00249 std::vector<struct joint_list_struct> joint_list_ ;
00250
00251 if (d==0)
00252 fprintf(stderr, "i=%i: ", i) ;
00253 if (d==0 && target_letter_0!=-1 && (int)i != target_letter_0 )
00254 continue ;
00255
00256 if (d==1)
00257 {
00258 fprintf(stdout, "*") ;
00259 fflush(stdout) ;
00260 }
00261 if (d==2)
00262 {
00263 fprintf(stdout, "+") ;
00264 fflush(stdout) ;
00265 }
00266
00267 for (unsigned int j=0; j<joint_list.size(); j++)
00268 {
00269 if (joint_seq[joint_list[j].index+d] != AA[i])
00270 {
00271 if (joint_list[j].mismatch+1 <= (unsigned int) max_mismatch)
00272 {
00273 struct joint_list_struct list_item ;
00274 list_item = joint_list[j] ;
00275 list_item.mismatch = joint_list[j].mismatch+1 ;
00276 joint_list_.push_back(list_item) ;
00277 }
00278 }
00279 else
00280 joint_list_.push_back(joint_list[j]) ;
00281 }
00282
00283 if (joint_list_.size()>0)
00284 {
00285 std::string path_ = path + AA[i] ;
00286
00287 if (d+1 < (unsigned int) degree)
00288 {
00289 compute_helper_all(joint_seq, joint_list_, path_, d+1) ;
00290 }
00291 else
00292 {
00293 CArray<float64_t> feats ;
00294 feats.resize_array(kernel_matrix.get_dim1()) ;
00295 feats.zero() ;
00296
00297 for (unsigned int j=0; j<joint_list_.size(); j++)
00298 {
00299 if (width==0.0)
00300 {
00301 feats[joint_list_[j].ex_index]++ ;
00302
00303
00304 }
00305 else
00306 {
00307 if (joint_list_[j].mismatch!=0)
00308 feats[joint_list_[j].ex_index] += AA_helper(path_, joint_seq, joint_list_[j].index) ;
00309 else
00310 feats[joint_list_[j].ex_index] ++ ;
00311 }
00312 }
00313
00314 std::vector<int> idx ;
00315 for (int r=0; r<feats.get_array_size(); r++)
00316 if (feats[r]!=0.0)
00317 idx.push_back(r) ;
00318
00319 for (unsigned int r=0; r<idx.size(); r++)
00320 for (unsigned int s=r; s<idx.size(); s++)
00321 if (s==r)
00322 kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s]) ;
00323 else
00324 {
00325 kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s]) ;
00326 kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[s],idx[r]), idx[s], idx[r]) ;
00327 }
00328 }
00329 }
00330 if (d==0)
00331 fprintf(stdout, "\n") ;
00332 }
00333 }
00334
00335 void CSpectrumMismatchRBFKernel::compute_all()
00336 {
00337 std::string joint_seq ;
00338 std::vector<struct joint_list_struct> joint_list ;
00339
00340 assert(lhs->get_num_vectors()==rhs->get_num_vectors()) ;
00341 kernel_matrix.resize_array(lhs->get_num_vectors(), lhs->get_num_vectors()) ;
00342 for (int i=0; i<lhs->get_num_vectors(); i++)
00343 for (int j=0; j<lhs->get_num_vectors(); j++)
00344 kernel_matrix.set_element(0, i, j) ;
00345
00346 for (int i=0; i<lhs->get_num_vectors(); i++)
00347 {
00348 int32_t alen ;
00349 bool free_avec ;
00350 char* avec = ((CStringFeatures<char>*) lhs)->get_feature_vector(i, alen, free_avec);
00351
00352 for (int apos=0; apos+degree-1<alen; apos++)
00353 {
00354 struct joint_list_struct list_item ;
00355 list_item.ex_index = i ;
00356 list_item.index = apos+joint_seq.size() ;
00357 list_item.mismatch = 0 ;
00358
00359 joint_list.push_back(list_item) ;
00360 }
00361 joint_seq += std::string(avec, alen) ;
00362
00363 ((CStringFeatures<char>*) lhs)->free_feature_vector(avec, i, free_avec);
00364 }
00365
00366 compute_helper_all(joint_seq.c_str(), joint_list, "", 0) ;
00367 }
00368
00369
00370 float64_t CSpectrumMismatchRBFKernel::compute(int32_t idx_a, int32_t idx_b)
00371 {
00372 return kernel_matrix.element(idx_a, idx_b) ;
00373 }
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390
00391
00392
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424 bool CSpectrumMismatchRBFKernel::set_AA_matrix(
00425 float64_t* AA_matrix_)
00426 {
00427
00428 if (AA_matrix_)
00429 {
00430 SG_DEBUG("Setting AA_matrix\n") ;
00431 memcpy(AA_matrix, AA_matrix_, 128*128*sizeof(float64_t)) ;
00432 return true ;
00433 }
00434
00435 return false;
00436 }
00437
00438 bool CSpectrumMismatchRBFKernel::set_max_mismatch(int32_t max)
00439 {
00440 max_mismatch=max;
00441
00442 if (lhs!=NULL && rhs!=NULL)
00443 return init(lhs, rhs);
00444 else
00445 return true;
00446 }