00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "kernel/CommWordStringKernel.h"
00013 #include "kernel/SqrtDiagKernelNormalizer.h"
00014 #include "features/StringFeatures.h"
00015 #include "lib/io.h"
00016
00017 CCommWordStringKernel::CCommWordStringKernel(int32_t size, bool s)
00018 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00019 use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00020 {
00021 properties |= KP_LINADD;
00022 init_dictionary(1<<(sizeof(uint16_t)*8));
00023 set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00024 }
00025
00026 CCommWordStringKernel::CCommWordStringKernel(
00027 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, bool s,
00028 int32_t size)
00029 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00030 use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00031 {
00032 properties |= KP_LINADD;
00033
00034 init_dictionary(1<<(sizeof(uint16_t)*8));
00035 set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00036 init(l,r);
00037 }
00038
00039
00040 bool CCommWordStringKernel::init_dictionary(int32_t size)
00041 {
00042 dictionary_size= size;
00043 delete[] dictionary_weights;
00044 dictionary_weights=new float64_t[size];
00045 SG_DEBUG( "using dictionary of %d words\n", size);
00046 clear_normal();
00047
00048 return dictionary_weights!=NULL;
00049 }
00050
00051 CCommWordStringKernel::~CCommWordStringKernel()
00052 {
00053 cleanup();
00054
00055 delete[] dictionary_weights;
00056 delete[] dict_diagonal_optimization ;
00057 }
00058
00059 bool CCommWordStringKernel::init(CFeatures* l, CFeatures* r)
00060 {
00061 CStringKernel<uint16_t>::init(l,r);
00062
00063 if (use_dict_diagonal_optimization)
00064 {
00065 delete[] dict_diagonal_optimization ;
00066 dict_diagonal_optimization=new int32_t[int32_t(((CStringFeatures<uint16_t>*)l)->get_num_symbols())];
00067 ASSERT(((CStringFeatures<uint16_t>*)l)->get_num_symbols() == ((CStringFeatures<uint16_t>*)r)->get_num_symbols()) ;
00068 }
00069
00070 return init_normalizer();
00071 }
00072
00073 void CCommWordStringKernel::cleanup()
00074 {
00075 delete_optimization();
00076 CKernel::cleanup();
00077 }
00078
00079 bool CCommWordStringKernel::load_init(FILE* src)
00080 {
00081 return false;
00082 }
00083
00084 bool CCommWordStringKernel::save_init(FILE* dest)
00085 {
00086 return false;
00087 }
00088
00089 float64_t CCommWordStringKernel::compute_diag(int32_t idx_a)
00090 {
00091 int32_t alen;
00092 CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00093 CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00094
00095 uint16_t* av=l->get_feature_vector(idx_a, alen);
00096
00097 float64_t result=0.0 ;
00098 ASSERT(l==r);
00099 ASSERT(sizeof(uint16_t)<=sizeof(float64_t));
00100 ASSERT((1<<(sizeof(uint16_t)*8)) > alen);
00101
00102 int32_t num_symbols=(int32_t) l->get_num_symbols();
00103 ASSERT(num_symbols<=dictionary_size);
00104
00105 int32_t* dic = dict_diagonal_optimization;
00106 memset(dic, 0, num_symbols*sizeof(int32_t));
00107
00108 for (int32_t i=0; i<alen; i++)
00109 dic[av[i]]++;
00110
00111 if (use_sign)
00112 {
00113 for (int32_t i=0; i<(int32_t) l->get_num_symbols(); i++)
00114 {
00115 if (dic[i]!=0)
00116 result++;
00117 }
00118 }
00119 else
00120 {
00121 for (int32_t i=0; i<num_symbols; i++)
00122 {
00123 if (dic[i]!=0)
00124 result+=dic[i]*dic[i];
00125 }
00126 }
00127
00128 return result;
00129 }
00130
00131 float64_t CCommWordStringKernel::compute_helper(
00132 int32_t idx_a, int32_t idx_b, bool do_sort)
00133 {
00134 int32_t alen, blen;
00135 CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00136 CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00137
00138 uint16_t* av=l->get_feature_vector(idx_a, alen);
00139 uint16_t* bv=r->get_feature_vector(idx_b, blen);
00140
00141 uint16_t* avec=av;
00142 uint16_t* bvec=bv;
00143
00144 if (do_sort)
00145 {
00146 if (alen>0)
00147 {
00148 avec=new uint16_t[alen];
00149 memcpy(avec, av, sizeof(uint16_t)*alen);
00150 CMath::radix_sort(avec, alen);
00151 }
00152 else
00153 avec=NULL;
00154
00155 if (blen>0)
00156 {
00157 bvec=new uint16_t[blen];
00158 memcpy(bvec, bv, sizeof(uint16_t)*blen);
00159 CMath::radix_sort(bvec, blen);
00160 }
00161 else
00162 bvec=NULL;
00163 }
00164 else
00165 {
00166 if ( (l->get_num_preproc() != l->get_num_preprocessed()) ||
00167 (r->get_num_preproc() != r->get_num_preprocessed()))
00168 {
00169 SG_ERROR("not all preprocessors have been applied to training (%d/%d)"
00170 " or test (%d/%d) data\n", l->get_num_preprocessed(), l->get_num_preproc(),
00171 r->get_num_preprocessed(), r->get_num_preproc());
00172 }
00173 }
00174
00175 float64_t result=0;
00176
00177 int32_t left_idx=0;
00178 int32_t right_idx=0;
00179
00180 if (use_sign)
00181 {
00182 while (left_idx < alen && right_idx < blen)
00183 {
00184 if (avec[left_idx]==bvec[right_idx])
00185 {
00186 uint16_t sym=avec[left_idx];
00187
00188 while (left_idx< alen && avec[left_idx]==sym)
00189 left_idx++;
00190
00191 while (right_idx< blen && bvec[right_idx]==sym)
00192 right_idx++;
00193
00194 result++;
00195 }
00196 else if (avec[left_idx]<bvec[right_idx])
00197 left_idx++;
00198 else
00199 right_idx++;
00200 }
00201 }
00202 else
00203 {
00204 while (left_idx < alen && right_idx < blen)
00205 {
00206 if (avec[left_idx]==bvec[right_idx])
00207 {
00208 int32_t old_left_idx=left_idx;
00209 int32_t old_right_idx=right_idx;
00210
00211 uint16_t sym=avec[left_idx];
00212
00213 while (left_idx< alen && avec[left_idx]==sym)
00214 left_idx++;
00215
00216 while (right_idx< blen && bvec[right_idx]==sym)
00217 right_idx++;
00218
00219 result+=((float64_t) (left_idx-old_left_idx))*
00220 ((float64_t) (right_idx-old_right_idx));
00221 }
00222 else if (avec[left_idx]<bvec[right_idx])
00223 left_idx++;
00224 else
00225 right_idx++;
00226 }
00227 }
00228
00229 if (do_sort)
00230 {
00231 delete[] avec;
00232 delete[] bvec;
00233 }
00234
00235 return result;
00236 }
00237
00238 void CCommWordStringKernel::add_to_normal(int32_t vec_idx, float64_t weight)
00239 {
00240 int32_t len=-1;
00241 uint16_t* vec=((CStringFeatures<uint16_t>*) lhs)->
00242 get_feature_vector(vec_idx, len);
00243
00244 if (len>0)
00245 {
00246 int32_t j, last_j=0;
00247 if (use_sign)
00248 {
00249 for (j=1; j<len; j++)
00250 {
00251 if (vec[j]==vec[j-1])
00252 continue;
00253
00254 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00255 normalize_lhs(weight, vec_idx);
00256 }
00257
00258 dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00259 normalize_lhs(weight, vec_idx);
00260 }
00261 else
00262 {
00263 for (j=1; j<len; j++)
00264 {
00265 if (vec[j]==vec[j-1])
00266 continue;
00267
00268 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00269 normalize_lhs(weight*(j-last_j), vec_idx);
00270 last_j = j;
00271 }
00272
00273 dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00274 normalize_lhs(weight*(len-last_j), vec_idx);
00275 }
00276 set_is_initialized(true);
00277 }
00278 }
00279
00280 void CCommWordStringKernel::clear_normal()
00281 {
00282 memset(dictionary_weights, 0, dictionary_size*sizeof(float64_t));
00283 set_is_initialized(false);
00284 }
00285
00286 bool CCommWordStringKernel::init_optimization(
00287 int32_t count, int32_t* IDX, float64_t* weights)
00288 {
00289 delete_optimization();
00290
00291 if (count<=0)
00292 {
00293 set_is_initialized(true);
00294 SG_DEBUG("empty set of SVs\n");
00295 return true;
00296 }
00297
00298 SG_DEBUG("initializing CCommWordStringKernel optimization\n");
00299
00300 for (int32_t i=0; i<count; i++)
00301 {
00302 if ( (i % (count/10+1)) == 0)
00303 SG_PROGRESS(i, 0, count);
00304
00305 add_to_normal(IDX[i], weights[i]);
00306 }
00307
00308 set_is_initialized(true);
00309 return true;
00310 }
00311
00312 bool CCommWordStringKernel::delete_optimization()
00313 {
00314 SG_DEBUG( "deleting CCommWordStringKernel optimization\n");
00315
00316 clear_normal();
00317 return true;
00318 }
00319
00320 float64_t CCommWordStringKernel::compute_optimized(int32_t i)
00321 {
00322 if (!get_is_initialized())
00323 {
00324 SG_ERROR( "CCommWordStringKernel optimization not initialized\n");
00325 return 0 ;
00326 }
00327
00328 float64_t result = 0;
00329 int32_t len = -1;
00330 uint16_t* vec=((CStringFeatures<uint16_t>*) rhs)->
00331 get_feature_vector(i, len);
00332
00333 int32_t j, last_j=0;
00334 if (vec && len>0)
00335 {
00336 if (use_sign)
00337 {
00338 for (j=1; j<len; j++)
00339 {
00340 if (vec[j]==vec[j-1])
00341 continue;
00342
00343 result += dictionary_weights[(int32_t) vec[j-1]];
00344 }
00345
00346 result += dictionary_weights[(int32_t) vec[len-1]];
00347 }
00348 else
00349 {
00350 for (j=1; j<len; j++)
00351 {
00352 if (vec[j]==vec[j-1])
00353 continue;
00354
00355 result += dictionary_weights[(int32_t) vec[j-1]]*(j-last_j);
00356 last_j = j;
00357 }
00358
00359 result += dictionary_weights[(int32_t) vec[len-1]]*(len-last_j);
00360 }
00361
00362 result=normalizer->normalize_rhs(result, i);
00363 }
00364 return result;
00365 }
00366
00367 float64_t* CCommWordStringKernel::compute_scoring(
00368 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, float64_t* target,
00369 int32_t num_suppvec, int32_t* IDX, float64_t* alphas, bool do_init)
00370 {
00371 ASSERT(lhs);
00372 CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00373 num_feat=1;
00374 CAlphabet* alpha=str->get_alphabet();
00375 ASSERT(alpha);
00376 int32_t num_bits=alpha->get_num_bits();
00377 int32_t order=str->get_order();
00378 ASSERT(max_degree<=order);
00379
00380 int32_t num_words=(int32_t) str->get_original_num_symbols();
00381 int32_t offset=0;
00382
00383 num_sym=0;
00384
00385 for (int32_t i=0; i<order; i++)
00386 num_sym+=CMath::pow((int32_t) num_words,i+1);
00387
00388 SG_DEBUG("num_words:%d, order:%d, len:%d sz:%d (len*sz:%d)\n", num_words, order,
00389 num_feat, num_sym, num_feat*num_sym);
00390
00391 if (!target)
00392 target=new float64_t[num_feat*num_sym];
00393 memset(target, 0, num_feat*num_sym*sizeof(float64_t));
00394
00395 if (do_init)
00396 init_optimization(num_suppvec, IDX, alphas);
00397
00398 uint32_t kmer_mask=0;
00399 uint32_t words=CMath::pow((int32_t) num_words,(int32_t) order);
00400
00401 for (int32_t o=0; o<max_degree; o++)
00402 {
00403 float64_t* contrib=&target[offset];
00404 offset+=CMath::pow((int32_t) num_words,(int32_t) o+1);
00405
00406 kmer_mask=(kmer_mask<<(num_bits)) | str->get_masked_symbols(0xffff, 1);
00407
00408 for (int32_t p=-o; p<order; p++)
00409 {
00410 int32_t o_sym=0, m_sym=0, il=0,ir=0, jl=0;
00411 uint32_t imer_mask=kmer_mask;
00412 uint32_t jmer_mask=kmer_mask;
00413
00414 if (p<0)
00415 {
00416 il=-p;
00417 m_sym=order-o-p-1;
00418 o_sym=-p;
00419 }
00420 else if (p<order-o)
00421 {
00422 ir=p;
00423 m_sym=order-o-1;
00424 }
00425 else
00426 {
00427 ir=p;
00428 m_sym=p;
00429 o_sym=p-order+o+1;
00430 jl=order-ir;
00431 imer_mask=(kmer_mask>>(num_bits*o_sym));
00432 jmer_mask=(kmer_mask>>(num_bits*jl));
00433 }
00434
00435 float64_t marginalizer=
00436 1.0/CMath::pow((int32_t) num_words,(int32_t) m_sym);
00437
00438 for (uint32_t i=0; i<words; i++)
00439 {
00440 uint16_t x= ((i << (num_bits*il)) >> (num_bits*ir)) & imer_mask;
00441
00442 if (p>=0 && p<order-o)
00443 {
00444
00445 #ifdef DEBUG_COMMSCORING
00446 SG_PRINT("o=%d/%d p=%d/%d i=0x%x x=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d marg=%g o_sym:%d m_sym:%d weight(",
00447 o,order, p,order, i, x, imer_mask, jmer_mask, kmer_mask, il, ir, marginalizer, o_sym, m_sym);
00448
00449 SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n",
00450 alpha->remap_to_char((x>>(3*num_bits))&0x03), alpha->remap_to_char((x>>(2*num_bits))&0x03),
00451 alpha->remap_to_char((x>>num_bits)&0x03), alpha->remap_to_char(x&0x03),
00452 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00453 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00454 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00455 #endif
00456 contrib[x]+=dictionary_weights[i]*marginalizer;
00457 }
00458 else
00459 {
00460 for (uint32_t j=0; j< (uint32_t) CMath::pow((int32_t) num_words, (int32_t) o_sym); j++)
00461 {
00462 uint32_t c=x | ((j & jmer_mask) << (num_bits*jl));
00463 #ifdef DEBUG_COMMSCORING
00464
00465 SG_PRINT("o=%d/%d p=%d/%d i=0x%x j=0x%x x=0x%x c=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d jl=%d marg=%g o_sym:%d m_sym:%d weight(",
00466 o,order, p,order, i, j, x, c, imer_mask, jmer_mask, kmer_mask, il, ir, jl, marginalizer, o_sym, m_sym);
00467 SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n",
00468 alpha->remap_to_char((c>>(3*num_bits))&0x03), alpha->remap_to_char((c>>(2*num_bits))&0x03),
00469 alpha->remap_to_char((c>>num_bits)&0x03), alpha->remap_to_char(c&0x03),
00470 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00471 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00472 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00473 #endif
00474 contrib[c]+=dictionary_weights[i]*marginalizer;
00475 }
00476 }
00477 }
00478 }
00479 }
00480
00481 for (int32_t i=1; i<num_feat; i++)
00482 memcpy(&target[num_sym*i], target, num_sym*sizeof(float64_t));
00483
00484 SG_UNREF(alpha);
00485
00486 return target;
00487 }
00488
00489
00490 char* CCommWordStringKernel::compute_consensus(
00491 int32_t &result_len, int32_t num_suppvec, int32_t* IDX, float64_t* alphas)
00492 {
00493 ASSERT(lhs);
00494 ASSERT(IDX);
00495 ASSERT(alphas);
00496
00497 CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00498 int32_t num_words=(int32_t) str->get_num_symbols();
00499 int32_t num_feat=str->get_max_vector_length();
00500 int64_t total_len=((int64_t) num_feat) * num_words;
00501 CAlphabet* alpha=((CStringFeatures<uint16_t>*) lhs)->get_alphabet();
00502 ASSERT(alpha);
00503 int32_t num_bits=alpha->get_num_bits();
00504 int32_t order=str->get_order();
00505 int32_t max_idx=-1;
00506 float64_t max_score=0;
00507 result_len=num_feat+order-1;
00508
00509
00510 init_optimization(num_suppvec, IDX, alphas);
00511
00512 char* result=new char[result_len];
00513 int32_t* bt=new int32_t[total_len];
00514 float64_t* score=new float64_t[total_len];
00515
00516 for (int64_t i=0; i<total_len; i++)
00517 {
00518 bt[i]=-1;
00519 score[i]=0;
00520 }
00521
00522 for (int32_t t=0; t<num_words; t++)
00523 score[t]=dictionary_weights[t];
00524
00525
00526 for (int32_t i=1; i<num_feat; i++)
00527 {
00528 for (int32_t t1=0; t1<num_words; t1++)
00529 {
00530 max_idx=-1;
00531 max_score=0;
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541
00542 uint16_t suffix=(uint16_t) t1 >> num_bits;
00543
00544 for (int32_t sym=0; sym<str->get_original_num_symbols(); sym++)
00545 {
00546 uint16_t t=suffix | sym << (num_bits*(order-1));
00547
00548
00549
00550
00551 float64_t sc=score[num_words*(i-1) + t]+dictionary_weights[t1];
00552 if (sc > max_score || max_idx==-1)
00553 {
00554 max_idx=t;
00555 max_score=sc;
00556 }
00557 }
00558 ASSERT(max_idx!=-1);
00559
00560 score[num_words*i + t1]=max_score;
00561 bt[num_words*i + t1]=max_idx;
00562 }
00563 }
00564
00565
00566 max_idx=0;
00567 max_score=score[num_words*(num_feat-1) + 0];
00568 for (int32_t t=1; t<num_words; t++)
00569 {
00570 float64_t sc=score[num_words*(num_feat-1) + t];
00571 if (sc>max_score)
00572 {
00573 max_idx=t;
00574 max_score=sc;
00575 }
00576 }
00577
00578 SG_PRINT("max_idx:%i, max_score:%f\n", max_idx, max_score);
00579
00580 for (int32_t i=result_len-1; i>=num_feat; i--)
00581 result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(result_len-1-i)), 1) );
00582
00583 for (int32_t i=num_feat-1; i>=0; i--)
00584 {
00585 result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(order-1)), 1) );
00586 max_idx=bt[num_words*i + max_idx];
00587 }
00588
00589 delete[] bt;
00590 delete[] score;
00591 SG_UNREF(alpha);
00592 return result;
00593 }