00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012
00013 #ifdef HAVE_LAPACK
00014 #include "classifier/Classifier.h"
00015 #include "classifier/LinearClassifier.h"
00016 #include "classifier/LDA.h"
00017 #include "features/Labels.h"
00018 #include "lib/Mathematics.h"
00019 #include "lib/lapack.h"
00020
00021 CLDA::CLDA(float64_t gamma)
00022 : CLinearClassifier(), m_gamma(gamma)
00023 {
00024 }
00025
00026 CLDA::CLDA(float64_t gamma, CSimpleFeatures<float64_t>* traindat, CLabels* trainlab)
00027 : CLinearClassifier(), m_gamma(gamma)
00028 {
00029 set_features(traindat);
00030 set_labels(trainlab);
00031 }
00032
00033
00034 CLDA::~CLDA()
00035 {
00036 }
00037
00038 bool CLDA::train()
00039 {
00040 ASSERT(labels);
00041 ASSERT(features);
00042 int32_t num_train_labels=0;
00043 int32_t* train_labels=labels->get_int_labels(num_train_labels);
00044 ASSERT(train_labels);
00045
00046 int32_t num_feat=features->get_dim_feature_space();
00047 int32_t num_vec=features->get_num_vectors();
00048 ASSERT(num_vec==num_train_labels);
00049
00050 int32_t* classidx_neg=new int32_t[num_vec];
00051 int32_t* classidx_pos=new int32_t[num_vec];
00052
00053 int32_t i=0;
00054 int32_t j=0;
00055 int32_t num_neg=0;
00056 int32_t num_pos=0;
00057 for (i=0; i<num_train_labels; i++)
00058 {
00059 if (train_labels[i]==-1)
00060 classidx_neg[num_neg++]=i;
00061 else if (train_labels[i]==+1)
00062 classidx_pos[num_pos++]=i;
00063 else
00064 {
00065 SG_ERROR( "found label != +/- 1 bailing...");
00066 return false;
00067 }
00068 }
00069
00070 if (num_neg<=0 && num_pos<=0)
00071 {
00072 SG_ERROR( "whooooo ? only a single class found\n");
00073 return false;
00074 }
00075
00076 delete[] w;
00077 w=new float64_t[num_feat];
00078 w_dim=num_feat;
00079
00080 float64_t* mean_neg=new float64_t[num_feat];
00081 memset(mean_neg,0,num_feat*sizeof(float64_t));
00082
00083 float64_t* mean_pos=new float64_t[num_feat];
00084 memset(mean_pos,0,num_feat*sizeof(float64_t));
00085
00086
00087 double* scatter=new double[num_feat*num_feat];
00088 double* buffer=new double[num_feat*CMath::max(num_neg, num_pos)];
00089 int nf = (int) num_feat;
00090
00091 CSimpleFeatures<float64_t>* rf = (CSimpleFeatures<float64_t>*) features;
00092
00093 for (i=0; i<num_neg; i++)
00094 {
00095 int32_t vlen;
00096 bool vfree;
00097 float64_t* vec=
00098 rf->get_feature_vector(classidx_neg[i], vlen, vfree);
00099 ASSERT(vec);
00100
00101 for (j=0; j<vlen; j++)
00102 {
00103 mean_neg[j]+=vec[j];
00104 buffer[num_feat*i+j]=vec[j];
00105 }
00106
00107 rf->free_feature_vector(vec, classidx_neg[i], vfree);
00108 }
00109
00110 for (j=0; j<num_feat; j++)
00111 mean_neg[j]/=num_neg;
00112
00113 for (i=0; i<num_neg; i++)
00114 {
00115 for (j=0; j<num_feat; j++)
00116 buffer[num_feat*i+j]-=mean_neg[j];
00117 }
00118 cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf,
00119 (int) num_neg, 1.0, buffer, nf, buffer, nf, 0, scatter, nf);
00120
00121
00122 for (i=0; i<num_pos; i++)
00123 {
00124 int32_t vlen;
00125 bool vfree;
00126 float64_t* vec=
00127 rf->get_feature_vector(classidx_pos[i], vlen, vfree);
00128 ASSERT(vec);
00129
00130 for (j=0; j<vlen; j++)
00131 {
00132 mean_pos[j]+=vec[j];
00133 buffer[num_feat*i+j]=vec[j];
00134 }
00135
00136 rf->free_feature_vector(vec, classidx_pos[i], vfree);
00137 }
00138
00139 for (j=0; j<num_feat; j++)
00140 mean_pos[j]/=num_pos;
00141
00142 for (i=0; i<num_pos; i++)
00143 {
00144 for (j=0; j<num_feat; j++)
00145 buffer[num_feat*i+j]-=mean_pos[j];
00146 }
00147 cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, nf, nf, (int) num_pos,
00148 1.0/(num_train_labels-1), buffer, nf, buffer, nf,
00149 1.0/(num_train_labels-1), scatter, nf);
00150
00151 float64_t trace=CMath::trace((float64_t*) scatter, num_feat, num_feat);
00152
00153 double s=1.0-m_gamma;
00154 for (i=0; i<num_feat*num_feat; i++)
00155 scatter[i]*=s;
00156
00157 for (i=0; i<num_feat; i++)
00158 scatter[i*num_feat+i]+= trace*m_gamma/num_feat;
00159
00160 double* inv_scatter= (double*) CMath::pinv(
00161 scatter, num_feat, num_feat, NULL);
00162
00163 float64_t* w_pos=buffer;
00164 float64_t* w_neg=&buffer[num_feat];
00165
00166 cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00167 (double*) mean_pos, 1, 0., (double*) w_pos, 1);
00168 cblas_dsymv(CblasColMajor, CblasUpper, nf, 1.0, inv_scatter, nf,
00169 (double*) mean_neg, 1, 0, (double*) w_neg, 1);
00170
00171 bias=0.5*(CMath::dot(w_neg, mean_neg, num_feat)-CMath::dot(w_pos, mean_pos, num_feat));
00172 for (i=0; i<num_feat; i++)
00173 w[i]=w_pos[i]-w_neg[i];
00174
00175 #ifdef DEBUG_LDA
00176 SG_PRINT("bias: %f\n", bias);
00177 CMath::display_vector(w, num_feat, "w");
00178 CMath::display_vector(w_pos, num_feat, "w_pos");
00179 CMath::display_vector(w_neg, num_feat, "w_neg");
00180 CMath::display_vector(mean_pos, num_feat, "mean_pos");
00181 CMath::display_vector(mean_neg, num_feat, "mean_neg");
00182 #endif
00183
00184 delete[] train_labels;
00185 delete[] mean_neg;
00186 delete[] mean_pos;
00187 delete[] scatter;
00188 delete[] inv_scatter;
00189 delete[] classidx_neg;
00190 delete[] classidx_pos;
00191 delete[] buffer;
00192 return true;
00193 }
00194 #endif