00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include "lib/common.h"
00013 #include "kernel/HistogramWordStringKernel.h"
00014 #include "features/Features.h"
00015 #include "features/StringFeatures.h"
00016 #include "classifier/PluginEstimate.h"
00017 #include "lib/io.h"
00018
00019 CHistogramWordStringKernel::CHistogramWordStringKernel(int32_t size, CPluginEstimate* pie)
00020 : CStringKernel<uint16_t>(size), estimate(pie), mean(NULL), variance(NULL),
00021 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00022 ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00023 plo_lhs(NULL), plo_rhs(NULL), num_params(0), num_params2(0),
00024 num_symbols(0), sum_m2_s2(0), initialized(false)
00025 {
00026 }
00027
00028 CHistogramWordStringKernel::CHistogramWordStringKernel(
00029 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, CPluginEstimate* pie)
00030 : CStringKernel<uint16_t>(10), estimate(pie), mean(NULL), variance(NULL),
00031 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00032 ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00033 plo_lhs(NULL), plo_rhs(NULL), num_params(0), num_params2(0),
00034 num_symbols(0), sum_m2_s2(0), initialized(false)
00035 {
00036 init(l, r);
00037 }
00038
00039 CHistogramWordStringKernel::~CHistogramWordStringKernel()
00040 {
00041 delete[] variance;
00042 delete[] mean;
00043 if (sqrtdiag_lhs != sqrtdiag_rhs)
00044 delete[] sqrtdiag_rhs;
00045 delete[] sqrtdiag_lhs;
00046 if (ld_mean_lhs!=ld_mean_rhs)
00047 delete[] ld_mean_rhs ;
00048 delete[] ld_mean_lhs ;
00049 if (plo_lhs!=plo_rhs)
00050 delete[] plo_rhs ;
00051 delete[] plo_lhs ;
00052 }
00053
00054 bool CHistogramWordStringKernel::init(CFeatures* p_l, CFeatures* p_r)
00055 {
00056 CStringKernel<uint16_t>::init(p_l,p_r);
00057 CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l;
00058 CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r;
00059 ASSERT(l);
00060 ASSERT(r);
00061
00062 SG_DEBUG( "init: lhs: %ld rhs: %ld\n", l, r) ;
00063 int32_t i;
00064 initialized=false;
00065
00066 if (sqrtdiag_lhs != sqrtdiag_rhs)
00067 delete[] sqrtdiag_rhs;
00068 sqrtdiag_rhs=NULL ;
00069 delete[] sqrtdiag_lhs;
00070 sqrtdiag_lhs=NULL ;
00071 if (ld_mean_lhs!=ld_mean_rhs)
00072 delete[] ld_mean_rhs ;
00073 ld_mean_rhs=NULL ;
00074 delete[] ld_mean_lhs ;
00075 ld_mean_lhs=NULL ;
00076 if (plo_lhs!=plo_rhs)
00077 delete[] plo_rhs ;
00078 plo_rhs=NULL ;
00079 delete[] plo_lhs ;
00080 plo_lhs=NULL ;
00081
00082 sqrtdiag_lhs= new float64_t[l->get_num_vectors()];
00083 ld_mean_lhs = new float64_t[l->get_num_vectors()];
00084 plo_lhs = new float64_t[l->get_num_vectors()];
00085
00086 for (i=0; i<l->get_num_vectors(); i++)
00087 sqrtdiag_lhs[i]=1;
00088
00089 if (l==r)
00090 {
00091 sqrtdiag_rhs=sqrtdiag_lhs;
00092 ld_mean_rhs=ld_mean_lhs;
00093 plo_rhs=plo_lhs;
00094 }
00095 else
00096 {
00097 sqrtdiag_rhs=new float64_t[r->get_num_vectors()];
00098 for (i=0; i<r->get_num_vectors(); i++)
00099 sqrtdiag_rhs[i]=1;
00100
00101 ld_mean_rhs=new float64_t[r->get_num_vectors()];
00102 plo_rhs=new float64_t[r->get_num_vectors()];
00103 }
00104
00105 float64_t* l_plo_lhs=plo_lhs;
00106 float64_t* l_plo_rhs=plo_rhs;
00107 float64_t* l_ld_mean_lhs=ld_mean_lhs;
00108 float64_t* l_ld_mean_rhs=ld_mean_rhs;
00109
00110
00111 if (!initialized)
00112 {
00113 int32_t num_vectors=l->get_num_vectors();
00114 num_symbols=(int32_t) l->get_num_symbols();
00115 int32_t llen=l->get_vector_length(0);
00116 int32_t rlen=r->get_vector_length(0);
00117 num_params=llen*((int32_t) l->get_num_symbols());
00118 num_params2=llen*((int32_t) l->get_num_symbols())+rlen*((int32_t) r->get_num_symbols());
00119
00120 if ((!estimate) || (!estimate->check_models()))
00121 {
00122 SG_ERROR( "no estimate available\n");
00123 return false ;
00124 } ;
00125 if (num_params2!=estimate->get_num_params())
00126 {
00127 SG_ERROR( "number of parameters of estimate and feature representation do not match\n");
00128 return false ;
00129 } ;
00130
00131
00132 num_params2++;
00133
00134 delete[] mean;
00135 mean=new float64_t[num_params2];
00136 delete[] variance;
00137 variance=new float64_t[num_params2];
00138
00139 for (i=0; i<num_params2; i++)
00140 {
00141 mean[i]=0;
00142 variance[i]=0;
00143 }
00144
00145
00146 for (i=0; i<num_vectors; i++)
00147 {
00148 int32_t len;
00149 uint16_t* vec=l->get_feature_vector(i, len);
00150
00151 mean[0]+=estimate->posterior_log_odds_obsolete(vec, len)/num_vectors;
00152
00153 for (int32_t j=0; j<len; j++)
00154 {
00155 int32_t idx=compute_index(j, vec[j]);
00156 mean[idx] += estimate->log_derivative_pos_obsolete(vec[j], j)/num_vectors;
00157 mean[idx+num_params] += estimate->log_derivative_neg_obsolete(vec[j], j)/num_vectors;
00158 }
00159 }
00160
00161
00162 for (i=0; i<num_vectors; i++)
00163 {
00164 int32_t len;
00165 uint16_t* vec=l->get_feature_vector(i, len);
00166
00167 variance[0] += CMath::sq(estimate->posterior_log_odds_obsolete(vec, len)-mean[0])/num_vectors;
00168
00169 for (int32_t j=0; j<len; j++)
00170 {
00171 for (int32_t k=0; k<4; k++)
00172 {
00173 int32_t idx=compute_index(j, k);
00174 if (k!=vec[j])
00175 {
00176 variance[idx]+=mean[idx]*mean[idx]/num_vectors;
00177 variance[idx+num_params]+=mean[idx+num_params]*mean[idx+num_params]/num_vectors;
00178 }
00179 else
00180 {
00181 variance[idx] += CMath::sq(estimate->log_derivative_pos_obsolete(vec[j], j)
00182 -mean[idx])/num_vectors;
00183 variance[idx+num_params] += CMath::sq(estimate->log_derivative_neg_obsolete(vec[j], j)
00184 -mean[idx+num_params])/num_vectors;
00185 }
00186 }
00187 }
00188 }
00189
00190
00191
00192 sum_m2_s2=0 ;
00193 for (i=1; i<num_params2; i++)
00194 {
00195 if (variance[i]<1e-14)
00196 variance[i]=1 ;
00197
00198
00199 sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ;
00200 } ;
00201 }
00202
00203
00204
00205
00206 for (i=0; i<l->get_num_vectors(); i++)
00207 {
00208 int32_t alen ;
00209 uint16_t* avec = l->get_feature_vector(i, alen);
00210 float64_t result=0 ;
00211 for (int32_t j=0; j<alen; j++)
00212 {
00213 int32_t a_idx = compute_index(j, avec[j]) ;
00214 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00215 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00216 }
00217 ld_mean_lhs[i]=result ;
00218
00219
00220 plo_lhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00221 } ;
00222
00223 if (ld_mean_lhs!=ld_mean_rhs)
00224 {
00225
00226
00227
00228 for (i=0; i < r->get_num_vectors(); i++)
00229 {
00230 int32_t alen ;
00231 uint16_t* avec=r->get_feature_vector(i, alen);
00232 float64_t result=0 ;
00233 for (int32_t j=0; j<alen; j++)
00234 {
00235 int32_t a_idx = compute_index(j, avec[j]) ;
00236 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00237 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00238 }
00239 ld_mean_rhs[i]=result ;
00240
00241
00242 plo_rhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00243 } ;
00244 } ;
00245
00246
00247
00248 this->lhs=l;
00249 this->rhs=l;
00250 plo_lhs = l_plo_lhs ;
00251 plo_rhs = l_plo_lhs ;
00252 ld_mean_lhs = l_ld_mean_lhs ;
00253 ld_mean_rhs = l_ld_mean_lhs ;
00254
00255
00256 for (i=0; i<l->get_num_vectors(); i++)
00257 {
00258 sqrtdiag_lhs[i]=sqrt(compute(i,i));
00259
00260
00261 if (sqrtdiag_lhs[i]==0)
00262 sqrtdiag_lhs[i]=1e-16;
00263 }
00264
00265
00266
00267 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00268 {
00269 this->lhs=r;
00270 this->rhs=r;
00271 plo_lhs = l_plo_rhs ;
00272 plo_rhs = l_plo_rhs ;
00273 ld_mean_lhs = l_ld_mean_rhs ;
00274 ld_mean_rhs = l_ld_mean_rhs ;
00275
00276
00277 for (i=0; i<r->get_num_vectors(); i++)
00278 {
00279 sqrtdiag_rhs[i]=sqrt(compute(i,i));
00280
00281
00282 if (sqrtdiag_rhs[i]==0)
00283 sqrtdiag_rhs[i]=1e-16;
00284 }
00285 }
00286
00287 this->lhs=l;
00288 this->rhs=r;
00289 plo_lhs = l_plo_lhs ;
00290 plo_rhs = l_plo_rhs ;
00291 ld_mean_lhs = l_ld_mean_lhs ;
00292 ld_mean_rhs = l_ld_mean_rhs ;
00293
00294 initialized = true ;
00295 return init_normalizer();
00296 }
00297
00298 void CHistogramWordStringKernel::cleanup()
00299 {
00300 delete[] variance;
00301 variance=NULL;
00302
00303 delete[] mean;
00304 mean=NULL;
00305
00306 if (sqrtdiag_lhs != sqrtdiag_rhs)
00307 delete[] sqrtdiag_rhs;
00308 sqrtdiag_rhs=NULL;
00309
00310 delete[] sqrtdiag_lhs;
00311 sqrtdiag_lhs=NULL;
00312
00313 if (ld_mean_lhs!=ld_mean_rhs)
00314 delete[] ld_mean_rhs ;
00315 ld_mean_rhs=NULL;
00316
00317 delete[] ld_mean_lhs ;
00318 ld_mean_lhs=NULL;
00319
00320 if (plo_lhs!=plo_rhs)
00321 delete[] plo_rhs ;
00322 plo_rhs=NULL;
00323
00324 delete[] plo_lhs ;
00325 plo_lhs=NULL;
00326
00327 num_params2=0;
00328 num_params=0;
00329 num_symbols=0;
00330 sum_m2_s2=0;
00331 initialized = false;
00332
00333 CKernel::cleanup();
00334 }
00335
00336 bool CHistogramWordStringKernel::load_init(FILE* src)
00337 {
00338 return false;
00339 }
00340
00341 bool CHistogramWordStringKernel::save_init(FILE* dest)
00342 {
00343 return false;
00344 }
00345
00346
00347
00348 float64_t CHistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b)
00349 {
00350 int32_t alen, blen;
00351 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen);
00352 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen);
00353
00354 ASSERT(alen==blen);
00355
00356 float64_t result = plo_lhs[idx_a]*plo_rhs[idx_b]/variance[0];
00357 result+= sum_m2_s2 ;
00358
00359 for (int32_t i=0; i<alen; i++)
00360 {
00361 if (avec[i]==bvec[i])
00362 {
00363 int32_t a_idx = compute_index(i, avec[i]) ;
00364 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00365 result += dd*dd/variance[a_idx] ;
00366 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00367 result += dd*dd/variance[a_idx+num_params] ;
00368 } ;
00369 }
00370 result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ;
00371
00372 if (initialized)
00373 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00374
00375
00376 #ifdef BLABLA
00377 float64_t result2 = compute_slow(idx_a, idx_b) ;
00378 if (fabs(result - result2)>1e-10)
00379 {
00380 fprintf(stderr, "new=%e old = %e diff = %e\n", result, result2, result - result2) ;
00381 ASSERT(0) ;
00382 } ;
00383 #endif
00384 return result;
00385 }
00386
00387 #ifdef BLABLA
00388
00389 float64_t CHistogramWordStringKernel::compute_slow(int32_t idx_a, int32_t idx_b)
00390 {
00391 int32_t alen, blen;
00392 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen);
00393 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen);
00394
00395 ASSERT(alen==blen);
00396
00397 float64_t result=(estimate->posterior_log_odds_obsolete(avec, alen)-mean[0])*
00398 (estimate->posterior_log_odds_obsolete(bvec, blen)-mean[0])/(variance[0]);
00399 result+= sum_m2_s2 ;
00400
00401 for (int32_t i=0; i<alen; i++)
00402 {
00403 int32_t a_idx = compute_index(i, avec[i]) ;
00404 int32_t b_idx = compute_index(i, bvec[i]) ;
00405
00406 if (avec[i]==bvec[i])
00407 {
00408 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00409 result += dd*dd/variance[a_idx] ;
00410 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00411 result += dd*dd/variance[a_idx+num_params] ;
00412 } ;
00413
00414 result -= estimate->log_derivative_pos_obsolete(avec[i], i)*mean[a_idx]/variance[a_idx] ;
00415 result -= estimate->log_derivative_pos_obsolete(bvec[i], i)*mean[b_idx]/variance[b_idx] ;
00416 result -= estimate->log_derivative_neg_obsolete(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00417 result -= estimate->log_derivative_neg_obsolete(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ;
00418 }
00419
00420 if (initialized)
00421 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00422
00423 return result;
00424 }
00425
00426 #endif