SVMLin.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2006-2009 Soeren Sonnenburg
00008  * Copyright (C) 2006-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/svm/SVMLin.h"
00012 #include "features/Labels.h"
00013 #include "lib/Mathematics.h"
00014 #include "classifier/svm/ssl.h"
00015 #include "classifier/LinearClassifier.h"
00016 #include "features/DotFeatures.h"
00017 #include "features/Labels.h"
00018 
00019 CSVMLin::CSVMLin()
00020 : CLinearClassifier(), C1(1), C2(1), epsilon(1e-5), use_bias(true)
00021 {
00022 }
00023 
00024 CSVMLin::CSVMLin(
00025     float64_t C, CDotFeatures* traindat, CLabels* trainlab)
00026 : CLinearClassifier(), C1(C), C2(C), epsilon(1e-5), use_bias(true)
00027 {
00028     set_features(traindat);
00029     set_labels(trainlab);
00030 }
00031 
00032 
00033 CSVMLin::~CSVMLin()
00034 {
00035 }
00036 
00037 bool CSVMLin::train()
00038 {
00039     ASSERT(labels);
00040     ASSERT(features);
00041 
00042     int32_t num_train_labels=0;
00043     float64_t* train_labels=labels->get_labels(num_train_labels);
00044     int32_t num_feat=features->get_dim_feature_space();
00045     int32_t num_vec=features->get_num_vectors();
00046 
00047     ASSERT(num_vec==num_train_labels);
00048     delete[] w;
00049 
00050     struct options Options;
00051     struct data Data;
00052     struct vector_double Weights;
00053     struct vector_double Outputs;
00054 
00055     Data.l=num_vec;
00056     Data.m=num_vec;
00057     Data.u=0; 
00058     Data.n=num_feat+1;
00059     Data.nz=num_feat+1;
00060     Data.Y=train_labels;
00061     Data.features=features;
00062     Data.C = new float64_t[Data.l];
00063 
00064     Options.algo = SVM;
00065     Options.lambda=1/(2*get_C1());
00066     Options.lambda_u=1/(2*get_C1());
00067     Options.S=10000;
00068     Options.R=0.5;
00069     Options.epsilon = get_epsilon();
00070     Options.cgitermax=10000;
00071     Options.mfnitermax=50;
00072     Options.Cp = get_C2()/get_C1();
00073     Options.Cn = 1;
00074     
00075     if (use_bias)
00076         Options.bias=1.0;
00077     else
00078         Options.bias=0.0;
00079 
00080     for (int32_t i=0;i<num_vec;i++)
00081     {
00082         if(train_labels[i]>0) 
00083             Data.C[i]=Options.Cp;
00084         else 
00085             Data.C[i]=Options.Cn;
00086     }
00087     ssl_train(&Data, &Options, &Weights, &Outputs);
00088     ASSERT(Weights.vec && Weights.d==num_feat+1);
00089 
00090     float64_t sgn=train_labels[0];
00091     for (int32_t i=0; i<num_feat+1; i++)
00092         Weights.vec[i]*=sgn;
00093 
00094     set_w(Weights.vec, num_feat);
00095     set_bias(Weights.vec[num_feat]);
00096 
00097     delete[] Data.C;
00098     delete[] train_labels;
00099     delete[] Outputs.vec;
00100     return true;
00101 }

SHOGUN Machine Learning Toolbox - Documentation