00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014
00015 using namespace shogun;
00016
00017 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00018 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00019 {
00020 }
00021
00022 CMultiClassSVM::CMultiClassSVM(
00023 EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00024 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00025 {
00026 }
00027
00028 CMultiClassSVM::~CMultiClassSVM()
00029 {
00030 cleanup();
00031 }
00032
00033 void CMultiClassSVM::cleanup()
00034 {
00035 for (int32_t i=0; i<m_num_svms; i++)
00036 SG_UNREF(m_svms[i]);
00037
00038 delete[] m_svms;
00039 m_num_svms=0;
00040 m_svms=NULL;
00041 }
00042
00043 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00044 {
00045 if (num_classes>0)
00046 {
00047 cleanup();
00048
00049 m_num_classes=num_classes;
00050
00051 if (multiclass_type==ONE_VS_REST)
00052 m_num_svms=num_classes;
00053 else if (multiclass_type==ONE_VS_ONE)
00054 m_num_svms=num_classes*(num_classes-1)/2;
00055 else
00056 SG_ERROR("unknown multiclass type\n");
00057
00058 m_svms=new CSVM*[m_num_svms];
00059 if (m_svms)
00060 {
00061 memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00062 return true;
00063 }
00064 }
00065 return false;
00066 }
00067
00068 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00069 {
00070 if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00071 {
00072 SG_REF(svm);
00073 m_svms[num]=svm;
00074 return true;
00075 }
00076 return false;
00077 }
00078
00079 CLabels* CMultiClassSVM::classify()
00080 {
00081 if (multiclass_type==ONE_VS_REST)
00082 return classify_one_vs_rest();
00083 else if (multiclass_type==ONE_VS_ONE)
00084 return classify_one_vs_one();
00085 else
00086 SG_ERROR("unknown multiclass type\n");
00087
00088 return NULL;
00089 }
00090
00091 CLabels* CMultiClassSVM::classify_one_vs_one()
00092 {
00093 ASSERT(m_num_svms>0);
00094 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00095 CLabels* result=NULL;
00096
00097 if (!kernel)
00098 {
00099 SG_ERROR( "SVM can not proceed without kernel!\n");
00100 return false ;
00101 }
00102
00103 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00104 {
00105 int32_t num_vectors=kernel->get_num_vec_rhs();
00106
00107 result=new CLabels(num_vectors);
00108 SG_REF(result);
00109
00110 ASSERT(num_vectors==result->get_num_labels());
00111 CLabels** outputs=new CLabels*[m_num_svms];
00112
00113 for (int32_t i=0; i<m_num_svms; i++)
00114 {
00115 SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00116 ASSERT(m_svms[i]);
00117 m_svms[i]->set_kernel(kernel);
00118 outputs[i]=m_svms[i]->classify();
00119 }
00120
00121 int32_t* votes=new int32_t[m_num_classes];
00122 for (int32_t v=0; v<num_vectors; v++)
00123 {
00124 int32_t s=0;
00125 memset(votes, 0, sizeof(int32_t)*m_num_classes);
00126
00127 for (int32_t i=0; i<m_num_classes; i++)
00128 {
00129 for (int32_t j=i+1; j<m_num_classes; j++)
00130 {
00131 if (outputs[s++]->get_label(v)>0)
00132 votes[i]++;
00133 else
00134 votes[j]++;
00135 }
00136 }
00137
00138 int32_t winner=0;
00139 int32_t max_votes=votes[0];
00140
00141 for (int32_t i=1; i<m_num_classes; i++)
00142 {
00143 if (votes[i]>max_votes)
00144 {
00145 max_votes=votes[i];
00146 winner=i;
00147 }
00148 }
00149
00150 result->set_label(v, winner);
00151 }
00152
00153 delete[] votes;
00154
00155 for (int32_t i=0; i<m_num_svms; i++)
00156 SG_UNREF(outputs[i]);
00157 delete[] outputs;
00158 }
00159
00160 return result;
00161 }
00162
00163 CLabels* CMultiClassSVM::classify_one_vs_rest()
00164 {
00165 ASSERT(m_num_svms>0);
00166 CLabels* result=NULL;
00167
00168 if (!kernel)
00169 {
00170 SG_ERROR( "SVM can not proceed without kernel!\n");
00171 return false ;
00172 }
00173
00174 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00175 {
00176 int32_t num_vectors=kernel->get_num_vec_rhs();
00177
00178 result=new CLabels(num_vectors);
00179 SG_REF(result);
00180
00181 ASSERT(num_vectors==result->get_num_labels());
00182 CLabels** outputs=new CLabels*[m_num_svms];
00183
00184 for (int32_t i=0; i<m_num_svms; i++)
00185 {
00186 ASSERT(m_svms[i]);
00187 m_svms[i]->set_kernel(kernel);
00188 outputs[i]=m_svms[i]->classify();
00189 }
00190
00191 for (int32_t i=0; i<num_vectors; i++)
00192 {
00193 int32_t winner=0;
00194 float64_t max_out=outputs[0]->get_label(i);
00195
00196 for (int32_t j=1; j<m_num_svms; j++)
00197 {
00198 float64_t out=outputs[j]->get_label(i);
00199
00200 if (out>max_out)
00201 {
00202 winner=j;
00203 max_out=out;
00204 }
00205 }
00206
00207 result->set_label(i, winner);
00208 }
00209
00210 for (int32_t i=0; i<m_num_svms; i++)
00211 SG_UNREF(outputs[i]);
00212
00213 delete[] outputs;
00214 }
00215
00216 return result;
00217 }
00218
00219 float64_t CMultiClassSVM::classify_example(int32_t num)
00220 {
00221 if (multiclass_type==ONE_VS_REST)
00222 return classify_example_one_vs_rest(num);
00223 else if (multiclass_type==ONE_VS_ONE)
00224 return classify_example_one_vs_one(num);
00225 else
00226 SG_ERROR("unknown multiclass type\n");
00227
00228 return 0;
00229 }
00230
00231 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00232 {
00233 ASSERT(m_num_svms>0);
00234 float64_t* outputs=new float64_t[m_num_svms];
00235 int32_t winner=0;
00236 float64_t max_out=m_svms[0]->classify_example(num);
00237
00238 for (int32_t i=1; i<m_num_svms; i++)
00239 {
00240 outputs[i]=m_svms[i]->classify_example(num);
00241 if (outputs[i]>max_out)
00242 {
00243 winner=i;
00244 max_out=outputs[i];
00245 }
00246 }
00247 delete[] outputs;
00248
00249 return winner;
00250 }
00251
00252 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00253 {
00254 ASSERT(m_num_svms>0);
00255 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00256
00257 int32_t* votes=new int32_t[m_num_classes];
00258 int32_t s=0;
00259
00260 for (int32_t i=0; i<m_num_classes; i++)
00261 {
00262 for (int32_t j=i+1; j<m_num_classes; j++)
00263 {
00264 if (m_svms[s++]->classify_example(num)>0)
00265 votes[i]++;
00266 else
00267 votes[j]++;
00268 }
00269 }
00270
00271 int32_t winner=0;
00272 int32_t max_votes=votes[0];
00273
00274 for (int32_t i=1; i<m_num_classes; i++)
00275 {
00276 if (votes[i]>max_votes)
00277 {
00278 max_votes=votes[i];
00279 winner=i;
00280 }
00281 }
00282
00283 delete[] votes;
00284
00285 return winner;
00286 }
00287
00288 bool CMultiClassSVM::load(FILE* modelfl)
00289 {
00290 bool result=true;
00291 char char_buffer[1024];
00292 int32_t int_buffer;
00293 float64_t double_buffer;
00294 int32_t line_number=1;
00295 int32_t svm_idx=-1;
00296
00297 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00298 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00299 else
00300 {
00301 char_buffer[15]='\0';
00302 if (strcmp("%MultiClassSVM", char_buffer)!=0)
00303 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00304
00305 line_number++;
00306 }
00307
00308 int_buffer=0;
00309 if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00310 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00311
00312 if (!feof(modelfl))
00313 line_number++;
00314
00315 if (int_buffer != multiclass_type)
00316 SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00317
00318 int_buffer=0;
00319 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00320 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00321
00322 if (!feof(modelfl))
00323 line_number++;
00324
00325 if (int_buffer < 2)
00326 SG_ERROR("less than 2 classes - how is this multiclass?\n");
00327
00328 create_multiclass_svm(int_buffer);
00329
00330 int_buffer=0;
00331 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00332 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00333
00334 if (!feof(modelfl))
00335 line_number++;
00336
00337 if (m_num_svms != int_buffer)
00338 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00339
00340 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00341 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00342
00343 if (!feof(modelfl))
00344 line_number++;
00345
00346 for (int32_t n=0; n<m_num_svms; n++)
00347 {
00348 svm_idx=-1;
00349 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00350 {
00351 result=false;
00352 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00353 }
00354 else
00355 {
00356 char_buffer[4]='\0';
00357 if (strncmp("%SVM", char_buffer, 4)!=0)
00358 {
00359 result=false;
00360 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00361 }
00362
00363 if (svm_idx != n)
00364 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00365
00366 line_number++;
00367 }
00368
00369 int_buffer=0;
00370 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00371 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00372
00373 if (svm_idx != n)
00374 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00375
00376 if (!feof(modelfl))
00377 line_number++;
00378
00379 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00380 CSVM* svm=new CSVM(int_buffer);
00381
00382 double_buffer=0;
00383
00384 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00385 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00386
00387 if (svm_idx != n)
00388 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00389
00390 if (!feof(modelfl))
00391 line_number++;
00392
00393 svm->set_bias(double_buffer);
00394
00395 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00396 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00397
00398 if (svm_idx != n)
00399 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00400
00401 if (!feof(modelfl))
00402 line_number++;
00403
00404 for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00405 {
00406 double_buffer=0;
00407 int_buffer=0;
00408
00409 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00410 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00411
00412 if (!feof(modelfl))
00413 line_number++;
00414
00415 svm->set_support_vector(i, int_buffer);
00416 svm->set_alpha(i, double_buffer);
00417 }
00418
00419 if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00420 {
00421 result=false;
00422 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00423 }
00424 else
00425 {
00426 char_buffer[3]='\0';
00427 if (strcmp("];", char_buffer)!=0)
00428 {
00429 result=false;
00430 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00431 }
00432 line_number++;
00433 }
00434
00435 set_svm(n, svm);
00436 }
00437
00438 svm_loaded=result;
00439 return result;
00440 }
00441
00442 bool CMultiClassSVM::save(FILE* modelfl)
00443 {
00444 if (!kernel)
00445 SG_ERROR("Kernel not defined!\n");
00446
00447 if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00448 SG_ERROR("Multiclass SVM not trained!\n");
00449
00450 SG_INFO( "Writing model file...");
00451 fprintf(modelfl,"%%MultiClassSVM\n");
00452 fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00453 fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00454 fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00455 fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00456
00457 for (int32_t i=0; i<m_num_svms; i++)
00458 {
00459 CSVM* svm=m_svms[i];
00460 ASSERT(svm);
00461 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00462 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00463 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00464
00465 fprintf(modelfl, "alphas%d=[\n", i);
00466
00467 for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00468 {
00469 fprintf(modelfl,"\t[%+10.16e,%d];\n",
00470 svm->get_alpha(j), svm->get_support_vector(j));
00471 }
00472
00473 fprintf(modelfl, "];\n");
00474 }
00475
00476 SG_DONE();
00477 return true ;
00478 }