Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include "classifier/svm/ScatterSVM.h"
00013 #include "lib/io.h"
00014
00015 using namespace shogun;
00016
00017 CScatterSVM::CScatterSVM()
00018 : CMultiClassSVM(ONE_VS_REST), model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0)
00019 {
00020 }
00021
00022 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab)
00023 : CMultiClassSVM(ONE_VS_REST, C, k, lab), model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0)
00024 {
00025 }
00026
00027 CScatterSVM::~CScatterSVM()
00028 {
00029 delete[] norm_wc;
00030 delete[] norm_wcw;
00031
00032 }
00033
00034 bool CScatterSVM::train(CFeatures* data)
00035 {
00036 struct svm_node* x_space;
00037
00038 ASSERT(labels && labels->get_num_labels());
00039 int32_t num_classes = labels->get_num_classes();
00040
00041 if (data)
00042 {
00043 if (labels->get_num_labels() != data->get_num_vectors())
00044 SG_ERROR("Number of training vectors does not match number of labels\n");
00045 kernel->init(data, data);
00046 }
00047
00048 problem.l=labels->get_num_labels();
00049 SG_INFO( "%d trainlabels\n", problem.l);
00050
00051 problem.y=new float64_t[problem.l];
00052 problem.x=new struct svm_node*[problem.l];
00053 x_space=new struct svm_node[2*problem.l];
00054
00055 for (int32_t i=0; i<problem.l; i++)
00056 {
00057 problem.y[i]=labels->get_label(i);
00058 problem.x[i]=&x_space[2*i];
00059 x_space[2*i].index=i;
00060 x_space[2*i+1].index=-1;
00061 }
00062
00063 int32_t weights_label[2]={-1,+1};
00064 float64_t weights[2]={1.0,get_C2()/get_C1()};
00065
00066 ASSERT(kernel && kernel->has_features());
00067 ASSERT(kernel->get_num_vec_lhs()==problem.l);
00068
00069 param.svm_type=NU_MULTICLASS_SVC;
00070 param.kernel_type = LINEAR;
00071 param.degree = 3;
00072 param.gamma = 0;
00073 param.coef0 = 0;
00074 param.nu = get_nu();
00075 param.kernel=kernel;
00076 param.cache_size = kernel->get_cache_size();
00077 param.C = 0;
00078 param.eps = epsilon;
00079 param.p = 0.1;
00080 param.shrinking = 0;
00081 param.nr_weight = 2;
00082 param.weight_label = weights_label;
00083 param.weight = weights;
00084 param.nr_class=num_classes;
00085 param.use_bias = get_bias_enabled();
00086
00087 int32_t* numc=new int32_t[num_classes];
00088 CMath::fill_vector(numc, num_classes, 0);
00089
00090 for (int32_t i=0; i<problem.l; i++)
00091 numc[(int32_t) problem.y[i]]++;
00092
00093 int32_t Nc=0;
00094 int32_t Nmin=problem.l;
00095 for (int32_t i=0; i<num_classes; i++)
00096 {
00097 if (numc[i]>0)
00098 {
00099 Nc++;
00100 Nmin=CMath::min(Nmin, numc[i]);
00101 }
00102
00103 }
00104
00105 float64_t nu_min=((float64_t) Nc)/problem.l;
00106 float64_t nu_max=((float64_t) Nc)*Nmin/problem.l;
00107
00108 SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max);
00109
00110 if (param.nu<nu_min || param.nu>nu_max)
00111 SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max);
00112
00113 const char* error_msg = svm_check_parameter(&problem,¶m);
00114
00115 if(error_msg)
00116 SG_ERROR("Error: %s\n",error_msg);
00117
00118 model = svm_train(&problem, ¶m);
00119
00120 if (model)
00121 {
00122 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00123
00124 ASSERT(model->nr_class==num_classes);
00125 create_multiclass_svm(num_classes);
00126
00127 rho=model->rho[0];
00128
00129 delete[] norm_wcw;
00130 norm_wcw = new float64_t[m_num_svms];
00131
00132 for (int32_t i=0; i<num_classes; i++)
00133 {
00134 int32_t num_sv=model->nSV[i];
00135
00136 CSVM* svm=new CSVM(num_sv);
00137 svm->set_bias(model->rho[i+1]);
00138 norm_wcw[i]=model->normwcw[i];
00139
00140
00141 for (int32_t j=0; j<num_sv; j++)
00142 {
00143 svm->set_alpha(j, model->sv_coef[i][j]);
00144 svm->set_support_vector(j, model->SV[i][j].index);
00145 }
00146
00147 set_svm(i, svm);
00148 }
00149
00150 delete[] problem.x;
00151 delete[] problem.y;
00152 delete[] x_space;
00153 for (int32_t i=0; i<num_classes; i++)
00154 {
00155 free(model->SV[i]);
00156 model->SV[i]=NULL;
00157 }
00158 svm_destroy_model(model);
00159 compute_norm_wc();
00160
00161 model=NULL;
00162 return true;
00163 }
00164 else
00165 return false;
00166 }
00167
00168 void CScatterSVM::compute_norm_wc()
00169 {
00170 delete[] norm_wc;
00171 norm_wc = new float64_t[m_num_svms];
00172 for (int32_t i=0; i<m_num_svms; i++)
00173 norm_wc[i]=0;
00174
00175
00176 for (int c=0; c<m_num_svms; c++)
00177 {
00178 CSVM* svm=m_svms[c];
00179 int32_t num_sv = svm->get_num_support_vectors();
00180
00181 for (int32_t i=0; i<num_sv; i++)
00182 {
00183 int32_t ii=svm->get_support_vector(i);
00184 for (int32_t j=0; j<num_sv; j++)
00185 {
00186 int32_t jj=svm->get_support_vector(j);
00187 norm_wc[c]+=svm->get_alpha(i)*kernel->kernel(ii,jj)*svm->get_alpha(j);
00188 }
00189 }
00190 }
00191
00192 for (int32_t i=0; i<m_num_svms; i++)
00193 norm_wc[i]=CMath::sqrt(norm_wc[i]);
00194
00195 CMath::display_vector(norm_wc, m_num_svms, "norm_wc");
00196 }
00197
00198 CLabels* CScatterSVM::classify_one_vs_rest()
00199 {
00200 ASSERT(m_num_svms>0);
00201 CLabels* output=NULL;
00202 if (!kernel)
00203 {
00204 SG_ERROR( "SVM can not proceed without kernel!\n");
00205 return false ;
00206 }
00207
00208 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00209 {
00210 int32_t num_vectors=kernel->get_num_vec_rhs();
00211
00212 output=new CLabels(num_vectors);
00213 SG_REF(output);
00214
00215 for (int32_t i=0; i<num_vectors; i++)
00216 {
00217 output->set_label(i, classify_example(i));
00218 }
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255 }
00256
00257 return output;
00258 }
00259
00260 float64_t CScatterSVM::classify_example(int32_t num)
00261 {
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282 ASSERT(m_num_svms>0);
00283 float64_t* outputs=new float64_t[m_num_svms];
00284 int32_t winner=0;
00285
00286 for (int32_t c=0; c<m_num_svms; c++)
00287 outputs[c]=m_svms[c]->get_bias()-rho;
00288
00289 for (int32_t c=0; c<m_num_svms; c++)
00290 {
00291 float64_t v=0;
00292
00293 for (int32_t i=0; i<m_svms[c]->get_num_support_vectors(); i++)
00294 {
00295 float64_t alpha=m_svms[c]->get_alpha(i);
00296 int32_t svidx=m_svms[c]->get_support_vector(i);
00297 v += alpha*kernel->kernel(svidx, num);
00298 }
00299
00300 outputs[c] += v;
00301 for (int32_t j=0; j<m_num_svms; j++)
00302 outputs[j] -= v/m_num_svms;
00303 }
00304
00305 for (int32_t j=0; j<m_num_svms; j++)
00306 outputs[j]/=norm_wcw[j];
00307
00308 float64_t max_out=outputs[0];
00309 for (int32_t j=0; j<m_num_svms; j++)
00310 {
00311 if (outputs[j]>max_out)
00312 {
00313 max_out=outputs[j];
00314 winner=j;
00315 }
00316 }
00317
00318 delete[] outputs;
00319
00320
00321
00322 return winner;
00323 }