MultiClassSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014 
00015 using namespace shogun;
00016 
00017 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00018 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00019 {
00020 }
00021 
00022 CMultiClassSVM::CMultiClassSVM(
00023     EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00024 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00025 {
00026 }
00027 
00028 CMultiClassSVM::~CMultiClassSVM()
00029 {
00030     cleanup();
00031 }
00032 
00033 void CMultiClassSVM::cleanup()
00034 {
00035     for (int32_t i=0; i<m_num_svms; i++)
00036         SG_UNREF(m_svms[i]);
00037 
00038     delete[] m_svms;
00039     m_num_svms=0;
00040     m_svms=NULL;
00041 }
00042 
00043 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00044 {
00045     if (num_classes>0)
00046     {
00047         cleanup();
00048 
00049         m_num_classes=num_classes;
00050 
00051         if (multiclass_type==ONE_VS_REST)
00052             m_num_svms=num_classes;
00053         else if (multiclass_type==ONE_VS_ONE)
00054             m_num_svms=num_classes*(num_classes-1)/2;
00055         else
00056             SG_ERROR("unknown multiclass type\n");
00057 
00058         m_svms=new CSVM*[m_num_svms];
00059         if (m_svms)
00060         {
00061             memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00062             return true;
00063         }
00064     }
00065     return false;
00066 }
00067 
00068 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00069 {
00070     if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00071     {
00072         SG_REF(svm);
00073         m_svms[num]=svm;
00074         return true;
00075     }
00076     return false;
00077 }
00078 
00079 CLabels* CMultiClassSVM::classify()
00080 {
00081     if (multiclass_type==ONE_VS_REST)
00082         return classify_one_vs_rest();
00083     else if (multiclass_type==ONE_VS_ONE)
00084         return classify_one_vs_one();
00085     else
00086         SG_ERROR("unknown multiclass type\n");
00087 
00088     return NULL;
00089 }
00090 
00091 CLabels* CMultiClassSVM::classify_one_vs_one()
00092 {
00093     ASSERT(m_num_svms>0);
00094     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00095     CLabels* result=NULL;
00096 
00097     if (!kernel)
00098     {
00099         SG_ERROR( "SVM can not proceed without kernel!\n");
00100         return false ;
00101     }
00102 
00103     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00104     {
00105         int32_t num_vectors=kernel->get_num_vec_rhs();
00106 
00107         result=new CLabels(num_vectors);
00108         SG_REF(result);
00109 
00110         ASSERT(num_vectors==result->get_num_labels());
00111         CLabels** outputs=new CLabels*[m_num_svms];
00112 
00113         for (int32_t i=0; i<m_num_svms; i++)
00114         {
00115             SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00116             ASSERT(m_svms[i]);
00117             m_svms[i]->set_kernel(kernel);
00118             outputs[i]=m_svms[i]->classify();
00119         }
00120 
00121         int32_t* votes=new int32_t[m_num_classes];
00122         for (int32_t v=0; v<num_vectors; v++)
00123         {
00124             int32_t s=0;
00125             memset(votes, 0, sizeof(int32_t)*m_num_classes);
00126 
00127             for (int32_t i=0; i<m_num_classes; i++)
00128             {
00129                 for (int32_t j=i+1; j<m_num_classes; j++)
00130                 {
00131                     if (outputs[s++]->get_label(v)>0)
00132                         votes[i]++;
00133                     else
00134                         votes[j]++;
00135                 }
00136             }
00137 
00138             int32_t winner=0;
00139             int32_t max_votes=votes[0];
00140 
00141             for (int32_t i=1; i<m_num_classes; i++)
00142             {
00143                 if (votes[i]>max_votes)
00144                 {
00145                     max_votes=votes[i];
00146                     winner=i;
00147                 }
00148             }
00149 
00150             result->set_label(v, winner);
00151         }
00152 
00153         delete[] votes;
00154 
00155         for (int32_t i=0; i<m_num_svms; i++)
00156             SG_UNREF(outputs[i]);
00157         delete[] outputs;
00158     }
00159 
00160     return result;
00161 }
00162 
00163 CLabels* CMultiClassSVM::classify_one_vs_rest()
00164 {
00165     ASSERT(m_num_svms>0);
00166     CLabels* result=NULL;
00167 
00168     if (!kernel)
00169     {
00170         SG_ERROR( "SVM can not proceed without kernel!\n");
00171         return false ;
00172     }
00173 
00174     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00175     {
00176         int32_t num_vectors=kernel->get_num_vec_rhs();
00177 
00178         result=new CLabels(num_vectors);
00179         SG_REF(result);
00180 
00181         ASSERT(num_vectors==result->get_num_labels());
00182         CLabels** outputs=new CLabels*[m_num_svms];
00183 
00184         for (int32_t i=0; i<m_num_svms; i++)
00185         {
00186             ASSERT(m_svms[i]);
00187             m_svms[i]->set_kernel(kernel);
00188             outputs[i]=m_svms[i]->classify();
00189         }
00190 
00191         for (int32_t i=0; i<num_vectors; i++)
00192         {
00193             int32_t winner=0;
00194             float64_t max_out=outputs[0]->get_label(i);
00195 
00196             for (int32_t j=1; j<m_num_svms; j++)
00197             {
00198                 float64_t out=outputs[j]->get_label(i);
00199 
00200                 if (out>max_out)
00201                 {
00202                     winner=j;
00203                     max_out=out;
00204                 }
00205             }
00206 
00207             result->set_label(i, winner);
00208         }
00209 
00210         for (int32_t i=0; i<m_num_svms; i++)
00211             SG_UNREF(outputs[i]);
00212 
00213         delete[] outputs;
00214     }
00215 
00216     return result;
00217 }
00218 
00219 float64_t CMultiClassSVM::classify_example(int32_t num)
00220 {
00221     if (multiclass_type==ONE_VS_REST)
00222         return classify_example_one_vs_rest(num);
00223     else if (multiclass_type==ONE_VS_ONE)
00224         return classify_example_one_vs_one(num);
00225     else
00226         SG_ERROR("unknown multiclass type\n");
00227 
00228     return 0;
00229 }
00230 
00231 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00232 {
00233     ASSERT(m_num_svms>0);
00234     float64_t* outputs=new float64_t[m_num_svms];
00235     int32_t winner=0;
00236     float64_t max_out=m_svms[0]->classify_example(num);
00237 
00238     for (int32_t i=1; i<m_num_svms; i++)
00239     {
00240         outputs[i]=m_svms[i]->classify_example(num);
00241         if (outputs[i]>max_out)
00242         {
00243             winner=i;
00244             max_out=outputs[i];
00245         }
00246     }
00247     delete[] outputs;
00248 
00249     return winner;
00250 }
00251 
00252 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00253 {
00254     ASSERT(m_num_svms>0);
00255     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00256 
00257     int32_t* votes=new int32_t[m_num_classes];
00258     int32_t s=0;
00259 
00260     for (int32_t i=0; i<m_num_classes; i++)
00261     {
00262         for (int32_t j=i+1; j<m_num_classes; j++)
00263         {
00264             if (m_svms[s++]->classify_example(num)>0)
00265                 votes[i]++;
00266             else
00267                 votes[j]++;
00268         }
00269     }
00270 
00271     int32_t winner=0;
00272     int32_t max_votes=votes[0];
00273 
00274     for (int32_t i=1; i<m_num_classes; i++)
00275     {
00276         if (votes[i]>max_votes)
00277         {
00278             max_votes=votes[i];
00279             winner=i;
00280         }
00281     }
00282 
00283     delete[] votes;
00284 
00285     return winner;
00286 }
00287 
00288 bool CMultiClassSVM::load(FILE* modelfl)
00289 {
00290     bool result=true;
00291     char char_buffer[1024];
00292     int32_t int_buffer;
00293     float64_t double_buffer;
00294     int32_t line_number=1;
00295     int32_t svm_idx=-1;
00296 
00297     if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00298         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00299     else
00300     {
00301         char_buffer[15]='\0';
00302         if (strcmp("%MultiClassSVM", char_buffer)!=0)
00303             SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00304 
00305         line_number++;
00306     }
00307 
00308     int_buffer=0;
00309     if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00310         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00311 
00312     if (!feof(modelfl))
00313         line_number++;
00314 
00315     if (int_buffer != multiclass_type)
00316         SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00317 
00318     int_buffer=0;
00319     if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00320         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00321 
00322     if (!feof(modelfl))
00323         line_number++;
00324 
00325     if (int_buffer < 2)
00326         SG_ERROR("less than 2 classes - how is this multiclass?\n");
00327 
00328     create_multiclass_svm(int_buffer);
00329 
00330     int_buffer=0;
00331     if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00332         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00333 
00334     if (!feof(modelfl))
00335         line_number++;
00336 
00337     if (m_num_svms != int_buffer)
00338         SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00339 
00340     if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00341         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00342 
00343     if (!feof(modelfl))
00344         line_number++;
00345 
00346     for (int32_t n=0; n<m_num_svms; n++)
00347     {
00348         svm_idx=-1;
00349         if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00350         {
00351             result=false;
00352             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00353         }
00354         else
00355         {
00356             char_buffer[4]='\0';
00357             if (strncmp("%SVM", char_buffer, 4)!=0)
00358             {
00359                 result=false;
00360                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00361             }
00362 
00363             if (svm_idx != n)
00364                 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00365 
00366             line_number++;
00367         }
00368 
00369         int_buffer=0;
00370         if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00371             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00372 
00373         if (svm_idx != n)
00374             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00375 
00376         if (!feof(modelfl))
00377             line_number++;
00378 
00379         SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00380         CSVM* svm=new CSVM(int_buffer);
00381 
00382         double_buffer=0;
00383 
00384         if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00385             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00386 
00387         if (svm_idx != n)
00388             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00389 
00390         if (!feof(modelfl))
00391             line_number++;
00392 
00393         svm->set_bias(double_buffer);
00394 
00395         if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00396             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00397 
00398         if (svm_idx != n)
00399             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00400 
00401         if (!feof(modelfl))
00402             line_number++;
00403 
00404         for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00405         {
00406             double_buffer=0;
00407             int_buffer=0;
00408 
00409             if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00410                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00411 
00412             if (!feof(modelfl))
00413                 line_number++;
00414 
00415             svm->set_support_vector(i, int_buffer);
00416             svm->set_alpha(i, double_buffer);
00417         }
00418 
00419         if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00420         {
00421             result=false;
00422             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00423         }
00424         else
00425         {
00426             char_buffer[3]='\0';
00427             if (strcmp("];", char_buffer)!=0)
00428             {
00429                 result=false;
00430                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00431             }
00432             line_number++;
00433         }
00434 
00435         set_svm(n, svm);
00436     }
00437 
00438     svm_loaded=result;
00439     return result;
00440 }
00441 
00442 bool CMultiClassSVM::save(FILE* modelfl)
00443 {
00444     if (!kernel)
00445         SG_ERROR("Kernel not defined!\n");
00446 
00447     if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00448         SG_ERROR("Multiclass SVM not trained!\n");
00449 
00450     SG_INFO( "Writing model file...");
00451     fprintf(modelfl,"%%MultiClassSVM\n");
00452     fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00453     fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00454     fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00455     fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00456 
00457     for (int32_t i=0; i<m_num_svms; i++)
00458     {
00459         CSVM* svm=m_svms[i];
00460         ASSERT(svm);
00461         fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00462         fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00463         fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00464 
00465         fprintf(modelfl, "alphas%d=[\n", i);
00466 
00467         for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00468         {
00469             fprintf(modelfl,"\t[%+10.16e,%d];\n",
00470                     svm->get_alpha(j), svm->get_support_vector(j));
00471         }
00472 
00473         fprintf(modelfl, "];\n");
00474     }
00475 
00476     SG_DONE();
00477     return true ;
00478 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation