CommWordStringKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "kernel/CommWordStringKernel.h"
00013 #include "kernel/SqrtDiagKernelNormalizer.h"
00014 #include "features/StringFeatures.h"
00015 #include "lib/io.h"
00016 
00017 CCommWordStringKernel::CCommWordStringKernel(int32_t size, bool s)
00018 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00019     use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00020 {
00021     properties |= KP_LINADD;
00022     init_dictionary(1<<(sizeof(uint16_t)*8));
00023     set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00024 }
00025 
00026 CCommWordStringKernel::CCommWordStringKernel(
00027     CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, bool s,
00028     int32_t size)
00029 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00030     use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00031 {
00032     properties |= KP_LINADD;
00033 
00034     init_dictionary(1<<(sizeof(uint16_t)*8));
00035     set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00036     init(l,r);
00037 }
00038 
00039 
00040 bool CCommWordStringKernel::init_dictionary(int32_t size)
00041 {
00042     dictionary_size= size;
00043     delete[] dictionary_weights;
00044     dictionary_weights=new float64_t[size];
00045     SG_DEBUG( "using dictionary of %d words\n", size);
00046     clear_normal();
00047 
00048     return dictionary_weights!=NULL;
00049 }
00050 
00051 CCommWordStringKernel::~CCommWordStringKernel() 
00052 {
00053     cleanup();
00054 
00055     delete[] dictionary_weights;
00056     delete[] dict_diagonal_optimization ;
00057 }
00058   
00059 bool CCommWordStringKernel::init(CFeatures* l, CFeatures* r)
00060 {
00061     CStringKernel<uint16_t>::init(l,r);
00062 
00063     if (use_dict_diagonal_optimization)
00064     {
00065         delete[] dict_diagonal_optimization ;
00066         dict_diagonal_optimization=new int32_t[int32_t(((CStringFeatures<uint16_t>*)l)->get_num_symbols())];
00067         ASSERT(((CStringFeatures<uint16_t>*)l)->get_num_symbols() == ((CStringFeatures<uint16_t>*)r)->get_num_symbols()) ;
00068     }
00069 
00070     return init_normalizer();
00071 }
00072 
00073 void CCommWordStringKernel::cleanup()
00074 {
00075     delete_optimization();
00076     CKernel::cleanup();
00077 }
00078 
00079 bool CCommWordStringKernel::load_init(FILE* src)
00080 {
00081     return false;
00082 }
00083 
00084 bool CCommWordStringKernel::save_init(FILE* dest)
00085 {
00086     return false;
00087 }
00088 
00089 float64_t CCommWordStringKernel::compute_diag(int32_t idx_a)
00090 {
00091     int32_t alen;
00092     CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00093     CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00094 
00095     uint16_t* av=l->get_feature_vector(idx_a, alen);
00096 
00097     float64_t result=0.0 ;
00098     ASSERT(l==r);
00099     ASSERT(sizeof(uint16_t)<=sizeof(float64_t));
00100     ASSERT((1<<(sizeof(uint16_t)*8)) > alen);
00101 
00102     int32_t num_symbols=(int32_t) l->get_num_symbols();
00103     ASSERT(num_symbols<=dictionary_size);
00104 
00105     int32_t* dic = dict_diagonal_optimization;
00106     memset(dic, 0, num_symbols*sizeof(int32_t));
00107 
00108     for (int32_t i=0; i<alen; i++)
00109         dic[av[i]]++;
00110 
00111     if (use_sign)
00112     {
00113         for (int32_t i=0; i<(int32_t) l->get_num_symbols(); i++)
00114         {
00115             if (dic[i]!=0)
00116                 result++;
00117         }
00118     }
00119     else
00120     {
00121         for (int32_t i=0; i<num_symbols; i++)
00122         {
00123             if (dic[i]!=0)
00124                 result+=dic[i]*dic[i];
00125         }
00126     }
00127 
00128     return result;
00129 }
00130 
00131 float64_t CCommWordStringKernel::compute_helper(
00132     int32_t idx_a, int32_t idx_b, bool do_sort)
00133 {
00134     int32_t alen, blen;
00135     CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00136     CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00137 
00138     uint16_t* av=l->get_feature_vector(idx_a, alen);
00139     uint16_t* bv=r->get_feature_vector(idx_b, blen);
00140 
00141     uint16_t* avec=av;
00142     uint16_t* bvec=bv;
00143 
00144     if (do_sort)
00145     {
00146         if (alen>0)
00147         {
00148             avec=new uint16_t[alen];
00149             memcpy(avec, av, sizeof(uint16_t)*alen);
00150             CMath::radix_sort(avec, alen);
00151         }
00152         else
00153             avec=NULL;
00154 
00155         if (blen>0)
00156         {
00157             bvec=new uint16_t[blen];
00158             memcpy(bvec, bv, sizeof(uint16_t)*blen);
00159             CMath::radix_sort(bvec, blen);
00160         }
00161         else
00162             bvec=NULL;
00163     }
00164     else
00165     {
00166         if ( (l->get_num_preproc() != l->get_num_preprocessed()) ||
00167                 (r->get_num_preproc() != r->get_num_preprocessed()))
00168         {
00169             SG_ERROR("not all preprocessors have been applied to training (%d/%d)"
00170                     " or test (%d/%d) data\n", l->get_num_preprocessed(), l->get_num_preproc(),
00171                     r->get_num_preprocessed(), r->get_num_preproc());
00172         }
00173     }
00174 
00175     float64_t result=0;
00176 
00177     int32_t left_idx=0;
00178     int32_t right_idx=0;
00179 
00180     if (use_sign)
00181     {
00182         while (left_idx < alen && right_idx < blen)
00183         {
00184             if (avec[left_idx]==bvec[right_idx])
00185             {
00186                 uint16_t sym=avec[left_idx];
00187 
00188                 while (left_idx< alen && avec[left_idx]==sym)
00189                     left_idx++;
00190 
00191                 while (right_idx< blen && bvec[right_idx]==sym)
00192                     right_idx++;
00193 
00194                 result++;
00195             }
00196             else if (avec[left_idx]<bvec[right_idx])
00197                 left_idx++;
00198             else
00199                 right_idx++;
00200         }
00201     }
00202     else
00203     {
00204         while (left_idx < alen && right_idx < blen)
00205         {
00206             if (avec[left_idx]==bvec[right_idx])
00207             {
00208                 int32_t old_left_idx=left_idx;
00209                 int32_t old_right_idx=right_idx;
00210 
00211                 uint16_t sym=avec[left_idx];
00212 
00213                 while (left_idx< alen && avec[left_idx]==sym)
00214                     left_idx++;
00215 
00216                 while (right_idx< blen && bvec[right_idx]==sym)
00217                     right_idx++;
00218 
00219                 result+=((float64_t) (left_idx-old_left_idx))*
00220                     ((float64_t) (right_idx-old_right_idx));
00221             }
00222             else if (avec[left_idx]<bvec[right_idx])
00223                 left_idx++;
00224             else
00225                 right_idx++;
00226         }
00227     }
00228 
00229     if (do_sort)
00230     {
00231         delete[] avec;
00232         delete[] bvec;
00233     }
00234 
00235     return result;
00236 }
00237 
00238 void CCommWordStringKernel::add_to_normal(int32_t vec_idx, float64_t weight)
00239 {
00240     int32_t len=-1;
00241     uint16_t* vec=((CStringFeatures<uint16_t>*) lhs)->
00242         get_feature_vector(vec_idx, len);
00243 
00244     if (len>0)
00245     {
00246         int32_t j, last_j=0;
00247         if (use_sign)
00248         {
00249             for (j=1; j<len; j++)
00250             {
00251                 if (vec[j]==vec[j-1])
00252                     continue;
00253 
00254                 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00255                     normalize_lhs(weight, vec_idx);
00256             }
00257 
00258             dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00259                 normalize_lhs(weight, vec_idx);
00260         }
00261         else
00262         {
00263             for (j=1; j<len; j++)
00264             {
00265                 if (vec[j]==vec[j-1])
00266                     continue;
00267 
00268                 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00269                     normalize_lhs(weight*(j-last_j), vec_idx);
00270                 last_j = j;
00271             }
00272 
00273             dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00274                 normalize_lhs(weight*(len-last_j), vec_idx);
00275         }
00276         set_is_initialized(true);
00277     }
00278 }
00279 
00280 void CCommWordStringKernel::clear_normal()
00281 {
00282     memset(dictionary_weights, 0, dictionary_size*sizeof(float64_t));
00283     set_is_initialized(false);
00284 }
00285 
00286 bool CCommWordStringKernel::init_optimization(
00287     int32_t count, int32_t* IDX, float64_t* weights)
00288 {
00289     delete_optimization();
00290 
00291     if (count<=0)
00292     {
00293         set_is_initialized(true);
00294         SG_DEBUG("empty set of SVs\n");
00295         return true;
00296     }
00297 
00298     SG_DEBUG("initializing CCommWordStringKernel optimization\n");
00299 
00300     for (int32_t i=0; i<count; i++)
00301     {
00302         if ( (i % (count/10+1)) == 0)
00303             SG_PROGRESS(i, 0, count);
00304 
00305         add_to_normal(IDX[i], weights[i]);
00306     }
00307 
00308     set_is_initialized(true);
00309     return true;
00310 }
00311 
00312 bool CCommWordStringKernel::delete_optimization() 
00313 {
00314     SG_DEBUG( "deleting CCommWordStringKernel optimization\n");
00315 
00316     clear_normal();
00317     return true;
00318 }
00319 
00320 float64_t CCommWordStringKernel::compute_optimized(int32_t i)
00321 { 
00322     if (!get_is_initialized())
00323     {
00324       SG_ERROR( "CCommWordStringKernel optimization not initialized\n");
00325         return 0 ; 
00326     }
00327 
00328     float64_t result = 0;
00329     int32_t len = -1;
00330     uint16_t* vec=((CStringFeatures<uint16_t>*) rhs)->
00331         get_feature_vector(i, len);
00332 
00333     int32_t j, last_j=0;
00334     if (vec && len>0)
00335     {
00336         if (use_sign)
00337         {
00338             for (j=1; j<len; j++)
00339             {
00340                 if (vec[j]==vec[j-1])
00341                     continue;
00342 
00343                 result += dictionary_weights[(int32_t) vec[j-1]];
00344             }
00345 
00346             result += dictionary_weights[(int32_t) vec[len-1]];
00347         }
00348         else
00349         {
00350             for (j=1; j<len; j++)
00351             {
00352                 if (vec[j]==vec[j-1])
00353                     continue;
00354 
00355                 result += dictionary_weights[(int32_t) vec[j-1]]*(j-last_j);
00356                 last_j = j;
00357             }
00358 
00359             result += dictionary_weights[(int32_t) vec[len-1]]*(len-last_j);
00360         }
00361 
00362         result=normalizer->normalize_rhs(result, i);
00363     }
00364     return result;
00365 }
00366 
00367 float64_t* CCommWordStringKernel::compute_scoring(
00368     int32_t max_degree, int32_t& num_feat, int32_t& num_sym, float64_t* target,
00369     int32_t num_suppvec, int32_t* IDX, float64_t* alphas, bool do_init)
00370 {
00371     ASSERT(lhs);
00372     CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00373     num_feat=1;//str->get_max_vector_length();
00374     CAlphabet* alpha=str->get_alphabet();
00375     ASSERT(alpha);
00376     int32_t num_bits=alpha->get_num_bits();
00377     int32_t order=str->get_order();
00378     ASSERT(max_degree<=order);
00379     //int32_t num_words=(int32_t) str->get_num_symbols();
00380     int32_t num_words=(int32_t) str->get_original_num_symbols();
00381     int32_t offset=0;
00382 
00383     num_sym=0;
00384     
00385     for (int32_t i=0; i<order; i++)
00386         num_sym+=CMath::pow((int32_t) num_words,i+1);
00387 
00388     SG_DEBUG("num_words:%d, order:%d, len:%d sz:%d (len*sz:%d)\n", num_words, order,
00389             num_feat, num_sym, num_feat*num_sym);
00390 
00391     if (!target)
00392         target=new float64_t[num_feat*num_sym];
00393     memset(target, 0, num_feat*num_sym*sizeof(float64_t));
00394 
00395     if (do_init)
00396         init_optimization(num_suppvec, IDX, alphas);
00397 
00398     uint32_t kmer_mask=0;
00399     uint32_t words=CMath::pow((int32_t) num_words,(int32_t) order);
00400 
00401     for (int32_t o=0; o<max_degree; o++)
00402     {
00403         float64_t* contrib=&target[offset];
00404         offset+=CMath::pow((int32_t) num_words,(int32_t) o+1);
00405 
00406         kmer_mask=(kmer_mask<<(num_bits)) | str->get_masked_symbols(0xffff, 1);
00407 
00408         for (int32_t p=-o; p<order; p++)
00409         {
00410             int32_t o_sym=0, m_sym=0, il=0,ir=0, jl=0;
00411             uint32_t imer_mask=kmer_mask;
00412             uint32_t jmer_mask=kmer_mask;
00413 
00414             if (p<0)
00415             {
00416                 il=-p;
00417                 m_sym=order-o-p-1;
00418                 o_sym=-p;
00419             }
00420             else if (p<order-o)
00421             {
00422                 ir=p;
00423                 m_sym=order-o-1;
00424             }
00425             else
00426             {
00427                 ir=p;
00428                 m_sym=p;
00429                 o_sym=p-order+o+1;
00430                 jl=order-ir;
00431                 imer_mask=(kmer_mask>>(num_bits*o_sym));
00432                 jmer_mask=(kmer_mask>>(num_bits*jl));
00433             }
00434 
00435             float64_t marginalizer=
00436                 1.0/CMath::pow((int32_t) num_words,(int32_t) m_sym);
00437             
00438             for (uint32_t i=0; i<words; i++)
00439             {
00440                 uint16_t x= ((i << (num_bits*il)) >> (num_bits*ir)) & imer_mask;
00441 
00442                 if (p>=0 && p<order-o)
00443                 {
00444 //#define DEBUG_COMMSCORING
00445 #ifdef DEBUG_COMMSCORING
00446                     SG_PRINT("o=%d/%d p=%d/%d i=0x%x x=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d marg=%g o_sym:%d m_sym:%d weight(",
00447                             o,order, p,order, i, x, imer_mask, jmer_mask, kmer_mask, il, ir, marginalizer, o_sym, m_sym);
00448 
00449                     SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n", 
00450                             alpha->remap_to_char((x>>(3*num_bits))&0x03), alpha->remap_to_char((x>>(2*num_bits))&0x03),
00451                             alpha->remap_to_char((x>>num_bits)&0x03), alpha->remap_to_char(x&0x03),
00452                             alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00453                             alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00454                             dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00455 #endif
00456                     contrib[x]+=dictionary_weights[i]*marginalizer;
00457                 }
00458                 else
00459                 {
00460                     for (uint32_t j=0; j< (uint32_t) CMath::pow((int32_t) num_words, (int32_t) o_sym); j++)
00461                     {
00462                         uint32_t c=x | ((j & jmer_mask) << (num_bits*jl));
00463 #ifdef DEBUG_COMMSCORING
00464 
00465                         SG_PRINT("o=%d/%d p=%d/%d i=0x%x j=0x%x x=0x%x c=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d jl=%d marg=%g o_sym:%d m_sym:%d weight(",
00466                                 o,order, p,order, i, j, x, c, imer_mask, jmer_mask, kmer_mask, il, ir, jl, marginalizer, o_sym, m_sym);
00467                         SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n", 
00468                                 alpha->remap_to_char((c>>(3*num_bits))&0x03), alpha->remap_to_char((c>>(2*num_bits))&0x03),
00469                                 alpha->remap_to_char((c>>num_bits)&0x03), alpha->remap_to_char(c&0x03),
00470                                 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00471                                 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00472                                 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00473 #endif
00474                         contrib[c]+=dictionary_weights[i]*marginalizer;
00475                     }
00476                 }
00477             }
00478         }
00479     }
00480 
00481     for (int32_t i=1; i<num_feat; i++)
00482         memcpy(&target[num_sym*i], target, num_sym*sizeof(float64_t));
00483 
00484     SG_UNREF(alpha);
00485 
00486     return target;
00487 }
00488 
00489 
00490 char* CCommWordStringKernel::compute_consensus(
00491     int32_t &result_len, int32_t num_suppvec, int32_t* IDX, float64_t* alphas)
00492 {
00493     ASSERT(lhs);
00494     ASSERT(IDX);
00495     ASSERT(alphas);
00496 
00497     CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00498     int32_t num_words=(int32_t) str->get_num_symbols();
00499     int32_t num_feat=str->get_max_vector_length();
00500     int64_t total_len=((int64_t) num_feat) * num_words;
00501     CAlphabet* alpha=((CStringFeatures<uint16_t>*) lhs)->get_alphabet();
00502     ASSERT(alpha);
00503     int32_t num_bits=alpha->get_num_bits();
00504     int32_t order=str->get_order();
00505     int32_t max_idx=-1;
00506     float64_t max_score=0; 
00507     result_len=num_feat+order-1;
00508 
00509     //init
00510     init_optimization(num_suppvec, IDX, alphas);
00511 
00512     char* result=new char[result_len];
00513     int32_t* bt=new int32_t[total_len];
00514     float64_t* score=new float64_t[total_len];
00515 
00516     for (int64_t i=0; i<total_len; i++)
00517     {
00518         bt[i]=-1;
00519         score[i]=0;
00520     }
00521 
00522     for (int32_t t=0; t<num_words; t++)
00523         score[t]=dictionary_weights[t];
00524 
00525     //dynamic program
00526     for (int32_t i=1; i<num_feat; i++)
00527     {
00528         for (int32_t t1=0; t1<num_words; t1++)
00529         {
00530             max_idx=-1;
00531             max_score=0; 
00532 
00533             /* ignore weights the svm does not care about 
00534              * (has not seen in training). note that this assumes that zero 
00535              * weights are very unlikely to appear elsewise */
00536 
00537             //if (dictionary_weights[t1]==0.0)
00538                 //continue;
00539 
00540             /* iterate over words t ending on t1 and find the highest scoring
00541              * pair */
00542             uint16_t suffix=(uint16_t) t1 >> num_bits;
00543 
00544             for (int32_t sym=0; sym<str->get_original_num_symbols(); sym++)
00545             {
00546                 uint16_t t=suffix | sym << (num_bits*(order-1));
00547 
00548                 //if (dictionary_weights[t]==0.0)
00549                 //  continue;
00550 
00551                 float64_t sc=score[num_words*(i-1) + t]+dictionary_weights[t1];
00552                 if (sc > max_score || max_idx==-1)
00553                 {
00554                     max_idx=t;
00555                     max_score=sc;
00556                 }
00557             }
00558             ASSERT(max_idx!=-1);
00559 
00560             score[num_words*i + t1]=max_score;
00561             bt[num_words*i + t1]=max_idx;
00562         }
00563     }
00564 
00565     //backtracking
00566     max_idx=0;
00567     max_score=score[num_words*(num_feat-1) + 0];
00568     for (int32_t t=1; t<num_words; t++)
00569     {
00570         float64_t sc=score[num_words*(num_feat-1) + t];
00571         if (sc>max_score)
00572         {
00573             max_idx=t;
00574             max_score=sc;
00575         }
00576     }
00577 
00578     SG_PRINT("max_idx:%i, max_score:%f\n", max_idx, max_score);
00579     
00580     for (int32_t i=result_len-1; i>=num_feat; i--)
00581         result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(result_len-1-i)), 1) );
00582 
00583     for (int32_t i=num_feat-1; i>=0; i--)
00584     {
00585         result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(order-1)), 1) );
00586         max_idx=bt[num_words*i + max_idx];
00587     }
00588 
00589     delete[] bt;
00590     delete[] score;
00591     SG_UNREF(alpha);
00592     return result;
00593 }

SHOGUN Machine Learning Toolbox - Documentation