00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include "lib/common.h"
00013 #include "kernel/HistogramWordStringKernel.h"
00014 #include "features/Features.h"
00015 #include "features/StringFeatures.h"
00016 #include "classifier/PluginEstimate.h"
00017 #include "lib/io.h"
00018
00019 using namespace shogun;
00020
00021 CHistogramWordStringKernel::CHistogramWordStringKernel(int32_t size, CPluginEstimate* pie)
00022 : CStringKernel<uint16_t>(size), estimate(pie), mean(NULL), variance(NULL),
00023 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00024 ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00025 plo_lhs(NULL), plo_rhs(NULL), num_params(0), num_params2(0),
00026 num_symbols(0), sum_m2_s2(0), initialized(false)
00027 {
00028 }
00029
00030 CHistogramWordStringKernel::CHistogramWordStringKernel(
00031 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, CPluginEstimate* pie)
00032 : CStringKernel<uint16_t>(10), estimate(pie), mean(NULL), variance(NULL),
00033 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00034 ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00035 plo_lhs(NULL), plo_rhs(NULL), num_params(0), num_params2(0),
00036 num_symbols(0), sum_m2_s2(0), initialized(false)
00037 {
00038 init(l, r);
00039 }
00040
00041 CHistogramWordStringKernel::~CHistogramWordStringKernel()
00042 {
00043 delete[] variance;
00044 delete[] mean;
00045 if (sqrtdiag_lhs != sqrtdiag_rhs)
00046 delete[] sqrtdiag_rhs;
00047 delete[] sqrtdiag_lhs;
00048 if (ld_mean_lhs!=ld_mean_rhs)
00049 delete[] ld_mean_rhs ;
00050 delete[] ld_mean_lhs ;
00051 if (plo_lhs!=plo_rhs)
00052 delete[] plo_rhs ;
00053 delete[] plo_lhs ;
00054 }
00055
00056 bool CHistogramWordStringKernel::init(CFeatures* p_l, CFeatures* p_r)
00057 {
00058 CStringKernel<uint16_t>::init(p_l,p_r);
00059 CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l;
00060 CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r;
00061 ASSERT(l);
00062 ASSERT(r);
00063
00064 SG_DEBUG( "init: lhs: %ld rhs: %ld\n", l, r) ;
00065 int32_t i;
00066 initialized=false;
00067
00068 if (sqrtdiag_lhs != sqrtdiag_rhs)
00069 delete[] sqrtdiag_rhs;
00070 sqrtdiag_rhs=NULL ;
00071 delete[] sqrtdiag_lhs;
00072 sqrtdiag_lhs=NULL ;
00073 if (ld_mean_lhs!=ld_mean_rhs)
00074 delete[] ld_mean_rhs ;
00075 ld_mean_rhs=NULL ;
00076 delete[] ld_mean_lhs ;
00077 ld_mean_lhs=NULL ;
00078 if (plo_lhs!=plo_rhs)
00079 delete[] plo_rhs ;
00080 plo_rhs=NULL ;
00081 delete[] plo_lhs ;
00082 plo_lhs=NULL ;
00083
00084 sqrtdiag_lhs= new float64_t[l->get_num_vectors()];
00085 ld_mean_lhs = new float64_t[l->get_num_vectors()];
00086 plo_lhs = new float64_t[l->get_num_vectors()];
00087
00088 for (i=0; i<l->get_num_vectors(); i++)
00089 sqrtdiag_lhs[i]=1;
00090
00091 if (l==r)
00092 {
00093 sqrtdiag_rhs=sqrtdiag_lhs;
00094 ld_mean_rhs=ld_mean_lhs;
00095 plo_rhs=plo_lhs;
00096 }
00097 else
00098 {
00099 sqrtdiag_rhs=new float64_t[r->get_num_vectors()];
00100 for (i=0; i<r->get_num_vectors(); i++)
00101 sqrtdiag_rhs[i]=1;
00102
00103 ld_mean_rhs=new float64_t[r->get_num_vectors()];
00104 plo_rhs=new float64_t[r->get_num_vectors()];
00105 }
00106
00107 float64_t* l_plo_lhs=plo_lhs;
00108 float64_t* l_plo_rhs=plo_rhs;
00109 float64_t* l_ld_mean_lhs=ld_mean_lhs;
00110 float64_t* l_ld_mean_rhs=ld_mean_rhs;
00111
00112
00113 if (!initialized)
00114 {
00115 int32_t num_vectors=l->get_num_vectors();
00116 num_symbols=(int32_t) l->get_num_symbols();
00117 int32_t llen=l->get_vector_length(0);
00118 int32_t rlen=r->get_vector_length(0);
00119 num_params=llen*((int32_t) l->get_num_symbols());
00120 num_params2=llen*((int32_t) l->get_num_symbols())+rlen*((int32_t) r->get_num_symbols());
00121
00122 if ((!estimate) || (!estimate->check_models()))
00123 {
00124 SG_ERROR( "no estimate available\n");
00125 return false ;
00126 } ;
00127 if (num_params2!=estimate->get_num_params())
00128 {
00129 SG_ERROR( "number of parameters of estimate and feature representation do not match\n");
00130 return false ;
00131 } ;
00132
00133
00134 num_params2++;
00135
00136 delete[] mean;
00137 mean=new float64_t[num_params2];
00138 delete[] variance;
00139 variance=new float64_t[num_params2];
00140
00141 for (i=0; i<num_params2; i++)
00142 {
00143 mean[i]=0;
00144 variance[i]=0;
00145 }
00146
00147
00148 for (i=0; i<num_vectors; i++)
00149 {
00150 int32_t len;
00151 bool free_vec;
00152 uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00153
00154 mean[0]+=estimate->posterior_log_odds_obsolete(vec, len)/num_vectors;
00155
00156 for (int32_t j=0; j<len; j++)
00157 {
00158 int32_t idx=compute_index(j, vec[j]);
00159 mean[idx] += estimate->log_derivative_pos_obsolete(vec[j], j)/num_vectors;
00160 mean[idx+num_params] += estimate->log_derivative_neg_obsolete(vec[j], j)/num_vectors;
00161 }
00162
00163 l->free_feature_vector(vec, i, free_vec);
00164 }
00165
00166
00167 for (i=0; i<num_vectors; i++)
00168 {
00169 int32_t len;
00170 bool free_vec;
00171 uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00172
00173 variance[0] += CMath::sq(estimate->posterior_log_odds_obsolete(vec, len)-mean[0])/num_vectors;
00174
00175 for (int32_t j=0; j<len; j++)
00176 {
00177 for (int32_t k=0; k<4; k++)
00178 {
00179 int32_t idx=compute_index(j, k);
00180 if (k!=vec[j])
00181 {
00182 variance[idx]+=mean[idx]*mean[idx]/num_vectors;
00183 variance[idx+num_params]+=mean[idx+num_params]*mean[idx+num_params]/num_vectors;
00184 }
00185 else
00186 {
00187 variance[idx] += CMath::sq(estimate->log_derivative_pos_obsolete(vec[j], j)
00188 -mean[idx])/num_vectors;
00189 variance[idx+num_params] += CMath::sq(estimate->log_derivative_neg_obsolete(vec[j], j)
00190 -mean[idx+num_params])/num_vectors;
00191 }
00192 }
00193 }
00194
00195 l->free_feature_vector(vec, i, free_vec);
00196 }
00197
00198
00199
00200 sum_m2_s2=0 ;
00201 for (i=1; i<num_params2; i++)
00202 {
00203 if (variance[i]<1e-14)
00204 variance[i]=1 ;
00205
00206
00207 sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ;
00208 } ;
00209 }
00210
00211
00212
00213
00214 for (i=0; i<l->get_num_vectors(); i++)
00215 {
00216 int32_t alen;
00217 bool free_avec;
00218 uint16_t* avec = l->get_feature_vector(i, alen, free_avec);
00219
00220 float64_t result=0 ;
00221 for (int32_t j=0; j<alen; j++)
00222 {
00223 int32_t a_idx = compute_index(j, avec[j]) ;
00224 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00225 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00226 }
00227 ld_mean_lhs[i]=result ;
00228
00229
00230 plo_lhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00231 l->free_feature_vector(avec, alen, free_avec);
00232 } ;
00233
00234 if (ld_mean_lhs!=ld_mean_rhs)
00235 {
00236
00237
00238
00239 for (i=0; i < r->get_num_vectors(); i++)
00240 {
00241 int32_t alen;
00242 bool free_avec;
00243 uint16_t* avec=r->get_feature_vector(i, alen, free_avec);
00244
00245 float64_t result=0 ;
00246 for (int32_t j=0; j<alen; j++)
00247 {
00248 int32_t a_idx = compute_index(j, avec[j]) ;
00249 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00250 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00251 }
00252 ld_mean_rhs[i]=result ;
00253
00254
00255 plo_rhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00256 r->free_feature_vector(avec, alen, free_avec);
00257 } ;
00258 } ;
00259
00260
00261
00262 this->lhs=l;
00263 this->rhs=l;
00264 plo_lhs = l_plo_lhs ;
00265 plo_rhs = l_plo_lhs ;
00266 ld_mean_lhs = l_ld_mean_lhs ;
00267 ld_mean_rhs = l_ld_mean_lhs ;
00268
00269
00270 for (i=0; i<l->get_num_vectors(); i++)
00271 {
00272 sqrtdiag_lhs[i]=sqrt(compute(i,i));
00273
00274
00275 if (sqrtdiag_lhs[i]==0)
00276 sqrtdiag_lhs[i]=1e-16;
00277 }
00278
00279
00280
00281 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00282 {
00283 this->lhs=r;
00284 this->rhs=r;
00285 plo_lhs = l_plo_rhs ;
00286 plo_rhs = l_plo_rhs ;
00287 ld_mean_lhs = l_ld_mean_rhs ;
00288 ld_mean_rhs = l_ld_mean_rhs ;
00289
00290
00291 for (i=0; i<r->get_num_vectors(); i++)
00292 {
00293 sqrtdiag_rhs[i]=sqrt(compute(i,i));
00294
00295
00296 if (sqrtdiag_rhs[i]==0)
00297 sqrtdiag_rhs[i]=1e-16;
00298 }
00299 }
00300
00301 this->lhs=l;
00302 this->rhs=r;
00303 plo_lhs = l_plo_lhs ;
00304 plo_rhs = l_plo_rhs ;
00305 ld_mean_lhs = l_ld_mean_lhs ;
00306 ld_mean_rhs = l_ld_mean_rhs ;
00307
00308 initialized = true ;
00309 return init_normalizer();
00310 }
00311
00312 void CHistogramWordStringKernel::cleanup()
00313 {
00314 delete[] variance;
00315 variance=NULL;
00316
00317 delete[] mean;
00318 mean=NULL;
00319
00320 if (sqrtdiag_lhs != sqrtdiag_rhs)
00321 delete[] sqrtdiag_rhs;
00322 sqrtdiag_rhs=NULL;
00323
00324 delete[] sqrtdiag_lhs;
00325 sqrtdiag_lhs=NULL;
00326
00327 if (ld_mean_lhs!=ld_mean_rhs)
00328 delete[] ld_mean_rhs ;
00329 ld_mean_rhs=NULL;
00330
00331 delete[] ld_mean_lhs ;
00332 ld_mean_lhs=NULL;
00333
00334 if (plo_lhs!=plo_rhs)
00335 delete[] plo_rhs ;
00336 plo_rhs=NULL;
00337
00338 delete[] plo_lhs ;
00339 plo_lhs=NULL;
00340
00341 num_params2=0;
00342 num_params=0;
00343 num_symbols=0;
00344 sum_m2_s2=0;
00345 initialized = false;
00346
00347 CKernel::cleanup();
00348 }
00349
00350 float64_t CHistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b)
00351 {
00352 int32_t alen, blen;
00353 bool free_avec, free_bvec;
00354 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00355 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00356
00357 ASSERT(alen==blen);
00358
00359 float64_t result = plo_lhs[idx_a]*plo_rhs[idx_b]/variance[0];
00360 result+= sum_m2_s2 ;
00361
00362 for (int32_t i=0; i<alen; i++)
00363 {
00364 if (avec[i]==bvec[i])
00365 {
00366 int32_t a_idx = compute_index(i, avec[i]) ;
00367 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00368 result += dd*dd/variance[a_idx] ;
00369 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00370 result += dd*dd/variance[a_idx+num_params] ;
00371 } ;
00372 }
00373 result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ;
00374
00375 if (initialized)
00376 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00377
00378 #ifdef BLABLA
00379 float64_t result2 = compute_slow(idx_a, idx_b) ;
00380 if (fabs(result - result2)>1e-10)
00381 SG_ERROR("new=%e old = %e diff = %e\n", result, result2, result - result2);
00382 #endif
00383 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00384 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00385 return result;
00386 }
00387
00388 #ifdef BLABLA
00389
00390 float64_t CHistogramWordStringKernel::compute_slow(int32_t idx_a, int32_t idx_b)
00391 {
00392 int32_t alen, blen;
00393 bool free_avec, free_bvec;
00394 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00395 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00396
00397 ASSERT(alen==blen);
00398
00399 float64_t result=(estimate->posterior_log_odds_obsolete(avec, alen)-mean[0])*
00400 (estimate->posterior_log_odds_obsolete(bvec, blen)-mean[0])/(variance[0]);
00401 result+= sum_m2_s2 ;
00402
00403 for (int32_t i=0; i<alen; i++)
00404 {
00405 int32_t a_idx = compute_index(i, avec[i]) ;
00406 int32_t b_idx = compute_index(i, bvec[i]) ;
00407
00408 if (avec[i]==bvec[i])
00409 {
00410 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00411 result += dd*dd/variance[a_idx] ;
00412 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00413 result += dd*dd/variance[a_idx+num_params] ;
00414 } ;
00415
00416 result -= estimate->log_derivative_pos_obsolete(avec[i], i)*mean[a_idx]/variance[a_idx] ;
00417 result -= estimate->log_derivative_pos_obsolete(bvec[i], i)*mean[b_idx]/variance[b_idx] ;
00418 result -= estimate->log_derivative_neg_obsolete(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00419 result -= estimate->log_derivative_neg_obsolete(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ;
00420 }
00421
00422 if (initialized)
00423 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00424
00425 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00426 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00427 return result;
00428 }
00429
00430 #endif