SVMLin.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "classifier/svm/SVMLin.h"
00012 #include "features/Labels.h"
00013 #include "lib/Mathematics.h"
00014 #include "classifier/svm/ssl.h"
00015 #include "classifier/LinearClassifier.h"
00016 #include "features/DotFeatures.h"
00017 #include "features/Labels.h"
00018
00019 CSVMLin::CSVMLin()
00020 : CLinearClassifier(), C1(1), C2(1), epsilon(1e-5), use_bias(true)
00021 {
00022 }
00023
00024 CSVMLin::CSVMLin(
00025 float64_t C, CDotFeatures* traindat, CLabels* trainlab)
00026 : CLinearClassifier(), C1(C), C2(C), epsilon(1e-5), use_bias(true)
00027 {
00028 set_features(traindat);
00029 set_labels(trainlab);
00030 }
00031
00032
00033 CSVMLin::~CSVMLin()
00034 {
00035 }
00036
00037 bool CSVMLin::train()
00038 {
00039 ASSERT(labels);
00040 ASSERT(features);
00041
00042 int32_t num_train_labels=0;
00043 float64_t* train_labels=labels->get_labels(num_train_labels);
00044 int32_t num_feat=features->get_dim_feature_space();
00045 int32_t num_vec=features->get_num_vectors();
00046
00047 ASSERT(num_vec==num_train_labels);
00048 delete[] w;
00049
00050 struct options Options;
00051 struct data Data;
00052 struct vector_double Weights;
00053 struct vector_double Outputs;
00054
00055 Data.l=num_vec;
00056 Data.m=num_vec;
00057 Data.u=0;
00058 Data.n=num_feat+1;
00059 Data.nz=num_feat+1;
00060 Data.Y=train_labels;
00061 Data.features=features;
00062 Data.C = new float64_t[Data.l];
00063
00064 Options.algo = SVM;
00065 Options.lambda=1/(2*get_C1());
00066 Options.lambda_u=1/(2*get_C1());
00067 Options.S=10000;
00068 Options.R=0.5;
00069 Options.epsilon = get_epsilon();
00070 Options.cgitermax=10000;
00071 Options.mfnitermax=50;
00072 Options.Cp = get_C2()/get_C1();
00073 Options.Cn = 1;
00074
00075 if (use_bias)
00076 Options.bias=1.0;
00077 else
00078 Options.bias=0.0;
00079
00080 for (int32_t i=0;i<num_vec;i++)
00081 {
00082 if(train_labels[i]>0)
00083 Data.C[i]=Options.Cp;
00084 else
00085 Data.C[i]=Options.Cn;
00086 }
00087 ssl_train(&Data, &Options, &Weights, &Outputs);
00088 ASSERT(Weights.vec && Weights.d==num_feat+1);
00089
00090 float64_t sgn=train_labels[0];
00091 for (int32_t i=0; i<num_feat+1; i++)
00092 Weights.vec[i]*=sgn;
00093
00094 set_w(Weights.vec, num_feat);
00095 set_bias(Weights.vec[num_feat]);
00096
00097 delete[] Data.C;
00098 delete[] train_labels;
00099 delete[] Outputs.vec;
00100 return true;
00101 }