LibSVM.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "classifier/svm/LibSVM.h"
00012 #include "lib/io.h"
00013
00014 CLibSVM::CLibSVM()
00015 : CSVM(), model(NULL)
00016 {
00017 }
00018
00019 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab)
00020 : CSVM(C, k, lab), model(NULL)
00021 {
00022 }
00023
00024 CLibSVM::~CLibSVM()
00025 {
00026
00027 }
00028
00029 bool CLibSVM::train()
00030 {
00031 struct svm_node* x_space;
00032
00033 ASSERT(labels && labels->get_num_labels());
00034 ASSERT(labels->is_two_class_labeling());
00035
00036 problem.l=labels->get_num_labels();
00037 SG_INFO( "%d trainlabels\n", problem.l);
00038
00039 problem.y=new float64_t[problem.l];
00040 problem.x=new struct svm_node*[problem.l];
00041 x_space=new struct svm_node[2*problem.l];
00042
00043 for (int32_t i=0; i<problem.l; i++)
00044 {
00045 problem.y[i]=labels->get_label(i);
00046 problem.x[i]=&x_space[2*i];
00047 x_space[2*i].index=i;
00048 x_space[2*i+1].index=-1;
00049 }
00050
00051 int32_t weights_label[2]={-1,+1};
00052 float64_t weights[2]={1.0,get_C2()/get_C1()};
00053
00054 ASSERT(kernel && kernel->has_features());
00055 ASSERT(kernel->get_num_vec_lhs()==problem.l);
00056
00057 param.svm_type=C_SVC;
00058 param.kernel_type = LINEAR;
00059 param.degree = 3;
00060 param.gamma = 0;
00061 param.coef0 = 0;
00062 param.nu = 0.5;
00063 param.kernel=kernel;
00064 param.cache_size = kernel->get_cache_size();
00065 param.C = get_C1();
00066 param.eps = epsilon;
00067 param.p = 0.1;
00068 param.shrinking = 1;
00069 param.nr_weight = 2;
00070 param.weight_label = weights_label;
00071 param.weight = weights;
00072
00073 const char* error_msg = svm_check_parameter(&problem,¶m);
00074
00075 if(error_msg)
00076 SG_ERROR("Error: %s\n",error_msg);
00077
00078 model = svm_train(&problem, ¶m);
00079
00080 if (model)
00081 {
00082 ASSERT(model->nr_class==2);
00083 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]));
00084
00085 int32_t num_sv=model->l;
00086
00087 create_new_model(num_sv);
00088 CSVM::set_objective(model->objective);
00089
00090 float64_t sgn=model->label[0];
00091
00092 set_bias(-sgn*model->rho[0]);
00093
00094 for (int32_t i=0; i<num_sv; i++)
00095 {
00096 set_support_vector(i, (model->SV[i])->index);
00097 set_alpha(i, sgn*model->sv_coef[0][i]);
00098 }
00099
00100 delete[] problem.x;
00101 delete[] problem.y;
00102 delete[] x_space;
00103
00104 svm_destroy_model(model);
00105 model=NULL;
00106 return true;
00107 }
00108 else
00109 return false;
00110 }