LibLinear.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include "lib/config.h"
00011
00012 #ifdef HAVE_LAPACK
00013 #include "lib/io.h"
00014 #include "classifier/svm/LibLinear.h"
00015 #include "classifier/svm/SVM_linear.h"
00016 #include "classifier/svm/Tron.h"
00017 #include "features/DotFeatures.h"
00018
00019 CLibLinear::CLibLinear(LIBLINEAR_LOSS l)
00020 : CLinearClassifier()
00021 {
00022 loss=l;
00023 use_bias=false;
00024 C1=1;
00025 C2=1;
00026 }
00027
00028 CLibLinear::CLibLinear(
00029 float64_t C, CDotFeatures* traindat, CLabels* trainlab)
00030 : CLinearClassifier(), C1(C), C2(C), use_bias(true), epsilon(1e-5)
00031 {
00032 set_features(traindat);
00033 set_labels(trainlab);
00034 loss=LR;
00035 }
00036
00037
00038 CLibLinear::~CLibLinear()
00039 {
00040 }
00041
00042 bool CLibLinear::train()
00043 {
00044 ASSERT(labels);
00045 ASSERT(features);
00046 ASSERT(labels->is_two_class_labeling());
00047
00048 int32_t num_train_labels=labels->get_num_labels();
00049 int32_t num_feat=features->get_dim_feature_space();
00050 int32_t num_vec=features->get_num_vectors();
00051
00052 ASSERT(num_vec==num_train_labels);
00053 delete[] w;
00054 if (use_bias)
00055 w=new float64_t[num_feat+1];
00056 else
00057 w=new float64_t[num_feat+0];
00058 w_dim=num_feat;
00059
00060 problem prob;
00061 if (use_bias)
00062 {
00063 prob.n=w_dim+1;
00064 memset(w, 0, sizeof(float64_t)*(w_dim+1));
00065 }
00066 else
00067 {
00068 prob.n=w_dim;
00069 memset(w, 0, sizeof(float64_t)*(w_dim+0));
00070 }
00071 prob.l=num_vec;
00072 prob.x=features;
00073 prob.y=new int[prob.l];
00074 prob.use_bias=use_bias;
00075
00076 for (int32_t i=0; i<prob.l; i++)
00077 prob.y[i]=labels->get_int_label(i);
00078
00079 SG_INFO( "%d training points %d dims\n", prob.l, prob.n);
00080
00081 function *fun_obj=NULL;
00082
00083 switch (loss)
00084 {
00085 case LR:
00086 fun_obj=new l2_lr_fun(&prob, get_C1(), get_C2());
00087 break;
00088 case L2:
00089 fun_obj=new l2loss_svm_fun(&prob, get_C1(), get_C2());
00090 break;
00091 default:
00092 SG_ERROR("unknown loss\n");
00093 break;
00094 }
00095
00096 if (fun_obj)
00097 {
00098 CTron tron_obj(fun_obj, epsilon);
00099 tron_obj.tron(w);
00100 float64_t sgn=prob.y[0];
00101
00102 for (int32_t i=0; i<w_dim; i++)
00103 w[i]*=sgn;
00104
00105 if (use_bias)
00106 set_bias(sgn*w[w_dim]);
00107 else
00108 set_bias(0);
00109
00110 delete fun_obj;
00111 }
00112
00113 delete[] prob.y;
00114
00115 return true;
00116 }
00117 #endif //HAVE_LAPACK