HistogramWordStringKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #include "lib/common.h"
00013 #include "kernel/HistogramWordStringKernel.h"
00014 #include "features/Features.h"
00015 #include "features/StringFeatures.h"
00016 #include "classifier/PluginEstimate.h"
00017 #include "lib/io.h"
00018 
00019 using namespace shogun;
00020 
00021 CHistogramWordStringKernel::CHistogramWordStringKernel(int32_t size, CPluginEstimate* pie)
00022 : CStringKernel<uint16_t>(size), estimate(pie), mean(NULL), variance(NULL),
00023     sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00024     ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00025     plo_lhs(NULL), plo_rhs(NULL), num_params(0), num_params2(0),
00026     num_symbols(0), sum_m2_s2(0), initialized(false)
00027 {
00028 }
00029 
00030 CHistogramWordStringKernel::CHistogramWordStringKernel(
00031     CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, CPluginEstimate* pie)
00032 : CStringKernel<uint16_t>(10), estimate(pie), mean(NULL), variance(NULL),
00033     sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00034     ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00035     plo_lhs(NULL), plo_rhs(NULL), num_params(0), num_params2(0),
00036     num_symbols(0), sum_m2_s2(0), initialized(false)
00037 {
00038     init(l, r);
00039 }
00040 
00041 CHistogramWordStringKernel::~CHistogramWordStringKernel()
00042 {
00043     delete[] variance;
00044     delete[] mean;
00045     if (sqrtdiag_lhs != sqrtdiag_rhs)
00046         delete[] sqrtdiag_rhs;
00047     delete[] sqrtdiag_lhs;
00048     if (ld_mean_lhs!=ld_mean_rhs)
00049         delete[] ld_mean_rhs ;
00050     delete[] ld_mean_lhs ;
00051     if (plo_lhs!=plo_rhs)
00052         delete[] plo_rhs ;
00053     delete[] plo_lhs ;
00054 }
00055 
00056 bool CHistogramWordStringKernel::init(CFeatures* p_l, CFeatures* p_r)
00057 {
00058     CStringKernel<uint16_t>::init(p_l,p_r);
00059     CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l;
00060     CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r;
00061     ASSERT(l);
00062     ASSERT(r);
00063 
00064     SG_DEBUG( "init: lhs: %ld   rhs: %ld\n", l, r) ;
00065     int32_t i;
00066     initialized=false;
00067 
00068     if (sqrtdiag_lhs != sqrtdiag_rhs)
00069         delete[] sqrtdiag_rhs;
00070     sqrtdiag_rhs=NULL ;
00071     delete[] sqrtdiag_lhs;
00072     sqrtdiag_lhs=NULL ;
00073     if (ld_mean_lhs!=ld_mean_rhs)
00074         delete[] ld_mean_rhs ;
00075     ld_mean_rhs=NULL ;
00076     delete[] ld_mean_lhs ;
00077     ld_mean_lhs=NULL ;
00078     if (plo_lhs!=plo_rhs)
00079         delete[] plo_rhs ;
00080     plo_rhs=NULL ;
00081     delete[] plo_lhs ;
00082     plo_lhs=NULL ;
00083 
00084     sqrtdiag_lhs= new float64_t[l->get_num_vectors()];
00085     ld_mean_lhs = new float64_t[l->get_num_vectors()];
00086     plo_lhs     = new float64_t[l->get_num_vectors()];
00087 
00088     for (i=0; i<l->get_num_vectors(); i++)
00089         sqrtdiag_lhs[i]=1;
00090 
00091     if (l==r)
00092     {
00093         sqrtdiag_rhs=sqrtdiag_lhs;
00094         ld_mean_rhs=ld_mean_lhs;
00095         plo_rhs=plo_lhs;
00096     }
00097     else
00098     {
00099         sqrtdiag_rhs=new float64_t[r->get_num_vectors()];
00100         for (i=0; i<r->get_num_vectors(); i++)
00101             sqrtdiag_rhs[i]=1;
00102 
00103         ld_mean_rhs=new float64_t[r->get_num_vectors()];
00104         plo_rhs=new float64_t[r->get_num_vectors()];
00105     }
00106 
00107     float64_t* l_plo_lhs=plo_lhs;
00108     float64_t* l_plo_rhs=plo_rhs;
00109     float64_t* l_ld_mean_lhs=ld_mean_lhs;
00110     float64_t* l_ld_mean_rhs=ld_mean_rhs;
00111 
00112     //from our knowledge first normalize variance to 1 and then norm=1 does the job
00113     if (!initialized)
00114     {
00115         int32_t num_vectors=l->get_num_vectors();
00116         num_symbols=(int32_t) l->get_num_symbols();
00117         int32_t llen=l->get_vector_length(0);
00118         int32_t rlen=r->get_vector_length(0);
00119         num_params=llen*((int32_t) l->get_num_symbols());
00120         num_params2=llen*((int32_t) l->get_num_symbols())+rlen*((int32_t) r->get_num_symbols());
00121 
00122         if ((!estimate) || (!estimate->check_models()))
00123         {
00124             SG_ERROR( "no estimate available\n");
00125             return false ;
00126         } ;
00127         if (num_params2!=estimate->get_num_params())
00128         {
00129             SG_ERROR( "number of parameters of estimate and feature representation do not match\n");
00130             return false ;
00131         } ;
00132 
00133         //add 1 as we have the 'bias' also in this vector
00134         num_params2++;
00135 
00136         delete[] mean;
00137         mean=new float64_t[num_params2];
00138         delete[] variance;
00139         variance=new float64_t[num_params2];
00140 
00141         for (i=0; i<num_params2; i++)
00142         {
00143             mean[i]=0;
00144             variance[i]=0;
00145         }
00146 
00147         // compute mean
00148         for (i=0; i<num_vectors; i++)
00149         {
00150             int32_t len;
00151             bool free_vec;
00152             uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00153 
00154             mean[0]+=estimate->posterior_log_odds_obsolete(vec, len)/num_vectors;
00155 
00156             for (int32_t j=0; j<len; j++)
00157             {
00158                 int32_t idx=compute_index(j, vec[j]);
00159                 mean[idx]             += estimate->log_derivative_pos_obsolete(vec[j], j)/num_vectors;
00160                 mean[idx+num_params] += estimate->log_derivative_neg_obsolete(vec[j], j)/num_vectors;
00161             }
00162 
00163             l->free_feature_vector(vec, i, free_vec);
00164         }
00165 
00166         // compute variance
00167         for (i=0; i<num_vectors; i++)
00168         {
00169             int32_t len;
00170             bool free_vec;
00171             uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00172 
00173             variance[0] += CMath::sq(estimate->posterior_log_odds_obsolete(vec, len)-mean[0])/num_vectors;
00174 
00175             for (int32_t j=0; j<len; j++)
00176             {
00177                 for (int32_t k=0; k<4; k++)
00178                 {
00179                     int32_t idx=compute_index(j, k);
00180                     if (k!=vec[j])
00181                     {
00182                         variance[idx]+=mean[idx]*mean[idx]/num_vectors;
00183                         variance[idx+num_params]+=mean[idx+num_params]*mean[idx+num_params]/num_vectors;
00184                     }
00185                     else
00186                     {
00187                         variance[idx] += CMath::sq(estimate->log_derivative_pos_obsolete(vec[j], j)
00188                                 -mean[idx])/num_vectors;
00189                         variance[idx+num_params] += CMath::sq(estimate->log_derivative_neg_obsolete(vec[j], j)
00190                                 -mean[idx+num_params])/num_vectors;
00191                     }
00192                 }
00193             }
00194 
00195             l->free_feature_vector(vec, i, free_vec);
00196         }
00197 
00198 
00199         // compute sum_i m_i^2/s_i^2
00200         sum_m2_s2=0 ;
00201         for (i=1; i<num_params2; i++)
00202         {
00203             if (variance[i]<1e-14) // then it is likely to be numerical inaccuracy
00204                 variance[i]=1 ;
00205 
00206             //fprintf(stderr, "%i: mean=%1.2e  std=%1.2e\n", i, mean[i], std[i]) ;
00207             sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ;
00208         } ;
00209     } 
00210 
00211     // compute sum of 
00212     //result -= estimate->log_derivative_pos(avec[i], i)*mean[a_idx]/variance[a_idx] ;
00213     //result -= estimate->log_derivative_neg(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00214     for (i=0; i<l->get_num_vectors(); i++)
00215     {
00216         int32_t alen;
00217         bool free_avec;
00218         uint16_t* avec = l->get_feature_vector(i, alen, free_avec);
00219 
00220         float64_t  result=0 ;
00221         for (int32_t j=0; j<alen; j++)
00222         {
00223             int32_t a_idx = compute_index(j, avec[j]) ;
00224             result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00225             result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00226         }
00227         ld_mean_lhs[i]=result ;
00228 
00229         // precompute posterior-log-odds
00230         plo_lhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00231         l->free_feature_vector(avec, alen, free_avec);
00232     } ;
00233 
00234     if (ld_mean_lhs!=ld_mean_rhs)
00235     {
00236         // compute sum of 
00237         //result -= estimate->log_derivative_pos(bvec[i], i)*mean[b_idx]/variance[b_idx] ;
00238         //result -= estimate->log_derivative_neg(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ;    
00239         for (i=0; i < r->get_num_vectors(); i++)
00240         {
00241             int32_t alen;
00242             bool free_avec;
00243             uint16_t* avec=r->get_feature_vector(i, alen, free_avec);
00244 
00245             float64_t  result=0 ;
00246             for (int32_t j=0; j<alen; j++)
00247             {
00248                 int32_t a_idx = compute_index(j, avec[j]) ;
00249                 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00250                 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00251             }
00252             ld_mean_rhs[i]=result ;
00253 
00254             // precompute posterior-log-odds
00255             plo_rhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00256             r->free_feature_vector(avec, alen, free_avec);
00257         } ;
00258     } ;
00259 
00260     //warning hacky
00261     //
00262     this->lhs=l;
00263     this->rhs=l;
00264     plo_lhs = l_plo_lhs ;
00265     plo_rhs = l_plo_lhs ;
00266     ld_mean_lhs = l_ld_mean_lhs ;
00267     ld_mean_rhs = l_ld_mean_lhs ;
00268 
00269     //compute normalize to 1 values
00270     for (i=0; i<l->get_num_vectors(); i++)
00271     {
00272         sqrtdiag_lhs[i]=sqrt(compute(i,i));
00273 
00274         //trap divide by zero exception
00275         if (sqrtdiag_lhs[i]==0)
00276             sqrtdiag_lhs[i]=1e-16;
00277     }
00278 
00279     // if lhs is different from rhs (train/test data)
00280     // compute also the normalization for rhs
00281     if (sqrtdiag_lhs!=sqrtdiag_rhs)
00282     {
00283         this->lhs=r;
00284         this->rhs=r;
00285         plo_lhs = l_plo_rhs ;
00286         plo_rhs = l_plo_rhs ;
00287         ld_mean_lhs = l_ld_mean_rhs ;
00288         ld_mean_rhs = l_ld_mean_rhs ;
00289 
00290         //compute normalize to 1 values
00291         for (i=0; i<r->get_num_vectors(); i++)
00292         {
00293             sqrtdiag_rhs[i]=sqrt(compute(i,i));
00294 
00295             //trap divide by zero exception
00296             if (sqrtdiag_rhs[i]==0)
00297                 sqrtdiag_rhs[i]=1e-16;
00298         }
00299     }
00300 
00301     this->lhs=l;
00302     this->rhs=r;
00303     plo_lhs = l_plo_lhs ;
00304     plo_rhs = l_plo_rhs ;
00305     ld_mean_lhs = l_ld_mean_lhs ;
00306     ld_mean_rhs = l_ld_mean_rhs ;
00307 
00308     initialized = true ;
00309     return init_normalizer();
00310 }
00311 
00312 void CHistogramWordStringKernel::cleanup()
00313 {
00314     delete[] variance;
00315     variance=NULL;
00316 
00317     delete[] mean;
00318     mean=NULL;
00319 
00320     if (sqrtdiag_lhs != sqrtdiag_rhs)
00321         delete[] sqrtdiag_rhs;
00322     sqrtdiag_rhs=NULL;
00323 
00324     delete[] sqrtdiag_lhs;
00325     sqrtdiag_lhs=NULL;
00326 
00327     if (ld_mean_lhs!=ld_mean_rhs)
00328         delete[] ld_mean_rhs ;
00329     ld_mean_rhs=NULL;
00330 
00331     delete[] ld_mean_lhs ;
00332     ld_mean_lhs=NULL;
00333 
00334     if (plo_lhs!=plo_rhs)
00335         delete[] plo_rhs ;
00336     plo_rhs=NULL;
00337 
00338     delete[] plo_lhs ;
00339     plo_lhs=NULL;
00340 
00341     num_params2=0;
00342     num_params=0;
00343     num_symbols=0;
00344     sum_m2_s2=0;
00345     initialized = false;
00346 
00347     CKernel::cleanup();
00348 }
00349 
00350 float64_t CHistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b)
00351 {
00352     int32_t alen, blen;
00353     bool free_avec, free_bvec;
00354     uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00355     uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00356     // can only deal with strings of same length
00357     ASSERT(alen==blen);
00358 
00359     float64_t result = plo_lhs[idx_a]*plo_rhs[idx_b]/variance[0];
00360     result+= sum_m2_s2 ; // does not contain 0-th element
00361 
00362     for (int32_t i=0; i<alen; i++)
00363     {
00364         if (avec[i]==bvec[i])
00365         {
00366             int32_t a_idx = compute_index(i, avec[i]) ;
00367             float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00368             result   += dd*dd/variance[a_idx] ;
00369             dd        = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00370             result   += dd*dd/variance[a_idx+num_params] ;
00371         } ;
00372     }
00373     result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ;
00374 
00375     if (initialized)
00376         result /=  (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00377 
00378 #ifdef BLABLA
00379     float64_t result2 = compute_slow(idx_a, idx_b) ;
00380     if (fabs(result - result2)>1e-10)
00381         SG_ERROR("new=%e  old = %e  diff = %e\n", result, result2, result - result2);
00382 #endif
00383     ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00384     ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00385     return result;
00386 }
00387 
00388 #ifdef BLABLA
00389 
00390 float64_t CHistogramWordStringKernel::compute_slow(int32_t idx_a, int32_t idx_b)
00391 {
00392     int32_t alen, blen;
00393     bool free_avec, free_bvec;
00394     uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00395     uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00396     // can only deal with strings of same length
00397     ASSERT(alen==blen);
00398 
00399     float64_t result=(estimate->posterior_log_odds_obsolete(avec, alen)-mean[0])*
00400         (estimate->posterior_log_odds_obsolete(bvec, blen)-mean[0])/(variance[0]);
00401     result+= sum_m2_s2 ; // does not contain 0-th element
00402 
00403     for (int32_t i=0; i<alen; i++)
00404     {
00405         int32_t a_idx = compute_index(i, avec[i]) ;
00406         int32_t b_idx = compute_index(i, bvec[i]) ;
00407 
00408         if (avec[i]==bvec[i])
00409         {
00410             float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00411             result   += dd*dd/variance[a_idx] ;
00412             dd        = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00413             result   += dd*dd/variance[a_idx+num_params] ;
00414         } ;
00415 
00416         result -= estimate->log_derivative_pos_obsolete(avec[i], i)*mean[a_idx]/variance[a_idx] ;
00417         result -= estimate->log_derivative_pos_obsolete(bvec[i], i)*mean[b_idx]/variance[b_idx] ;
00418         result -= estimate->log_derivative_neg_obsolete(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00419         result -= estimate->log_derivative_neg_obsolete(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ;
00420     }
00421 
00422     if (initialized)
00423         result /=  (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00424 
00425     ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00426     ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00427     return result;
00428 }
00429 
00430 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation