SVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "base/Parallel.h"
00014 
00015 #include "classifier/svm/SVM.h"
00016 #include "classifier/mkl/MKL.h"
00017 
00018 #include <string.h>
00019 
00020 #ifndef WIN32
00021 #include <pthread.h>
00022 #endif
00023 
00024 #ifdef HAVE_BOOST_SERIALIZATION
00025 #include <boost/serialization/export.hpp>
00026 //BOOST_SERIALIZATION_ASSUME_ABSTRACT(CSVM);
00027 #endif //HAVE_BOOST_SERIALIZATION
00028 
00029 using namespace shogun;
00030 
00031 CSVM::CSVM(int32_t num_sv)
00032 : CKernelMachine()
00033 {
00034     set_defaults(num_sv);
00035 
00036     parameters.add(&C1, "C1");
00037     parameters.add(&C2, "C2");
00038 }
00039 
00040 CSVM::CSVM(float64_t C, CKernel* k, CLabels* lab)
00041 : CKernelMachine()
00042 {
00043     set_defaults();
00044     set_C(C,C);
00045     set_labels(lab);
00046     set_kernel(k);
00047 
00048     parameters.add(&C1, "C1");
00049     parameters.add(&C2, "C2");
00050 }
00051 
00052 CSVM::~CSVM()
00053 {
00054     SG_UNREF(mkl);
00055 }
00056 
00057 void CSVM::set_defaults(int32_t num_sv)
00058 {
00059     callback=NULL;
00060     mkl=NULL;
00061 
00062     svm_loaded=false;
00063 
00064     epsilon=1e-5;
00065     tube_epsilon=1e-2;
00066 
00067     nu=0.5;
00068     C1=1;
00069     C2=1;
00070 
00071     objective=0;
00072 
00073     qpsize=41;
00074     use_bias=true;
00075     use_shrinking=true;
00076     use_batch_computation=true;
00077     use_linadd=true;
00078 
00079     if (num_sv>0)
00080         create_new_model(num_sv);
00081 }
00082 
00083 bool CSVM::load(FILE* modelfl)
00084 {
00085     bool result=true;
00086     char char_buffer[1024];
00087     int32_t int_buffer;
00088     float64_t double_buffer;
00089     int32_t line_number=1;
00090 
00091     if (fscanf(modelfl,"%4s\n", char_buffer)==EOF)
00092     {
00093         result=false;
00094         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00095     }
00096     else
00097     {
00098         char_buffer[4]='\0';
00099         if (strcmp("%SVM", char_buffer)!=0)
00100         {
00101             result=false;
00102             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00103         }
00104         line_number++;
00105     }
00106 
00107     int_buffer=0;
00108     if (fscanf(modelfl," numsv=%d; \n", &int_buffer) != 1)
00109     {
00110         result=false;
00111         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00112     }
00113 
00114     if (!feof(modelfl))
00115         line_number++;
00116 
00117     SG_INFO( "loading %ld support vectors\n",int_buffer);
00118     create_new_model(int_buffer);
00119 
00120     if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00121     {
00122         result=false;
00123         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00124     }
00125 
00126     if (!feof(modelfl))
00127         line_number++;
00128 
00129     double_buffer=0;
00130 
00131     if (fscanf(modelfl," b=%lf; \n", &double_buffer) != 1)
00132     {
00133         result=false;
00134         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00135     }
00136 
00137     if (!feof(modelfl))
00138         line_number++;
00139 
00140     set_bias(double_buffer);
00141 
00142     if (fscanf(modelfl,"%8s\n", char_buffer) == EOF)
00143     {
00144         result=false;
00145         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00146     }
00147     else
00148     {
00149         char_buffer[9]='\0';
00150         if (strcmp("alphas=[", char_buffer)!=0)
00151         {
00152             result=false;
00153             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00154         }
00155         line_number++;
00156     }
00157 
00158     for (int32_t i=0; i<get_num_support_vectors(); i++)
00159     {
00160         double_buffer=0;
00161         int_buffer=0;
00162 
00163         if (fscanf(modelfl," \[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00164         {
00165             result=false;
00166             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00167         }
00168 
00169         if (!feof(modelfl))
00170             line_number++;
00171 
00172         set_support_vector(i, int_buffer);
00173         set_alpha(i, double_buffer);
00174     }
00175 
00176     if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00177     {
00178         result=false;
00179         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00180     }
00181     else
00182     {
00183         char_buffer[3]='\0';
00184         if (strcmp("];", char_buffer)!=0)
00185         {
00186             result=false;
00187             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00188         }
00189         line_number++;
00190     }
00191 
00192     svm_loaded=result;
00193     return result;
00194 }
00195 
00196 bool CSVM::save(FILE* modelfl)
00197 {
00198     if (!kernel)
00199         SG_ERROR("Kernel not defined!\n");
00200 
00201     SG_INFO( "Writing model file...");
00202     fprintf(modelfl,"%%SVM\n");
00203     fprintf(modelfl,"numsv=%d;\n", get_num_support_vectors());
00204     fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00205     fprintf(modelfl,"b=%+10.16e;\n",get_bias());
00206 
00207     fprintf(modelfl, "alphas=\[\n");
00208 
00209     for(int32_t i=0; i<get_num_support_vectors(); i++)
00210         fprintf(modelfl,"\t[%+10.16e,%d];\n",
00211                 CSVM::get_alpha(i), get_support_vector(i));
00212 
00213     fprintf(modelfl, "];\n");
00214 
00215     SG_DONE();
00216     return true ;
00217 }
00218 
00219 void CSVM::set_callback_function(CMKL* m, bool (*cb)
00220         (CMKL* mkl, const float64_t* sumw, const float64_t suma))
00221 {
00222     SG_UNREF(mkl);
00223     mkl=m;
00224     SG_REF(mkl);
00225 
00226     callback=cb;
00227 }
00228 
00229 float64_t CSVM::compute_svm_dual_objective()
00230 {
00231     int32_t n=get_num_support_vectors();
00232 
00233     if (labels && kernel)
00234     {
00235         objective=0;
00236         for (int32_t i=0; i<n; i++)
00237         {
00238             int32_t ii=get_support_vector(i);
00239             objective-=get_alpha(i)*labels->get_label(ii);
00240 
00241             for (int32_t j=0; j<n; j++)
00242             {
00243                 int32_t jj=get_support_vector(j);
00244                 objective+=0.5*get_alpha(i)*get_alpha(j)*kernel->kernel(ii,jj);
00245             }
00246         }
00247     }
00248     else
00249         SG_ERROR( "cannot compute objective, labels or kernel not set\n");
00250 
00251     return objective;
00252 }
00253 
00254 float64_t CSVM::compute_svm_primal_objective()
00255 {
00256     int32_t n=get_num_support_vectors();
00257     float64_t regularizer=0;
00258     float64_t loss=0;
00259 
00260     if (labels && kernel)
00261     {
00262         for (int32_t i=0; i<n; i++)
00263         {
00264             int32_t ii=get_support_vector(i);
00265             for (int32_t j=0; j<n; j++)
00266             {
00267                 int32_t jj=get_support_vector(j);
00268                 regularizer-=0.5*get_alpha(i)*get_alpha(j)*kernel->kernel(ii,jj);
00269             }
00270 
00271             loss-=C1*CMath::max(0.0, 1.0-get_label(ii)*classify_example(ii));
00272         }
00273     }
00274     else
00275         SG_ERROR( "cannot compute objective, labels or kernel not set\n");
00276 
00277     return regularizer+loss;
00278 }
00279 
00280 
00281 float64_t* CSVM::get_linear_term_array() {
00282 
00283     float64_t* a = new float64_t[linear_term.size()];
00284     std::copy( linear_term.begin(), linear_term.end(), a);
00285 
00286     return a;
00287 
00288 }
00289 
00290 
00291 
00292 void CSVM::set_linear_term(std::vector<float64_t> lin)
00293 {
00294 
00295     if (!labels)
00296         SG_ERROR("Please assign labels first!\n");
00297 
00298     int32_t num_labels=labels->get_num_labels();
00299 
00300     if (num_labels!=(int32_t) lin.size())
00301     {
00302         SG_ERROR("Number of labels (%d) does not match number"
00303                 "of entries (%d) in linear term \n", num_labels, lin.size());
00304     }
00305 
00306     linear_term = lin;
00307 
00308 }
00309 
00310 
00311 std::vector<float64_t> CSVM::get_linear_term() {
00312 
00313     return linear_term;
00314 
00315 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation