CPLEXSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/svm/CPLEXSVM.h"
00012 #include "lib/common.h"
00013 
00014 #ifdef USE_CPLEX
00015 #include "lib/io.h"
00016 #include "lib/Mathematics.h"
00017 #include "lib/Cplex.h"
00018 #include "features/Labels.h"
00019 
00020 CCPLEXSVM::CCPLEXSVM()
00021 : CSVM()
00022 {
00023 }
00024 
00025 CCPLEXSVM::~CCPLEXSVM()
00026 {
00027 }
00028 
00029 bool CCPLEXSVM::train()
00030 {
00031     bool result = false;
00032     CCplex cplex;
00033 
00034     if (cplex.init(E_QP))
00035     {
00036         int32_t n,m;
00037         int32_t num_label=0;
00038         float64_t* y = labels->get_labels(num_label);
00039         float64_t* H = kernel->get_kernel_matrix_real(m, n, NULL);
00040         ASSERT(n>0 && n==m && n==num_label);
00041         float64_t* alphas=new float64_t[n];
00042         float64_t* lb=new float64_t[n];
00043         float64_t* ub=new float64_t[n];
00044 
00045         //hessian y'y.*K
00046         for (int32_t i=0; i<n; i++)
00047         {
00048             lb[i]=0;
00049             ub[i]=get_C1();
00050 
00051             for (int32_t j=0; j<n; j++)
00052                 H[i*n+j]*=y[j]*y[i];
00053         }
00054 
00055         //feed qp to cplex
00056 
00057 
00058         int32_t j=0;
00059         for (int32_t i=0; i<n; i++)
00060         {
00061             if (alphas[i]>0)
00062             {
00063                 //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
00064                 set_alpha(j, alphas[i]*labels->get_label(i));
00065                 set_support_vector(j, i);
00066                 j++;
00067             }
00068         }
00069         compute_objective();
00070         SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
00071         SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
00072 
00073         delete[] alphas;
00074         delete[] lb;
00075         delete[] ub;
00076         delete[] H;
00077         delete[] y;
00078 
00079         result = true;
00080     }
00081 
00082     if (!result)
00083         SG_ERROR( "cplex svm failed");
00084 
00085     return result;
00086 }
00087 #endif

SHOGUN Machine Learning Toolbox - Documentation