SVM.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _SVM_H___
00012 #define _SVM_H___
00013
00014 #include "lib/common.h"
00015 #include "features/Features.h"
00016 #include "kernel/Kernel.h"
00017 #include "kernel/KernelMachine.h"
00018
00019 class CKernelMachine;
00020
00043 class CSVM : public CKernelMachine
00044 {
00045 public:
00049 CSVM(int32_t num_sv=0);
00050
00058 CSVM(float64_t C, CKernel* k, CLabels* lab);
00059 virtual ~CSVM();
00060
00063 void set_defaults(int32_t num_sv=0);
00064
00068 bool load(FILE* svm_file);
00069
00073 bool save(FILE* svm_file);
00074
00079 inline void set_nu(float64_t nue) { nu=nue; }
00080
00089 inline void set_C(float64_t c1, float64_t c2) { C1=c1; C2=c2; }
00090
00095 inline void set_weight_epsilon(float64_t eps) { weight_epsilon=eps; }
00096
00101 inline void set_epsilon(float64_t eps) { epsilon=eps; }
00102
00107 inline void set_tube_epsilon(float64_t eps) { tube_epsilon=eps; }
00108
00113 inline void set_C_mkl(float64_t C) { C_mkl = C; }
00114
00119 inline void set_mkl_norm(float64_t norm)
00120 {
00121 if (norm<=0)
00122 SG_ERROR("Norm must be > 0, e.g., 1-norm is the standard MKL; 2-norm nonsparse MKL\n");
00123 mkl_norm = norm;
00124 }
00125
00130 inline void set_qpsize(int32_t qps) { qpsize=qps; }
00131
00136 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; }
00137
00142 inline bool get_bias_enabled() { return use_bias; }
00143
00148 inline float64_t get_weight_epsilon() { return weight_epsilon; }
00149
00154 inline float64_t get_epsilon() { return epsilon; }
00155
00160 inline float64_t get_nu() { return nu; }
00161
00166 inline float64_t get_C1() { return C1; }
00167
00172 inline float64_t get_C2() { return C2; }
00173
00178 inline int32_t get_qpsize() { return qpsize; }
00179
00185 inline int32_t get_support_vector(int32_t idx)
00186 {
00187 ASSERT(svm_model.svs && idx<svm_model.num_svs);
00188 return svm_model.svs[idx];
00189 }
00190
00196 inline float64_t get_alpha(int32_t idx)
00197 {
00198 ASSERT(svm_model.alpha && idx<svm_model.num_svs);
00199 return svm_model.alpha[idx];
00200 }
00201
00208 inline bool set_support_vector(int32_t idx, int32_t val)
00209 {
00210 if (svm_model.svs && idx<svm_model.num_svs)
00211 svm_model.svs[idx]=val;
00212 else
00213 return false;
00214
00215 return true;
00216 }
00217
00224 inline bool set_alpha(int32_t idx, float64_t val)
00225 {
00226 if (svm_model.alpha && idx<svm_model.num_svs)
00227 svm_model.alpha[idx]=val;
00228 else
00229 return false;
00230
00231 return true;
00232 }
00233
00238 inline float64_t get_bias()
00239 {
00240 return svm_model.b;
00241 }
00242
00247 inline void set_bias(float64_t bias)
00248 {
00249 svm_model.b=bias;
00250 }
00251
00256 inline int32_t get_num_support_vectors()
00257 {
00258 return svm_model.num_svs;
00259 }
00260
00266 void set_alphas(float64_t* alphas, int32_t d)
00267 {
00268 ASSERT(alphas);
00269 ASSERT(d==svm_model.num_svs);
00270
00271 for(int32_t i=0; i<d; i++)
00272 svm_model.alpha[i]=alphas[i];
00273 }
00274
00280 void set_support_vectors(int32_t* svs, int32_t d)
00281 {
00282 ASSERT(svs);
00283 ASSERT(d==svm_model.num_svs);
00284
00285 for(int32_t i=0; i<d; i++)
00286 svm_model.svs[i]=svs[i];
00287 }
00288
00294 void get_support_vectors(int32_t** svs, int32_t* num)
00295 {
00296 int32_t nsv = get_num_support_vectors();
00297
00298 ASSERT(svs && num);
00299 *svs=NULL;
00300 *num=nsv;
00301
00302 if (nsv>0)
00303 {
00304 *svs = (int32_t*) malloc(sizeof(int32_t)*nsv);
00305 for(int32_t i=0; i<nsv; i++)
00306 (*svs)[i] = get_support_vector(i);
00307 }
00308 }
00309
00315 void get_alphas(float64_t** alphas, int32_t* d1)
00316 {
00317 int32_t nsv = get_num_support_vectors();
00318
00319 ASSERT(alphas && d1);
00320 *alphas=NULL;
00321 *d1=nsv;
00322
00323 if (nsv>0)
00324 {
00325 *alphas = (float64_t*) malloc(nsv*sizeof(float64_t));
00326 for(int32_t i=0; i<nsv; i++)
00327 (*alphas)[i] = get_alpha(i);
00328 }
00329 }
00330
00335 inline bool create_new_model(int32_t num)
00336 {
00337 delete[] svm_model.alpha;
00338 delete[] svm_model.svs;
00339
00340 svm_model.b=0;
00341 svm_model.num_svs=num;
00342
00343 if (num>0)
00344 {
00345 svm_model.alpha= new float64_t[num];
00346 svm_model.svs= new int32_t[num];
00347 return (svm_model.alpha!=NULL && svm_model.svs!=NULL);
00348 }
00349 else
00350 {
00351 svm_model.alpha= NULL;
00352 svm_model.svs=NULL;
00353 return true;
00354 }
00355 }
00356
00361 inline void set_shrinking_enabled(bool enable)
00362 {
00363 use_shrinking=enable;
00364 }
00365
00370 inline bool get_shrinking_enabled()
00371 {
00372 return use_shrinking;
00373 }
00374
00379 inline void set_mkl_enabled(bool enable)
00380 {
00381 use_mkl=enable;
00382 }
00383
00388 inline bool get_mkl_enabled()
00389 {
00390 return use_mkl;
00391 }
00392
00397 float64_t compute_objective();
00398
00403 inline void set_objective(float64_t v)
00404 {
00405 objective=v;
00406 }
00407
00412 inline float64_t get_objective()
00413 {
00414 return objective;
00415 }
00416
00421 bool init_kernel_optimization();
00422
00428 virtual CLabels* classify(CLabels* lab=NULL);
00429
00435 virtual float64_t classify_example(int32_t num);
00436
00442 static void* classify_example_helper(void* p);
00443
00445 inline virtual const char* get_name() const { return "SVM"; }
00446
00451 inline int32_t get_mkl_iterations() { return mkl_iterations; }
00452
00453 protected:
00456 struct TModel
00457 {
00459 float64_t b;
00461 float64_t* alpha;
00463 int32_t* svs;
00465 int32_t num_svs;
00466 };
00467
00469 TModel svm_model;
00471 bool svm_loaded;
00473 float64_t weight_epsilon;
00475 float64_t epsilon;
00477 float64_t tube_epsilon;
00479 float64_t nu;
00481 float64_t C1;
00483 float64_t C2;
00485 float64_t mkl_norm;
00487 float64_t C_mkl;
00489 float64_t objective;
00491 int32_t qpsize;
00493 bool use_bias;
00495 bool use_shrinking;
00497 bool use_mkl;
00499 int32_t mkl_iterations;
00500 };
00501 #endif