LibSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/svm/LibSVM.h"
00012 #include "lib/io.h"
00013 
00014 CLibSVM::CLibSVM()
00015 : CSVM(), model(NULL)
00016 {
00017 }
00018 
00019 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab)
00020 : CSVM(C, k, lab), model(NULL)
00021 {
00022 }
00023 
00024 CLibSVM::~CLibSVM()
00025 {
00026     //SG_PRINT("deleting LibSVM\n");
00027 }
00028 
00029 bool CLibSVM::train()
00030 {
00031     struct svm_node* x_space;
00032 
00033     ASSERT(labels && labels->get_num_labels());
00034     ASSERT(labels->is_two_class_labeling());
00035 
00036     problem.l=labels->get_num_labels();
00037     SG_INFO( "%d trainlabels\n", problem.l);
00038 
00039     problem.y=new float64_t[problem.l];
00040     problem.x=new struct svm_node*[problem.l];
00041     x_space=new struct svm_node[2*problem.l];
00042 
00043     for (int32_t i=0; i<problem.l; i++)
00044     {
00045         problem.y[i]=labels->get_label(i);
00046         problem.x[i]=&x_space[2*i];
00047         x_space[2*i].index=i;
00048         x_space[2*i+1].index=-1;
00049     }
00050 
00051     int32_t weights_label[2]={-1,+1};
00052     float64_t weights[2]={1.0,get_C2()/get_C1()};
00053 
00054     ASSERT(kernel && kernel->has_features());
00055     ASSERT(kernel->get_num_vec_lhs()==problem.l);
00056 
00057     param.svm_type=C_SVC; // C SVM
00058     param.kernel_type = LINEAR;
00059     param.degree = 3;
00060     param.gamma = 0;    // 1/k
00061     param.coef0 = 0;
00062     param.nu = 0.5;
00063     param.kernel=kernel;
00064     param.cache_size = kernel->get_cache_size();
00065     param.C = get_C1();
00066     param.eps = epsilon;
00067     param.p = 0.1;
00068     param.shrinking = 1;
00069     param.nr_weight = 2;
00070     param.weight_label = weights_label;
00071     param.weight = weights;
00072 
00073     const char* error_msg = svm_check_parameter(&problem,&param);
00074 
00075     if(error_msg)
00076         SG_ERROR("Error: %s\n",error_msg);
00077 
00078     model = svm_train(&problem, &param);
00079 
00080     if (model)
00081     {
00082         ASSERT(model->nr_class==2);
00083         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]));
00084 
00085         int32_t num_sv=model->l;
00086 
00087         create_new_model(num_sv);
00088         CSVM::set_objective(model->objective);
00089 
00090         float64_t sgn=model->label[0];
00091 
00092         set_bias(-sgn*model->rho[0]);
00093 
00094         for (int32_t i=0; i<num_sv; i++)
00095         {
00096             set_support_vector(i, (model->SV[i])->index);
00097             set_alpha(i, sgn*model->sv_coef[0][i]);
00098         }
00099 
00100         delete[] problem.x;
00101         delete[] problem.y;
00102         delete[] x_space;
00103 
00104         svm_destroy_model(model);
00105         model=NULL;
00106         return true;
00107     }
00108     else
00109         return false;
00110 }

SHOGUN Machine Learning Toolbox - Documentation