ScatterSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009 Soeren Sonnenburg
00008  * Written (W) 2009 Marius Kloft
00009  * Copyright (C) 2009 TU Berlin and Max-Planck-Society
00010  */
00011 
00012 #include "classifier/svm/ScatterSVM.h"
00013 #include "lib/io.h"
00014 
00015 using namespace shogun;
00016 
00017 CScatterSVM::CScatterSVM()
00018 : CMultiClassSVM(ONE_VS_REST), model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0)
00019 {
00020 }
00021 
00022 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab)
00023 : CMultiClassSVM(ONE_VS_REST, C, k, lab), model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0)
00024 {
00025 }
00026 
00027 CScatterSVM::~CScatterSVM()
00028 {
00029     delete[] norm_wc;
00030     delete[] norm_wcw;
00031     //SG_PRINT("deleting ScatterSVM\n");
00032 }
00033 
00034 bool CScatterSVM::train(CFeatures* data)
00035 {
00036     struct svm_node* x_space;
00037 
00038     ASSERT(labels && labels->get_num_labels());
00039     int32_t num_classes = labels->get_num_classes();
00040 
00041     if (data)
00042     {
00043         if (labels->get_num_labels() != data->get_num_vectors())
00044             SG_ERROR("Number of training vectors does not match number of labels\n");
00045         kernel->init(data, data);
00046     }
00047 
00048     problem.l=labels->get_num_labels();
00049     SG_INFO( "%d trainlabels\n", problem.l);
00050 
00051     problem.y=new float64_t[problem.l];
00052     problem.x=new struct svm_node*[problem.l];
00053     x_space=new struct svm_node[2*problem.l];
00054 
00055     for (int32_t i=0; i<problem.l; i++)
00056     {
00057         problem.y[i]=labels->get_label(i);
00058         problem.x[i]=&x_space[2*i];
00059         x_space[2*i].index=i;
00060         x_space[2*i+1].index=-1;
00061     }
00062 
00063     int32_t weights_label[2]={-1,+1};
00064     float64_t weights[2]={1.0,get_C2()/get_C1()};
00065 
00066     ASSERT(kernel && kernel->has_features());
00067     ASSERT(kernel->get_num_vec_lhs()==problem.l);
00068 
00069     param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM
00070     param.kernel_type = LINEAR;
00071     param.degree = 3;
00072     param.gamma = 0;    // 1/k
00073     param.coef0 = 0;
00074     param.nu = get_nu(); // Nu
00075     param.kernel=kernel;
00076     param.cache_size = kernel->get_cache_size();
00077     param.C = 0;
00078     param.eps = epsilon;
00079     param.p = 0.1;
00080     param.shrinking = 0;
00081     param.nr_weight = 2;
00082     param.weight_label = weights_label;
00083     param.weight = weights;
00084     param.nr_class=num_classes;
00085     param.use_bias = get_bias_enabled();
00086 
00087     int32_t* numc=new int32_t[num_classes];
00088     CMath::fill_vector(numc, num_classes, 0);
00089 
00090     for (int32_t i=0; i<problem.l; i++)
00091         numc[(int32_t) problem.y[i]]++;
00092 
00093     int32_t Nc=0;
00094     int32_t Nmin=problem.l;
00095     for (int32_t i=0; i<num_classes; i++)
00096     {
00097         if (numc[i]>0)
00098         {
00099             Nc++;
00100             Nmin=CMath::min(Nmin, numc[i]);
00101         }
00102 
00103     }
00104 
00105     float64_t nu_min=((float64_t) Nc)/problem.l;
00106     float64_t nu_max=((float64_t) Nc)*Nmin/problem.l;
00107 
00108     SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max);
00109 
00110     if (param.nu<nu_min || param.nu>nu_max)
00111         SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max);
00112 
00113     const char* error_msg = svm_check_parameter(&problem,&param);
00114 
00115     if(error_msg)
00116         SG_ERROR("Error: %s\n",error_msg);
00117 
00118     model = svm_train(&problem, &param);
00119 
00120     if (model)
00121     {
00122         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00123 
00124         ASSERT(model->nr_class==num_classes);
00125         create_multiclass_svm(num_classes);
00126 
00127         rho=model->rho[0];
00128 
00129         delete[] norm_wcw;
00130         norm_wcw = new float64_t[m_num_svms];
00131 
00132         for (int32_t i=0; i<num_classes; i++)
00133         {
00134             int32_t num_sv=model->nSV[i];
00135 
00136             CSVM* svm=new CSVM(num_sv);
00137             svm->set_bias(model->rho[i+1]);
00138             norm_wcw[i]=model->normwcw[i];
00139 
00140 
00141             for (int32_t j=0; j<num_sv; j++)
00142             {
00143                 svm->set_alpha(j, model->sv_coef[i][j]);
00144                 svm->set_support_vector(j, model->SV[i][j].index);
00145             }
00146 
00147             set_svm(i, svm);
00148         }
00149 
00150         delete[] problem.x;
00151         delete[] problem.y;
00152         delete[] x_space;
00153         for (int32_t i=0; i<num_classes; i++)
00154         {
00155             free(model->SV[i]);
00156             model->SV[i]=NULL;
00157         }
00158         svm_destroy_model(model);
00159         compute_norm_wc();
00160 
00161         model=NULL;
00162         return true;
00163     }
00164     else
00165         return false;
00166 }
00167 
00168 void CScatterSVM::compute_norm_wc()
00169 {
00170     delete[] norm_wc;
00171     norm_wc = new float64_t[m_num_svms];
00172     for (int32_t i=0; i<m_num_svms; i++)
00173         norm_wc[i]=0;
00174 
00175 
00176     for (int c=0; c<m_num_svms; c++)
00177     {
00178         CSVM* svm=m_svms[c];
00179         int32_t num_sv = svm->get_num_support_vectors();
00180 
00181         for (int32_t i=0; i<num_sv; i++)
00182         {
00183             int32_t ii=svm->get_support_vector(i);
00184             for (int32_t j=0; j<num_sv; j++)
00185             {
00186                 int32_t jj=svm->get_support_vector(j);
00187                 norm_wc[c]+=svm->get_alpha(i)*kernel->kernel(ii,jj)*svm->get_alpha(j);
00188             }
00189         }
00190     }
00191 
00192     for (int32_t i=0; i<m_num_svms; i++)
00193         norm_wc[i]=CMath::sqrt(norm_wc[i]);
00194 
00195     CMath::display_vector(norm_wc, m_num_svms, "norm_wc");
00196 }
00197 
00198 CLabels* CScatterSVM::classify_one_vs_rest()
00199 {
00200     ASSERT(m_num_svms>0);
00201     CLabels* output=NULL;
00202     if (!kernel)
00203     {
00204         SG_ERROR( "SVM can not proceed without kernel!\n");
00205         return false ;
00206     }
00207 
00208     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00209     {
00210         int32_t num_vectors=kernel->get_num_vec_rhs();
00211 
00212         output=new CLabels(num_vectors);
00213         SG_REF(output);
00214 
00215         for (int32_t i=0; i<num_vectors; i++)
00216         {
00217             output->set_label(i, classify_example(i));
00218         }
00219 /*
00220         ASSERT(num_vectors==output->get_num_labels());
00221         CLabels** outputs=new CLabels*[m_num_svms];
00222 
00223         for (int32_t i=0; i<m_num_svms; i++)
00224         {
00225             ASSERT(m_svms[i]);
00226             m_svms[i]->set_kernel(kernel);
00227             m_svms[i]->set_labels(labels);
00228             outputs[i]=m_svms[i]->classify();
00229         }
00230 
00231         for (int32_t i=0; i<num_vectors; i++)
00232         {
00233             int32_t winner=0;
00234             float64_t max_out=outputs[0]->get_label(i)/norm_wc[0];
00235 
00236             for (int32_t j=1; j<m_num_svms; j++)
00237             {
00238                 float64_t out=outputs[j]->get_label(i)/norm_wc[j];
00239 
00240                 if (out>max_out)
00241                 {
00242                     winner=j;
00243                     max_out=out;
00244                 }
00245             }
00246 
00247             output->set_label(i, winner);
00248         }
00249 
00250         for (int32_t i=0; i<m_num_svms; i++)
00251             SG_UNREF(outputs[i]);
00252 
00253         delete[] outputs;
00254         */
00255     }
00256 
00257     return output;
00258 }
00259 
00260 float64_t CScatterSVM::classify_example(int32_t num)
00261 {
00262     /*
00263     ASSERT(m_num_svms>0);
00264     float64_t* outputs=new float64_t[m_num_svms];
00265     int32_t winner=0;
00266     float64_t max_out=m_svms[0]->classify_example(num)/norm_wc[0];
00267 
00268     for (int32_t i=1; i<m_num_svms; i++)
00269     {
00270         outputs[i]=m_svms[i]->classify_example(num)/norm_wc[i];
00271         if (outputs[i]>max_out)
00272         {
00273             winner=i;
00274             max_out=outputs[i];
00275         }
00276     }
00277     delete[] outputs;
00278 
00279     return winner;
00280     */
00281 
00282     ASSERT(m_num_svms>0);
00283     float64_t* outputs=new float64_t[m_num_svms];
00284     int32_t winner=0;
00285 
00286     for (int32_t c=0; c<m_num_svms; c++)
00287         outputs[c]=m_svms[c]->get_bias()-rho;
00288 
00289     for (int32_t c=0; c<m_num_svms; c++)
00290     {
00291         float64_t v=0;
00292 
00293         for (int32_t i=0; i<m_svms[c]->get_num_support_vectors(); i++)
00294         {
00295             float64_t alpha=m_svms[c]->get_alpha(i);
00296             int32_t svidx=m_svms[c]->get_support_vector(i);
00297             v += alpha*kernel->kernel(svidx, num);
00298         }
00299 
00300         outputs[c] += v;
00301         for (int32_t j=0; j<m_num_svms; j++)
00302             outputs[j] -= v/m_num_svms;
00303     }
00304 
00305     for (int32_t j=0; j<m_num_svms; j++)
00306         outputs[j]/=norm_wcw[j];
00307 
00308     float64_t max_out=outputs[0];
00309     for (int32_t j=0; j<m_num_svms; j++)
00310     {
00311         if (outputs[j]>max_out)
00312         {
00313             max_out=outputs[j];
00314             winner=j;
00315         }
00316     }
00317 
00318     delete[] outputs;
00319 
00320     //SG_PRINT("winner = %d\n", winner);
00321 
00322     return winner;
00323 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation