00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "kernel/CommWordStringKernel.h"
00013 #include "kernel/SqrtDiagKernelNormalizer.h"
00014 #include "features/StringFeatures.h"
00015 #include "lib/io.h"
00016
00017 using namespace shogun;
00018
00019 CCommWordStringKernel::CCommWordStringKernel(int32_t size, bool s)
00020 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00021 use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00022 {
00023 properties |= KP_LINADD;
00024 init_dictionary(1<<(sizeof(uint16_t)*8));
00025 set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00026 }
00027
00028 CCommWordStringKernel::CCommWordStringKernel(
00029 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, bool s,
00030 int32_t size)
00031 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00032 use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00033 {
00034 properties |= KP_LINADD;
00035
00036 init_dictionary(1<<(sizeof(uint16_t)*8));
00037 set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00038 init(l,r);
00039 }
00040
00041
00042 bool CCommWordStringKernel::init_dictionary(int32_t size)
00043 {
00044 dictionary_size= size;
00045 delete[] dictionary_weights;
00046 dictionary_weights=new float64_t[size];
00047 SG_DEBUG( "using dictionary of %d words\n", size);
00048 clear_normal();
00049
00050 return dictionary_weights!=NULL;
00051 }
00052
00053 CCommWordStringKernel::~CCommWordStringKernel()
00054 {
00055 cleanup();
00056
00057 delete[] dictionary_weights;
00058 delete[] dict_diagonal_optimization ;
00059 }
00060
00061 bool CCommWordStringKernel::init(CFeatures* l, CFeatures* r)
00062 {
00063 CStringKernel<uint16_t>::init(l,r);
00064
00065 if (use_dict_diagonal_optimization)
00066 {
00067 delete[] dict_diagonal_optimization ;
00068 dict_diagonal_optimization=new int32_t[int32_t(((CStringFeatures<uint16_t>*)l)->get_num_symbols())];
00069 ASSERT(((CStringFeatures<uint16_t>*)l)->get_num_symbols() == ((CStringFeatures<uint16_t>*)r)->get_num_symbols()) ;
00070 }
00071
00072 return init_normalizer();
00073 }
00074
00075 void CCommWordStringKernel::cleanup()
00076 {
00077 delete_optimization();
00078 CKernel::cleanup();
00079 }
00080
00081 float64_t CCommWordStringKernel::compute_diag(int32_t idx_a)
00082 {
00083 int32_t alen;
00084 CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00085 CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00086
00087 bool free_av;
00088 uint16_t* av=l->get_feature_vector(idx_a, alen, free_av);
00089
00090 float64_t result=0.0 ;
00091 ASSERT(l==r);
00092 ASSERT(sizeof(uint16_t)<=sizeof(float64_t));
00093 ASSERT((1<<(sizeof(uint16_t)*8)) > alen);
00094
00095 int32_t num_symbols=(int32_t) l->get_num_symbols();
00096 ASSERT(num_symbols<=dictionary_size);
00097
00098 int32_t* dic = dict_diagonal_optimization;
00099 memset(dic, 0, num_symbols*sizeof(int32_t));
00100
00101 for (int32_t i=0; i<alen; i++)
00102 dic[av[i]]++;
00103
00104 if (use_sign)
00105 {
00106 for (int32_t i=0; i<(int32_t) l->get_num_symbols(); i++)
00107 {
00108 if (dic[i]!=0)
00109 result++;
00110 }
00111 }
00112 else
00113 {
00114 for (int32_t i=0; i<num_symbols; i++)
00115 {
00116 if (dic[i]!=0)
00117 result+=dic[i]*dic[i];
00118 }
00119 }
00120 l->free_feature_vector(av, idx_a, free_av);
00121
00122 return result;
00123 }
00124
00125 float64_t CCommWordStringKernel::compute_helper(
00126 int32_t idx_a, int32_t idx_b, bool do_sort)
00127 {
00128 int32_t alen, blen;
00129 bool free_av, free_bv;
00130
00131 CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00132 CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00133
00134 uint16_t* av=l->get_feature_vector(idx_a, alen, free_av);
00135 uint16_t* bv=r->get_feature_vector(idx_b, blen, free_bv);
00136
00137 uint16_t* avec=av;
00138 uint16_t* bvec=bv;
00139
00140 if (do_sort)
00141 {
00142 if (alen>0)
00143 {
00144 avec=new uint16_t[alen];
00145 memcpy(avec, av, sizeof(uint16_t)*alen);
00146 CMath::radix_sort(avec, alen);
00147 }
00148 else
00149 avec=NULL;
00150
00151 if (blen>0)
00152 {
00153 bvec=new uint16_t[blen];
00154 memcpy(bvec, bv, sizeof(uint16_t)*blen);
00155 CMath::radix_sort(bvec, blen);
00156 }
00157 else
00158 bvec=NULL;
00159 }
00160 else
00161 {
00162 if ( (l->get_num_preproc() != l->get_num_preprocessed()) ||
00163 (r->get_num_preproc() != r->get_num_preprocessed()))
00164 {
00165 SG_ERROR("not all preprocessors have been applied to training (%d/%d)"
00166 " or test (%d/%d) data\n", l->get_num_preprocessed(), l->get_num_preproc(),
00167 r->get_num_preprocessed(), r->get_num_preproc());
00168 }
00169 }
00170
00171 float64_t result=0;
00172
00173 int32_t left_idx=0;
00174 int32_t right_idx=0;
00175
00176 if (use_sign)
00177 {
00178 while (left_idx < alen && right_idx < blen)
00179 {
00180 if (avec[left_idx]==bvec[right_idx])
00181 {
00182 uint16_t sym=avec[left_idx];
00183
00184 while (left_idx< alen && avec[left_idx]==sym)
00185 left_idx++;
00186
00187 while (right_idx< blen && bvec[right_idx]==sym)
00188 right_idx++;
00189
00190 result++;
00191 }
00192 else if (avec[left_idx]<bvec[right_idx])
00193 left_idx++;
00194 else
00195 right_idx++;
00196 }
00197 }
00198 else
00199 {
00200 while (left_idx < alen && right_idx < blen)
00201 {
00202 if (avec[left_idx]==bvec[right_idx])
00203 {
00204 int32_t old_left_idx=left_idx;
00205 int32_t old_right_idx=right_idx;
00206
00207 uint16_t sym=avec[left_idx];
00208
00209 while (left_idx< alen && avec[left_idx]==sym)
00210 left_idx++;
00211
00212 while (right_idx< blen && bvec[right_idx]==sym)
00213 right_idx++;
00214
00215 result+=((float64_t) (left_idx-old_left_idx))*
00216 ((float64_t) (right_idx-old_right_idx));
00217 }
00218 else if (avec[left_idx]<bvec[right_idx])
00219 left_idx++;
00220 else
00221 right_idx++;
00222 }
00223 }
00224
00225 if (do_sort)
00226 {
00227 delete[] avec;
00228 delete[] bvec;
00229 }
00230
00231 l->free_feature_vector(av, idx_a, free_av);
00232 r->free_feature_vector(bv, idx_b, free_bv);
00233
00234 return result;
00235 }
00236
00237 void CCommWordStringKernel::add_to_normal(int32_t vec_idx, float64_t weight)
00238 {
00239 int32_t len=-1;
00240 bool free_vec;
00241 uint16_t* vec=((CStringFeatures<uint16_t>*) lhs)->
00242 get_feature_vector(vec_idx, len, free_vec);
00243
00244 if (len>0)
00245 {
00246 int32_t j, last_j=0;
00247 if (use_sign)
00248 {
00249 for (j=1; j<len; j++)
00250 {
00251 if (vec[j]==vec[j-1])
00252 continue;
00253
00254 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00255 normalize_lhs(weight, vec_idx);
00256 }
00257
00258 dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00259 normalize_lhs(weight, vec_idx);
00260 }
00261 else
00262 {
00263 for (j=1; j<len; j++)
00264 {
00265 if (vec[j]==vec[j-1])
00266 continue;
00267
00268 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00269 normalize_lhs(weight*(j-last_j), vec_idx);
00270 last_j = j;
00271 }
00272
00273 dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00274 normalize_lhs(weight*(len-last_j), vec_idx);
00275 }
00276 set_is_initialized(true);
00277 }
00278
00279 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(vec, vec_idx, free_vec);
00280 }
00281
00282 void CCommWordStringKernel::clear_normal()
00283 {
00284 memset(dictionary_weights, 0, dictionary_size*sizeof(float64_t));
00285 set_is_initialized(false);
00286 }
00287
00288 bool CCommWordStringKernel::init_optimization(
00289 int32_t count, int32_t* IDX, float64_t* weights)
00290 {
00291 delete_optimization();
00292
00293 if (count<=0)
00294 {
00295 set_is_initialized(true);
00296 SG_DEBUG("empty set of SVs\n");
00297 return true;
00298 }
00299
00300 SG_DEBUG("initializing CCommWordStringKernel optimization\n");
00301
00302 for (int32_t i=0; i<count; i++)
00303 {
00304 if ( (i % (count/10+1)) == 0)
00305 SG_PROGRESS(i, 0, count);
00306
00307 add_to_normal(IDX[i], weights[i]);
00308 }
00309
00310 set_is_initialized(true);
00311 return true;
00312 }
00313
00314 bool CCommWordStringKernel::delete_optimization()
00315 {
00316 SG_DEBUG( "deleting CCommWordStringKernel optimization\n");
00317
00318 clear_normal();
00319 return true;
00320 }
00321
00322 float64_t CCommWordStringKernel::compute_optimized(int32_t i)
00323 {
00324 if (!get_is_initialized())
00325 {
00326 SG_ERROR( "CCommWordStringKernel optimization not initialized\n");
00327 return 0 ;
00328 }
00329
00330 float64_t result = 0;
00331 int32_t len = -1;
00332 bool free_vec;
00333 uint16_t* vec=((CStringFeatures<uint16_t>*) rhs)->
00334 get_feature_vector(i, len, free_vec);
00335
00336 int32_t j, last_j=0;
00337 if (vec && len>0)
00338 {
00339 if (use_sign)
00340 {
00341 for (j=1; j<len; j++)
00342 {
00343 if (vec[j]==vec[j-1])
00344 continue;
00345
00346 result += dictionary_weights[(int32_t) vec[j-1]];
00347 }
00348
00349 result += dictionary_weights[(int32_t) vec[len-1]];
00350 }
00351 else
00352 {
00353 for (j=1; j<len; j++)
00354 {
00355 if (vec[j]==vec[j-1])
00356 continue;
00357
00358 result += dictionary_weights[(int32_t) vec[j-1]]*(j-last_j);
00359 last_j = j;
00360 }
00361
00362 result += dictionary_weights[(int32_t) vec[len-1]]*(len-last_j);
00363 }
00364
00365 result=normalizer->normalize_rhs(result, i);
00366 }
00367 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(vec, i, free_vec);
00368 return result;
00369 }
00370
00371 float64_t* CCommWordStringKernel::compute_scoring(
00372 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, float64_t* target,
00373 int32_t num_suppvec, int32_t* IDX, float64_t* alphas, bool do_init)
00374 {
00375 ASSERT(lhs);
00376 CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00377 num_feat=1;
00378 CAlphabet* alpha=str->get_alphabet();
00379 ASSERT(alpha);
00380 int32_t num_bits=alpha->get_num_bits();
00381 int32_t order=str->get_order();
00382 ASSERT(max_degree<=order);
00383
00384 int32_t num_words=(int32_t) str->get_original_num_symbols();
00385 int32_t offset=0;
00386
00387 num_sym=0;
00388
00389 for (int32_t i=0; i<order; i++)
00390 num_sym+=CMath::pow((int32_t) num_words,i+1);
00391
00392 SG_DEBUG("num_words:%d, order:%d, len:%d sz:%d (len*sz:%d)\n", num_words, order,
00393 num_feat, num_sym, num_feat*num_sym);
00394
00395 if (!target)
00396 target=new float64_t[num_feat*num_sym];
00397 memset(target, 0, num_feat*num_sym*sizeof(float64_t));
00398
00399 if (do_init)
00400 init_optimization(num_suppvec, IDX, alphas);
00401
00402 uint32_t kmer_mask=0;
00403 uint32_t words=CMath::pow((int32_t) num_words,(int32_t) order);
00404
00405 for (int32_t o=0; o<max_degree; o++)
00406 {
00407 float64_t* contrib=&target[offset];
00408 offset+=CMath::pow((int32_t) num_words,(int32_t) o+1);
00409
00410 kmer_mask=(kmer_mask<<(num_bits)) | str->get_masked_symbols(0xffff, 1);
00411
00412 for (int32_t p=-o; p<order; p++)
00413 {
00414 int32_t o_sym=0, m_sym=0, il=0,ir=0, jl=0;
00415 uint32_t imer_mask=kmer_mask;
00416 uint32_t jmer_mask=kmer_mask;
00417
00418 if (p<0)
00419 {
00420 il=-p;
00421 m_sym=order-o-p-1;
00422 o_sym=-p;
00423 }
00424 else if (p<order-o)
00425 {
00426 ir=p;
00427 m_sym=order-o-1;
00428 }
00429 else
00430 {
00431 ir=p;
00432 m_sym=p;
00433 o_sym=p-order+o+1;
00434 jl=order-ir;
00435 imer_mask=(kmer_mask>>(num_bits*o_sym));
00436 jmer_mask=(kmer_mask>>(num_bits*jl));
00437 }
00438
00439 float64_t marginalizer=
00440 1.0/CMath::pow((int32_t) num_words,(int32_t) m_sym);
00441
00442 for (uint32_t i=0; i<words; i++)
00443 {
00444 uint16_t x= ((i << (num_bits*il)) >> (num_bits*ir)) & imer_mask;
00445
00446 if (p>=0 && p<order-o)
00447 {
00448
00449 #ifdef DEBUG_COMMSCORING
00450 SG_PRINT("o=%d/%d p=%d/%d i=0x%x x=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d marg=%g o_sym:%d m_sym:%d weight(",
00451 o,order, p,order, i, x, imer_mask, jmer_mask, kmer_mask, il, ir, marginalizer, o_sym, m_sym);
00452
00453 SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n",
00454 alpha->remap_to_char((x>>(3*num_bits))&0x03), alpha->remap_to_char((x>>(2*num_bits))&0x03),
00455 alpha->remap_to_char((x>>num_bits)&0x03), alpha->remap_to_char(x&0x03),
00456 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00457 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00458 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00459 #endif
00460 contrib[x]+=dictionary_weights[i]*marginalizer;
00461 }
00462 else
00463 {
00464 for (uint32_t j=0; j< (uint32_t) CMath::pow((int32_t) num_words, (int32_t) o_sym); j++)
00465 {
00466 uint32_t c=x | ((j & jmer_mask) << (num_bits*jl));
00467 #ifdef DEBUG_COMMSCORING
00468
00469 SG_PRINT("o=%d/%d p=%d/%d i=0x%x j=0x%x x=0x%x c=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d jl=%d marg=%g o_sym:%d m_sym:%d weight(",
00470 o,order, p,order, i, j, x, c, imer_mask, jmer_mask, kmer_mask, il, ir, jl, marginalizer, o_sym, m_sym);
00471 SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n",
00472 alpha->remap_to_char((c>>(3*num_bits))&0x03), alpha->remap_to_char((c>>(2*num_bits))&0x03),
00473 alpha->remap_to_char((c>>num_bits)&0x03), alpha->remap_to_char(c&0x03),
00474 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00475 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00476 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00477 #endif
00478 contrib[c]+=dictionary_weights[i]*marginalizer;
00479 }
00480 }
00481 }
00482 }
00483 }
00484
00485 for (int32_t i=1; i<num_feat; i++)
00486 memcpy(&target[num_sym*i], target, num_sym*sizeof(float64_t));
00487
00488 SG_UNREF(alpha);
00489
00490 return target;
00491 }
00492
00493
00494 char* CCommWordStringKernel::compute_consensus(
00495 int32_t &result_len, int32_t num_suppvec, int32_t* IDX, float64_t* alphas)
00496 {
00497 ASSERT(lhs);
00498 ASSERT(IDX);
00499 ASSERT(alphas);
00500
00501 CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00502 int32_t num_words=(int32_t) str->get_num_symbols();
00503 int32_t num_feat=str->get_max_vector_length();
00504 int64_t total_len=((int64_t) num_feat) * num_words;
00505 CAlphabet* alpha=((CStringFeatures<uint16_t>*) lhs)->get_alphabet();
00506 ASSERT(alpha);
00507 int32_t num_bits=alpha->get_num_bits();
00508 int32_t order=str->get_order();
00509 int32_t max_idx=-1;
00510 float64_t max_score=0;
00511 result_len=num_feat+order-1;
00512
00513
00514 init_optimization(num_suppvec, IDX, alphas);
00515
00516 char* result=new char[result_len];
00517 int32_t* bt=new int32_t[total_len];
00518 float64_t* score=new float64_t[total_len];
00519
00520 for (int64_t i=0; i<total_len; i++)
00521 {
00522 bt[i]=-1;
00523 score[i]=0;
00524 }
00525
00526 for (int32_t t=0; t<num_words; t++)
00527 score[t]=dictionary_weights[t];
00528
00529
00530 for (int32_t i=1; i<num_feat; i++)
00531 {
00532 for (int32_t t1=0; t1<num_words; t1++)
00533 {
00534 max_idx=-1;
00535 max_score=0;
00536
00537
00538
00539
00540
00541
00542
00543
00544
00545
00546 uint16_t suffix=(uint16_t) t1 >> num_bits;
00547
00548 for (int32_t sym=0; sym<str->get_original_num_symbols(); sym++)
00549 {
00550 uint16_t t=suffix | sym << (num_bits*(order-1));
00551
00552
00553
00554
00555 float64_t sc=score[num_words*(i-1) + t]+dictionary_weights[t1];
00556 if (sc > max_score || max_idx==-1)
00557 {
00558 max_idx=t;
00559 max_score=sc;
00560 }
00561 }
00562 ASSERT(max_idx!=-1);
00563
00564 score[num_words*i + t1]=max_score;
00565 bt[num_words*i + t1]=max_idx;
00566 }
00567 }
00568
00569
00570 max_idx=0;
00571 max_score=score[num_words*(num_feat-1) + 0];
00572 for (int32_t t=1; t<num_words; t++)
00573 {
00574 float64_t sc=score[num_words*(num_feat-1) + t];
00575 if (sc>max_score)
00576 {
00577 max_idx=t;
00578 max_score=sc;
00579 }
00580 }
00581
00582 SG_PRINT("max_idx:%i, max_score:%f\n", max_idx, max_score);
00583
00584 for (int32_t i=result_len-1; i>=num_feat; i--)
00585 result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(result_len-1-i)), 1) );
00586
00587 for (int32_t i=num_feat-1; i>=0; i--)
00588 {
00589 result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(order-1)), 1) );
00590 max_idx=bt[num_words*i + max_idx];
00591 }
00592
00593 delete[] bt;
00594 delete[] score;
00595 SG_UNREF(alpha);
00596 return result;
00597 }