00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014
00015 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00016 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00017 {
00018 }
00019
00020 CMultiClassSVM::CMultiClassSVM(
00021 EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00022 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00023 {
00024 }
00025
00026 CMultiClassSVM::~CMultiClassSVM()
00027 {
00028 cleanup();
00029 }
00030
00031 void CMultiClassSVM::cleanup()
00032 {
00033 for (int32_t i=0; i<m_num_svms; i++)
00034 SG_UNREF(m_svms[i]);
00035
00036 delete[] m_svms;
00037 m_num_svms=0;
00038 m_svms=NULL;
00039 }
00040
00041 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00042 {
00043 if (num_classes>0)
00044 {
00045 m_num_classes=num_classes;
00046
00047 if (multiclass_type==ONE_VS_REST)
00048 m_num_svms=num_classes;
00049 else if (multiclass_type==ONE_VS_ONE)
00050 m_num_svms=num_classes*(num_classes-1)/2;
00051 else
00052 SG_ERROR("unknown multiclass type\n");
00053
00054 m_svms=new CSVM*[m_num_svms];
00055 if (m_svms)
00056 {
00057 memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00058 return true;
00059 }
00060 }
00061 return false;
00062 }
00063
00064 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00065 {
00066 if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00067 {
00068 SG_REF(svm);
00069 m_svms[num]=svm;
00070 return true;
00071 }
00072 return false;
00073 }
00074
00075 CLabels* CMultiClassSVM::classify(CLabels* result)
00076 {
00077 if (multiclass_type==ONE_VS_REST)
00078 return classify_one_vs_rest(result);
00079 else if (multiclass_type==ONE_VS_ONE)
00080 return classify_one_vs_one(result);
00081 else
00082 SG_ERROR("unknown multiclass type\n");
00083
00084 return NULL;
00085 }
00086
00087 CLabels* CMultiClassSVM::classify_one_vs_one(CLabels* result)
00088 {
00089 ASSERT(m_num_svms>0);
00090 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00091
00092 if (!kernel)
00093 {
00094 SG_ERROR( "SVM can not proceed without kernel!\n");
00095 return false ;
00096 }
00097
00098 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00099 {
00100 int32_t num_vectors=kernel->get_num_vec_rhs();
00101
00102 if (!result)
00103 {
00104 result=new CLabels(num_vectors);
00105 SG_REF(result);
00106 }
00107
00108 ASSERT(num_vectors==result->get_num_labels());
00109 CLabels** outputs=new CLabels*[m_num_svms];
00110
00111 for (int32_t i=0; i<m_num_svms; i++)
00112 {
00113 SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00114 ASSERT(m_svms[i]);
00115 m_svms[i]->set_kernel(kernel);
00116 m_svms[i]->set_labels(labels);
00117 outputs[i]=m_svms[i]->classify();
00118 }
00119
00120 int32_t* votes=new int32_t[m_num_classes];
00121 for (int32_t v=0; v<num_vectors; v++)
00122 {
00123 int32_t s=0;
00124 memset(votes, 0, sizeof(int32_t)*m_num_classes);
00125
00126 for (int32_t i=0; i<m_num_classes; i++)
00127 {
00128 for (int32_t j=i+1; j<m_num_classes; j++)
00129 {
00130 if (outputs[s++]->get_label(v)>0)
00131 votes[i]++;
00132 else
00133 votes[j]++;
00134 }
00135 }
00136
00137 int32_t winner=0;
00138 int32_t max_votes=votes[0];
00139
00140 for (int32_t i=1; i<m_num_classes; i++)
00141 {
00142 if (votes[i]>max_votes)
00143 {
00144 max_votes=votes[i];
00145 winner=i;
00146 }
00147 }
00148
00149 result->set_label(v, winner);
00150 }
00151
00152 delete[] votes;
00153
00154 for (int32_t i=0; i<m_num_svms; i++)
00155 SG_UNREF(outputs[i]);
00156 delete[] outputs;
00157 }
00158
00159 return result;
00160 }
00161
00162 CLabels* CMultiClassSVM::classify_one_vs_rest(CLabels* result)
00163 {
00164 ASSERT(m_num_svms>0);
00165
00166 if (!kernel)
00167 {
00168 SG_ERROR( "SVM can not proceed without kernel!\n");
00169 return false ;
00170 }
00171
00172 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00173 {
00174 int32_t num_vectors=kernel->get_num_vec_rhs();
00175
00176 if (!result)
00177 {
00178 result=new CLabels(num_vectors);
00179 SG_REF(result);
00180 }
00181
00182 ASSERT(num_vectors==result->get_num_labels());
00183 CLabels** outputs=new CLabels*[m_num_svms];
00184
00185 for (int32_t i=0; i<m_num_svms; i++)
00186 {
00187 ASSERT(m_svms[i]);
00188 m_svms[i]->set_kernel(kernel);
00189 m_svms[i]->set_labels(labels);
00190 outputs[i]=m_svms[i]->classify();
00191 }
00192
00193 for (int32_t i=0; i<num_vectors; i++)
00194 {
00195 int32_t winner=0;
00196 float64_t max_out=outputs[0]->get_label(i);
00197
00198 for (int32_t j=1; j<m_num_svms; j++)
00199 {
00200 float64_t out=outputs[j]->get_label(i);
00201
00202 if (out>max_out)
00203 {
00204 winner=j;
00205 max_out=out;
00206 }
00207 }
00208
00209 result->set_label(i, winner);
00210 }
00211
00212 for (int32_t i=0; i<m_num_svms; i++)
00213 SG_UNREF(outputs[i]);
00214
00215 delete[] outputs;
00216 }
00217
00218 return result;
00219 }
00220
00221 float64_t CMultiClassSVM::classify_example(int32_t num)
00222 {
00223 if (multiclass_type==ONE_VS_REST)
00224 return classify_example_one_vs_rest(num);
00225 else if (multiclass_type==ONE_VS_ONE)
00226 return classify_example_one_vs_one(num);
00227 else
00228 SG_ERROR("unknown multiclass type\n");
00229
00230 return 0;
00231 }
00232
00233 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00234 {
00235 ASSERT(m_num_svms>0);
00236 float64_t* outputs=new float64_t[m_num_svms];
00237 int32_t winner=0;
00238 float64_t max_out=m_svms[0]->classify_example(num);
00239
00240 for (int32_t i=1; i<m_num_svms; i++)
00241 {
00242 outputs[i]=m_svms[i]->classify_example(num);
00243 if (outputs[i]>max_out)
00244 {
00245 winner=i;
00246 max_out=outputs[i];
00247 }
00248 }
00249 delete[] outputs;
00250
00251 return winner;
00252 }
00253
00254 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00255 {
00256 ASSERT(m_num_svms>0);
00257 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00258
00259 int32_t* votes=new int32_t[m_num_classes];
00260 int32_t s=0;
00261
00262 for (int32_t i=0; i<m_num_classes; i++)
00263 {
00264 for (int32_t j=i+1; j<m_num_classes; j++)
00265 {
00266 if (m_svms[s++]->classify_example(num)>0)
00267 votes[i]++;
00268 else
00269 votes[j]++;
00270 }
00271 }
00272
00273 int32_t winner=0;
00274 int32_t max_votes=votes[0];
00275
00276 for (int32_t i=1; i<m_num_classes; i++)
00277 {
00278 if (votes[i]>max_votes)
00279 {
00280 max_votes=votes[i];
00281 winner=i;
00282 }
00283 }
00284
00285 delete[] votes;
00286
00287 return winner;
00288 }
00289
00290 bool CMultiClassSVM::load(FILE* modelfl)
00291 {
00292 bool result=true;
00293 char char_buffer[1024];
00294 int32_t int_buffer;
00295 float64_t double_buffer;
00296 int32_t line_number=1;
00297 int32_t svm_idx=-1;
00298
00299 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00300 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00301 else
00302 {
00303 char_buffer[15]='\0';
00304 if (strcmp("%MultiClassSVM", char_buffer)!=0)
00305 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00306
00307 line_number++;
00308 }
00309
00310 int_buffer=0;
00311 if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00312 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00313
00314 if (!feof(modelfl))
00315 line_number++;
00316
00317 if (int_buffer != multiclass_type)
00318 SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00319
00320 int_buffer=0;
00321 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00322 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00323
00324 if (!feof(modelfl))
00325 line_number++;
00326
00327 if (int_buffer < 2)
00328 SG_ERROR("less than 2 classes - how is this multiclass?\n");
00329
00330 create_multiclass_svm(int_buffer);
00331
00332 int_buffer=0;
00333 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00334 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00335
00336 if (!feof(modelfl))
00337 line_number++;
00338
00339 if (m_num_svms != int_buffer)
00340 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00341
00342 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00343 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00344
00345 if (!feof(modelfl))
00346 line_number++;
00347
00348 for (int32_t n=0; n<m_num_svms; n++)
00349 {
00350 svm_idx=-1;
00351 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00352 {
00353 result=false;
00354 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00355 }
00356 else
00357 {
00358 char_buffer[4]='\0';
00359 if (strncmp("%SVM", char_buffer, 4)!=0)
00360 {
00361 result=false;
00362 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00363 }
00364
00365 if (svm_idx != n)
00366 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00367
00368 line_number++;
00369 }
00370
00371 int_buffer=0;
00372 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00373 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00374
00375 if (svm_idx != n)
00376 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00377
00378 if (!feof(modelfl))
00379 line_number++;
00380
00381 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00382 CSVM* svm=new CSVM(int_buffer);
00383
00384 double_buffer=0;
00385
00386 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00387 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00388
00389 if (svm_idx != n)
00390 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00391
00392 if (!feof(modelfl))
00393 line_number++;
00394
00395 svm->set_bias(double_buffer);
00396
00397 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00398 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00399
00400 if (svm_idx != n)
00401 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00402
00403 if (!feof(modelfl))
00404 line_number++;
00405
00406 for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00407 {
00408 double_buffer=0;
00409 int_buffer=0;
00410
00411 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00412 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00413
00414 if (!feof(modelfl))
00415 line_number++;
00416
00417 svm->set_support_vector(i, int_buffer);
00418 svm->set_alpha(i, double_buffer);
00419 }
00420
00421 if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00422 {
00423 result=false;
00424 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00425 }
00426 else
00427 {
00428 char_buffer[3]='\0';
00429 if (strcmp("];", char_buffer)!=0)
00430 {
00431 result=false;
00432 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00433 }
00434 line_number++;
00435 }
00436
00437 set_svm(n, svm);
00438 }
00439
00440 svm_loaded=result;
00441 return result;
00442 }
00443
00444 bool CMultiClassSVM::save(FILE* modelfl)
00445 {
00446 if (!kernel)
00447 SG_ERROR("Kernel not defined!\n");
00448
00449 if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00450 SG_ERROR("Multiclass SVM not trained!\n");
00451
00452 SG_INFO( "Writing model file...");
00453 fprintf(modelfl,"%%MultiClassSVM\n");
00454 fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00455 fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00456 fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00457 fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00458
00459 for (int32_t i=0; i<m_num_svms; i++)
00460 {
00461 CSVM* svm=m_svms[i];
00462 ASSERT(svm);
00463 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00464 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00465 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00466
00467 fprintf(modelfl, "alphas%d=[\n", i);
00468
00469 for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00470 {
00471 fprintf(modelfl,"\t[%+10.16e,%d];\n",
00472 svm->get_alpha(j), svm->get_support_vector(j));
00473 }
00474
00475 fprintf(modelfl, "];\n");
00476 }
00477
00478 SG_DONE();
00479 return true ;
00480 }