LDA.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 
00013 #ifdef HAVE_LAPACK
00014 #include "classifier/Classifier.h"
00015 #include "classifier/LinearClassifier.h"
00016 #include "classifier/LDA.h"
00017 #include "features/Labels.h"
00018 #include "lib/Mathematics.h"
00019 #include "lib/lapack.h"
00020 
00021 CLDA::CLDA(float64_t gamma)
00022 : CLinearClassifier(), m_gamma(gamma)
00023 {
00024 }
00025 
00026 CLDA::CLDA(float64_t gamma, CSimpleFeatures<float64_t>* traindat, CLabels* trainlab)
00027 : CLinearClassifier(), m_gamma(gamma)
00028 {
00029     set_features(traindat);
00030     set_labels(trainlab);
00031 }
00032 
00033 
00034 CLDA::~CLDA()
00035 {
00036 }
00037 
00038 bool CLDA::train()
00039 {
00040     ASSERT(labels);
00041     ASSERT(features);
00042     int32_t num_train_labels=0;
00043     int32_t* train_labels=labels->get_int_labels(num_train_labels);
00044     ASSERT(train_labels);
00045 
00046     int32_t num_feat=features->get_dim_feature_space();
00047     int32_t num_vec=features->get_num_vectors();
00048     ASSERT(num_vec==num_train_labels);
00049 
00050     int32_t* classidx_neg=new int32_t[num_vec];
00051     int32_t* classidx_pos=new int32_t[num_vec];
00052 
00053     int32_t i=0;
00054     int32_t j=0;
00055     int32_t num_neg=0;
00056     int32_t num_pos=0;
00057     for (i=0; i<num_train_labels; i++)
00058     {
00059         if (train_labels[i]==-1)
00060             classidx_neg[num_neg++]=i;
00061         else if (train_labels[i]==+1)
00062             classidx_pos[num_pos++]=i;
00063         else
00064         {
00065             SG_ERROR( "found label != +/- 1 bailing...");
00066             return false;
00067         }
00068     }
00069 
00070     if (num_neg<=0 && num_pos<=0)
00071     {
00072       SG_ERROR( "whooooo ? only a single class found\n");
00073         return false;
00074     }
00075 
00076     delete[] w;
00077     w=new float64_t[num_feat];
00078     w_dim=num_feat;
00079 
00080     float64_t* mean_neg=new float64_t[num_feat];
00081     memset(mean_neg,0,num_feat*sizeof(float64_t));
00082 
00083     float64_t* mean_pos=new float64_t[num_feat];
00084     memset(mean_pos,0,num_feat*sizeof(float64_t));
00085 
00086     /* calling external lib */
00087     double* scatter=new double[num_feat*num_feat];
00088     double* buffer=new double[num_feat*CMath::max(num_neg, num_pos)];
00089     int nf = (int) num_feat;
00090 
00091     CSimpleFeatures<float64_t>* rf = (CSimpleFeatures<float64_t>*) features;
00092     //mean neg
00093     for (i=0; i<num_neg; i++)
00094     {
00095         int32_t vlen;
00096         bool vfree;
00097         float64_t* vec=
00098             rf->get_feature_vector(classidx_neg[i], vlen, vfree);
00099         ASSERT(vec);
00100 
00101         for (j=0; j<vlen; j++)
00102         {
00103             mean_neg[j]+=vec[j];
00104             buffer[num_feat*i+j]=vec[j];
00105         }
00106 
00107         rf->free_feature_vector(vec, classidx_neg[i], vfree);
00108     }
00109 
00110     for (j=0; j<num_feat; j++)
00111         mean_neg[j]/=num_neg;
00112 
00113     for (i=0; i<num_neg; i++)
00114     {
00115         for (j=0; j<num_feat; j++)
00116             buffer[num_feat*i+j]-=mean_neg[j];
00117     }
00118     cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf,
00119         (int) num_neg, 1.0, buffer, nf, buffer, nf, 0, scatter, nf);
00120     
00121     //mean pos
00122     for (i=0; i<num_pos; i++)
00123     {
00124         int32_t vlen;
00125         bool vfree;
00126         float64_t* vec=
00127             rf->get_feature_vector(classidx_pos[i], vlen, vfree);
00128         ASSERT(vec);
00129 
00130         for (j=0; j<vlen; j++)
00131         {
00132             mean_pos[j]+=vec[j];
00133             buffer[num_feat*i+j]=vec[j];
00134         }
00135 
00136         rf->free_feature_vector(vec, classidx_pos[i], vfree);
00137     }
00138 
00139     for (j=0; j<num_feat; j++)
00140         mean_pos[j]/=num_pos;
00141 
00142     for (i=0; i<num_pos; i++)
00143     {
00144         for (j=0; j<num_feat; j++)
00145             buffer[num_feat*i+j]-=mean_pos[j];
00146     }
00147     cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf, (int) num_pos,
00148         1.0/(num_train_labels-1), buffer, nf, buffer, nf,
00149         1.0/(num_train_labels-1), scatter, nf);
00150 
00151     float64_t trace=CMath::trace((float64_t*) scatter, num_feat, num_feat);
00152 
00153     double s=1.0-m_gamma; /* calling external lib; indirectly */
00154     for (i=0; i<num_feat*num_feat; i++)
00155         scatter[i]*=s;
00156 
00157     for (i=0; i<num_feat; i++)
00158         scatter[i*num_feat+i]+= trace*m_gamma/num_feat;
00159 
00160     double* inv_scatter= (double*) CMath::pinv(
00161         scatter, num_feat, num_feat, NULL);
00162 
00163     float64_t* w_pos=buffer;
00164     float64_t* w_neg=&buffer[num_feat];
00165 
00166     cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00167         (double*) mean_pos, 1, 0., (double*) w_pos, 1);
00168     cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00169         (double*) mean_neg, 1, 0, (double*) w_neg, 1);
00170     
00171     bias=0.5*(CMath::dot(w_neg, mean_neg, num_feat)-CMath::dot(w_pos, mean_pos, num_feat));
00172     for (i=0; i<num_feat; i++)
00173         w[i]=w_pos[i]-w_neg[i];
00174 
00175 #ifdef DEBUG_LDA
00176     SG_PRINT("bias: %f\n", bias);
00177     CMath::display_vector(w, num_feat, "w");
00178     CMath::display_vector(w_pos, num_feat, "w_pos");
00179     CMath::display_vector(w_neg, num_feat, "w_neg");
00180     CMath::display_vector(mean_pos, num_feat, "mean_pos");
00181     CMath::display_vector(mean_neg, num_feat, "mean_neg");
00182 #endif
00183 
00184     delete[] train_labels;
00185     delete[] mean_neg;
00186     delete[] mean_pos;
00187     delete[] scatter;
00188     delete[] inv_scatter;
00189     delete[] classidx_neg;
00190     delete[] classidx_pos;
00191     delete[] buffer;
00192     return true;
00193 }
00194 #endif

SHOGUN Machine Learning Toolbox - Documentation