MKLMultiClass.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009 Alexander Binder
00008  * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/mkl/MKLMultiClass.h"
00012 #include "lib/io.h"
00013 
00014 using namespace shogun;
00015 
00016 
00017 CMKLMultiClass::CMKLMultiClass()
00018 : CMultiClassSVM(ONE_VS_REST)
00019 {
00020     svm=NULL;
00021     lpw=NULL;
00022 
00023     mkl_eps=0.01;
00024     max_num_mkl_iters=999;
00025     pnorm=1;
00026 }
00027 
00028 CMKLMultiClass::CMKLMultiClass(float64_t C, CKernel* k, CLabels* lab)
00029 : CMultiClassSVM(ONE_VS_REST, C, k, lab)
00030 {
00031     svm=NULL;
00032     lpw=NULL;
00033     
00034     mkl_eps=0.01;
00035     max_num_mkl_iters=999;
00036     pnorm=1;
00037 }
00038 
00039 
00040 CMKLMultiClass::~CMKLMultiClass()
00041 {
00042     SG_UNREF(svm);
00043     svm=NULL;
00044     delete lpw;
00045     lpw=NULL;
00046 }
00047 
00048 CMKLMultiClass::CMKLMultiClass( const CMKLMultiClass & cm)
00049 : CMultiClassSVM(ONE_VS_REST)
00050 {
00051     svm=NULL;
00052     lpw=NULL;
00053     SG_ERROR(
00054             " CMKLMultiClass::CMKLMultiClass(const CMKLMultiClass & cm): must "
00055             "not be called, glpk structure is currently not copyable");
00056 }
00057 
00058 CMKLMultiClass CMKLMultiClass::operator=( const CMKLMultiClass & cm)
00059 {
00060         SG_ERROR(
00061             " CMKLMultiClass CMKLMultiClass::operator=(...): must "
00062             "not be called, glpk structure is currently not copyable");
00063     return (*this);
00064 }
00065 
00066 
00067 void CMKLMultiClass::initsvm()
00068 {
00069     if (!labels)    
00070     {
00071         SG_ERROR("CMKLMultiClass::initsvm(): the set labels is NULL\n");
00072     }
00073 
00074     SG_UNREF(svm);
00075     svm=new CGMNPSVM;
00076     SG_REF(svm);
00077 
00078     svm->set_C(get_C1(),get_C2());
00079     svm->set_epsilon(epsilon);
00080 
00081     int32_t numlabels;
00082     float64_t * lb=labels->get_labels ( numlabels);
00083 
00084     if (numlabels<=0)   
00085     {
00086         SG_ERROR("CMKLMultiClass::initsvm(): the number of labels is "
00087                 "nonpositive, do not know how to handle this!\n");
00088     }
00089 
00090     CLabels* newlab=new CLabels(lb, labels->get_num_labels() );
00091     delete[] lb;
00092     lb=NULL;
00093 
00094     svm->set_labels(newlab);
00095 
00096     newlab=NULL;
00097 }
00098 
00099 void CMKLMultiClass::initlpsolver()
00100 {
00101     if (!kernel)    
00102     {
00103         SG_ERROR("CMKLMultiClass::initlpsolver(): the set kernel is NULL\n");
00104     }
00105 
00106     if (kernel->get_kernel_type()!=K_COMBINED)
00107     {
00108         SG_ERROR("CMKLMultiClass::initlpsolver(): given kernel is not of type"
00109                 " K_COMBINED %d required by Multiclass Mkl \n",
00110                 kernel->get_kernel_type());
00111     }
00112 
00113     int numker=dynamic_cast<CCombinedKernel *>(kernel)->get_num_subkernels();
00114 
00115     ASSERT(numker>0);
00116     /*
00117     if (lpw)
00118     {
00119         delete lpw;
00120     }
00121     */
00122     
00123     //lpw=new MKLMultiClassGLPK;
00124     if(pnorm>1)
00125     {
00126         lpw=new MKLMultiClassGradient;
00127         lpw->set_mkl_norm(pnorm);
00128     }
00129     else
00130     {
00131         lpw=new MKLMultiClassGLPK;
00132     }
00133     lpw->setup(numker);
00134     
00135 }
00136 
00137 
00138 bool CMKLMultiClass::evaluatefinishcriterion(const int32_t
00139         numberofsilpiterations)
00140 {
00141     if ( (max_num_mkl_iters>0) && (numberofsilpiterations>=max_num_mkl_iters) )
00142     {
00143         return(true);
00144     }
00145 
00146     if (weightshistory.size()>1)
00147     {
00148         std::vector<float64_t> wold,wnew;
00149 
00150         wold=weightshistory[ weightshistory.size()-2 ];
00151         wnew=weightshistory.back();
00152         float64_t delta=0;
00153 
00154         ASSERT (wold.size()==wnew.size());
00155 
00156 
00157         if((pnorm<=1)&&(!normweightssquared.empty()))
00158         {
00159 
00160             delta=0;
00161             for (size_t i=0;i< wnew.size();++i)
00162             {
00163                 delta+=(wold[i]-wnew[i])*(wold[i]-wnew[i]);
00164             }
00165             delta=sqrt(delta);
00166             SG_SDEBUG("L1 Norm chosen, weight delta %f \n",delta);
00167 
00168 
00169             //check dual gap part for mkl
00170             int32_t maxind=0;
00171             float64_t maxval=normweightssquared[maxind];
00172             delta=0;
00173             for (size_t i=0;i< wnew.size();++i)
00174             {
00175                 delta+=normweightssquared[i]*wnew[i];
00176                 if(wnew[i]>maxval)
00177                 {
00178                     maxind=i;
00179                     maxval=wnew[i];
00180                 }
00181             }
00182             delta-=normweightssquared[maxind];
00183             delta=fabs(delta);
00184             SG_SDEBUG("L1 Norm chosen, MKL part of duality gap %f \n",delta);
00185             if( (delta < mkl_eps) && (numberofsilpiterations>=1) )
00186             {
00187                 return(true);
00188             }
00189             
00190 
00191 
00192         }
00193         else
00194         {
00195             delta=0;
00196             for (size_t i=0;i< wnew.size();++i)
00197             {
00198                 delta+=(wold[i]-wnew[i])*(wold[i]-wnew[i]);
00199             }
00200             delta=sqrt(delta);
00201             SG_SDEBUG("Lp Norm chosen, weight delta %f \n",delta);
00202 
00203             if( (delta < mkl_eps) && (numberofsilpiterations>=1) )
00204             {
00205                 return(true);
00206             }
00207 
00208         }
00209     }
00210 
00211     return(false);
00212 }
00213 
00214 void CMKLMultiClass::addingweightsstep( const std::vector<float64_t> &
00215         curweights)
00216 {
00217 
00218     if (weightshistory.size()>2)
00219     {
00220         weightshistory.erase(weightshistory.begin());
00221     }
00222 
00223     float64_t* weights(NULL);
00224     weights=new float64_t[curweights.size()];
00225     std::copy(curweights.begin(),curweights.end(),weights);
00226 
00227     kernel->set_subkernel_weights(  weights, curweights.size());
00228     delete[] weights;
00229     weights=NULL;
00230 
00231     initsvm();
00232 
00233     svm->set_kernel(kernel);
00234     svm->train();
00235 
00236     float64_t sumofsignfreealphas=getsumofsignfreealphas();
00237     int32_t numkernels=
00238             dynamic_cast<CCombinedKernel *>(kernel)->get_num_subkernels();
00239 
00240 
00241     normweightssquared.resize(numkernels);
00242     for (int32_t ind=0; ind < numkernels; ++ind )
00243     {
00244         normweightssquared[ind]=getsquarenormofprimalcoefficients( ind );
00245     }
00246 
00247     lpw->addconstraint(normweightssquared,sumofsignfreealphas);
00248 }
00249 
00250 float64_t CMKLMultiClass::getsumofsignfreealphas()
00251 {
00252 
00253     std::vector<int> trainlabels2(labels->get_num_labels());
00254     int32_t tmpint;
00255     int32_t * lab=labels->get_int_labels ( tmpint);
00256     std::copy(lab,lab+labels->get_num_labels(), trainlabels2.begin());
00257     delete[] lab;
00258     lab=NULL;
00259 
00260 
00261     ASSERT (trainlabels2.size()>0);
00262     float64_t sum=0;
00263 
00264     for (int32_t nc=0; nc< labels->get_num_classes();++nc)
00265     {
00266         CSVM * sm=svm->get_svm(nc);
00267 
00268         float64_t bia=sm->get_bias();
00269         sum+= bia*bia;
00270 
00271         SG_UNREF(sm);
00272     }
00273 
00274     ::std::vector< ::std::vector<float64_t> > basealphas;
00275     svm->getbasealphas( basealphas);
00276 
00277     for (size_t lb=0; lb< trainlabels2.size();++lb)
00278     {
00279         for (int32_t nc=0; nc< labels->get_num_classes();++nc)
00280         {
00281             CSVM * sm=svm->get_svm(nc);
00282 
00283             if ((int)nc!=trainlabels2[lb])
00284             {
00285                 CSVM * sm2=svm->get_svm(trainlabels2[lb]);
00286 
00287                 float64_t bia1=sm2->get_bias();
00288                 float64_t bia2=sm->get_bias();
00289                 SG_UNREF(sm2);
00290 
00291                 sum+= -basealphas[nc][lb]*(bia1-bia2-1);
00292             }
00293             SG_UNREF(sm);
00294         }
00295     }
00296 
00297     return(sum);
00298 }
00299 
00300 float64_t CMKLMultiClass::getsquarenormofprimalcoefficients(
00301         const int32_t ind)
00302 {
00303     CKernel * ker=dynamic_cast<CCombinedKernel *>(kernel)->get_kernel(ind);
00304 
00305     float64_t tmp=0;
00306 
00307     for (int32_t classindex=0; classindex< labels->get_num_classes();
00308             ++classindex)
00309     {
00310         CSVM * sm=svm->get_svm(classindex);
00311 
00312         for (int32_t i=0; i < sm->get_num_support_vectors(); ++i)
00313         {
00314             float64_t alphai=sm->get_alpha(i);
00315             int32_t svindi= sm->get_support_vector(i); 
00316 
00317             for (int32_t k=0; k < sm->get_num_support_vectors(); ++k)
00318             {
00319                 float64_t alphak=sm->get_alpha(k);
00320                 int32_t svindk=sm->get_support_vector(k);
00321 
00322                 tmp+=alphai*ker->kernel(svindi,svindk)
00323                 *alphak;
00324 
00325             }
00326         }
00327         SG_UNREF(sm);
00328     }
00329     SG_UNREF(ker);
00330     ker=NULL;
00331 
00332     return(tmp);
00333 }
00334 
00335 
00336 bool CMKLMultiClass::train(CFeatures* data)
00337 {
00338     int numcl=labels->get_num_classes();
00339     ASSERT(kernel);
00340     ASSERT(labels && labels->get_num_labels());
00341 
00342     if (data)
00343     {
00344         if (labels->get_num_labels() != data->get_num_vectors())
00345             SG_ERROR("Number of training vectors does not match number of "
00346                     "labels\n");
00347         kernel->init(data, data);
00348     }
00349 
00350     initlpsolver();
00351 
00352     weightshistory.clear();
00353 
00354     int32_t numkernels=
00355             dynamic_cast<CCombinedKernel *>(kernel)->get_num_subkernels();
00356 
00357     ::std::vector<float64_t> curweights(numkernels,1.0/numkernels);
00358     weightshistory.push_back(curweights);
00359 
00360     addingweightsstep(curweights);
00361 
00362     int32_t numberofsilpiterations=0;
00363     bool final=false;
00364     while (!final)
00365     {
00366 
00367         //curweights.clear();
00368         lpw->computeweights(curweights);
00369         weightshistory.push_back(curweights);
00370 
00371 
00372         final=evaluatefinishcriterion(numberofsilpiterations);
00373         ++numberofsilpiterations;
00374 
00375         addingweightsstep(curweights);
00376 
00377     } // while(false==final)
00378 
00379 
00380     //set alphas, bias, support vecs
00381     ASSERT(numcl>=1);
00382     create_multiclass_svm(numcl);
00383 
00384     for (int32_t i=0; i<numcl; i++)
00385     {
00386         CSVM* osvm=svm->get_svm(i);
00387         CSVM* nsvm=new CSVM(osvm->get_num_support_vectors());
00388 
00389         for (int32_t k=0; k<osvm->get_num_support_vectors() ; k++)
00390         {
00391             nsvm->set_alpha(k, osvm->get_alpha(k) );
00392             nsvm->set_support_vector(k,osvm->get_support_vector(k) );
00393         }
00394         nsvm->set_bias(osvm->get_bias() );
00395         set_svm(i, nsvm);
00396 
00397         SG_UNREF(osvm);
00398         osvm=NULL;
00399     }
00400 
00401     SG_UNREF(svm);
00402     svm=NULL;
00403     if (lpw)
00404     {
00405         delete lpw;
00406     }
00407     lpw=NULL;
00408     return(true);
00409 }
00410 
00411 
00412 
00413 
00414 float64_t* CMKLMultiClass::getsubkernelweights(int32_t & numweights)
00415 {
00416     if ( weightshistory.empty() )
00417     {
00418         numweights=0;
00419         return NULL;
00420     }
00421 
00422     std::vector<float64_t> subkerw=weightshistory.back();
00423     numweights=weightshistory.back().size();
00424 
00425     float64_t* res=new float64_t[numweights];
00426     std::copy(weightshistory.back().begin(), weightshistory.back().end(),res);
00427     return res;
00428 }
00429 
00430 void CMKLMultiClass::set_mkl_epsilon(float64_t eps )
00431 {
00432     mkl_eps=eps;
00433 }
00434 
00435 void CMKLMultiClass::set_max_num_mkliters(int32_t maxnum)
00436 {
00437     max_num_mkl_iters=maxnum;
00438 }
00439 
00440 void CMKLMultiClass::set_mkl_norm(float64_t norm)
00441 {
00442     pnorm=norm;
00443     if(pnorm<1 )
00444         SG_ERROR("CMKLMultiClass::set_mkl_norm(float64_t norm) : parameter pnorm<1");
00445 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation