KernelMachine.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _KERNEL_MACHINE_H__
00012 #define _KERNEL_MACHINE_H__
00013 
00014 #include "lib/common.h"
00015 #include "lib/io.h"
00016 #include "kernel/Kernel.h"
00017 #include "features/Labels.h"
00018 #include "classifier/Classifier.h"
00019 
00020 #include <stdio.h>
00021 
00022 namespace shogun
00023 {
00024 class CClassifier;
00025 class CLabels;
00026 class CKernel;
00027 
00043 class CKernelMachine : public CClassifier
00044 {
00045     public:
00047         CKernelMachine();
00048 
00050         virtual ~CKernelMachine();
00051 
00056         inline void set_kernel(CKernel* k)
00057         {
00058             SG_UNREF(kernel);
00059             SG_REF(k);
00060             kernel=k;
00061         }
00062 
00067         inline CKernel* get_kernel()
00068         {
00069             SG_REF(kernel);
00070             return kernel;
00071         }
00072 
00077         inline void set_batch_computation_enabled(bool enable)
00078         {
00079             use_batch_computation=enable;
00080         }
00081 
00086         inline bool get_batch_computation_enabled()
00087         {
00088             return use_batch_computation;
00089         }
00090 
00095         inline void set_linadd_enabled(bool enable)
00096         {
00097             use_linadd=enable;
00098         }
00099 
00104         inline bool get_linadd_enabled()
00105         {
00106             return use_linadd ;
00107         }
00108 
00113         inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; }
00114 
00119         inline bool get_bias_enabled() { return use_bias; }
00120 
00125         inline float64_t get_bias()
00126         {
00127             return m_bias;
00128         }
00129 
00134         inline void set_bias(float64_t bias)
00135         {
00136             m_bias=bias;
00137         }
00138 
00144         inline int32_t get_support_vector(int32_t idx)
00145         {
00146             ASSERT(m_svs && idx<num_svs);
00147             return m_svs[idx];
00148         }
00149 
00155         inline float64_t get_alpha(int32_t idx)
00156         {
00157             ASSERT(m_alpha && idx<num_svs);
00158             return m_alpha[idx];
00159         }
00160 
00167         inline bool set_support_vector(int32_t idx, int32_t val)
00168         {
00169             if (m_svs && idx<num_svs)
00170                 m_svs[idx]=val;
00171             else
00172                 return false;
00173 
00174             return true;
00175         }
00176 
00183         inline bool set_alpha(int32_t idx, float64_t val)
00184         {
00185             if (m_alpha && idx<num_svs)
00186                 m_alpha[idx]=val;
00187             else
00188                 return false;
00189 
00190             return true;
00191         }
00192 
00197         inline int32_t get_num_support_vectors()
00198         {
00199             return num_svs;
00200         }
00201 
00207         void set_alphas(float64_t* alphas, int32_t d)
00208         {
00209             ASSERT(alphas);
00210             ASSERT(m_alpha);
00211             ASSERT(d==num_svs);
00212 
00213             for(int32_t i=0; i<d; i++)
00214                 m_alpha[i]=alphas[i];
00215         }
00216 
00222         void set_support_vectors(int32_t* svs, int32_t d)
00223         {
00224             ASSERT(m_svs);
00225             ASSERT(svs);
00226             ASSERT(d==num_svs);
00227 
00228             for(int32_t i=0; i<d; i++)
00229                 m_svs[i]=svs[i];
00230         }
00231 
00237         void get_support_vectors(int32_t** svs, int32_t* num)
00238         {
00239             int32_t nsv = get_num_support_vectors();
00240 
00241             ASSERT(svs && num);
00242             *svs=NULL;
00243             *num=nsv;
00244 
00245             if (nsv>0)
00246             {
00247                 *svs = (int32_t*) malloc(sizeof(int32_t)*nsv);
00248                 for(int32_t i=0; i<nsv; i++)
00249                     (*svs)[i] = get_support_vector(i);
00250             }
00251         }
00252 
00258         void get_alphas(float64_t** alphas, int32_t* d1)
00259         {
00260             int32_t nsv = get_num_support_vectors();
00261 
00262             ASSERT(alphas && d1);
00263             *alphas=NULL;
00264             *d1=nsv;
00265 
00266             if (nsv>0)
00267             {
00268                 *alphas = (float64_t*) malloc(nsv*sizeof(float64_t));
00269                 for(int32_t i=0; i<nsv; i++)
00270                     (*alphas)[i] = get_alpha(i);
00271             }
00272         }
00273 
00278         inline bool create_new_model(int32_t num)
00279         {
00280             delete[] m_alpha;
00281             delete[] m_svs;
00282 
00283             m_bias=0;
00284             num_svs=num;
00285 
00286             if (num>0)
00287             {
00288                 m_alpha= new float64_t[num];
00289                 m_svs= new int32_t[num];
00290                 return (m_alpha!=NULL && m_svs!=NULL);
00291             }
00292             else
00293             {
00294                 m_alpha= NULL;
00295                 m_svs=NULL;
00296                 return true;
00297             }
00298         }
00299 
00304         bool init_kernel_optimization();
00305 
00310         virtual CLabels* classify();
00311 
00317         virtual CLabels* classify(CFeatures* data);
00318 
00324         virtual float64_t classify_example(int32_t num);
00325 
00331         static void* classify_example_helper(void* p);
00332 
00333 #ifdef HAVE_BOOST_SERIALIZATION
00334     private:
00335 
00336 
00337         friend class ::boost::serialization::access;
00338         // When the class Archive corresponds to an output archive, the
00339         // & operator is defined similar to <<.  Likewise, when the class Archive
00340         // is a type of input archive the & operator is defined similar to >>.
00341         template<class Archive>
00342 
00343             void serialize(Archive & ar, const unsigned int archive_version)
00344             {
00345 
00346                 SG_DEBUG("archiving CKernelMachine\n");
00347                 ar & ::boost::serialization::base_object<CClassifier>(*this);
00348 
00349                 ar & kernel;
00350                 ar & use_batch_computation;
00351                 ar & use_linadd;
00352                 ar & use_bias;
00353                 ar & m_bias;
00354 
00355                 SG_DEBUG("done with CKernelMachine\n");
00356             }
00357 
00358 
00359         /*
00360 
00364         friend class ::boost::serialization::access;
00365         template<class Archive>
00366             void save(Archive & ar, const unsigned int archive_version) const
00367             {
00368 
00369                 SG_DEBUG("archiving CKernelMachine\n");
00370 
00371                 ar & kernel;
00372                 ar & use_batch_computation;
00373                 ar & use_linadd;
00374                 ar & use_bias;
00375                 ar & m_bias;
00376                 ar & num_svs;
00377 
00378 
00379                 for (int32_t i=0; i < num_svs; ++i) {
00380                     ar & m_alpha[i];
00381                     ar & m_svs[i];
00382                 }
00383 
00384                 SG_DEBUG("done with CKernelMachine\n");
00385 
00386             }
00387 
00388         template<class Archive>
00389             void load(Archive & ar, const unsigned int archive_version)
00390             {
00391 
00392                 SG_DEBUG("archiving CKernelMachine\n");
00393 
00394                 ar & kernel;
00395                 ar & use_batch_computation;
00396                 ar & use_linadd;
00397                 ar & use_bias;
00398                 ar & m_bias;
00399                 ar & num_svs;
00400 
00401 
00402                 if (num_svs > 0)
00403                 {
00404 
00405                     m_alpha = new float64_t[num_svs];
00406                     m_svs = new int32_t[num_svs];
00407                     for (int32_t i=0; i< num_svs; ++i){
00408                         ar & m_alpha[i];
00409                         //ar & m_svs[i];
00410                     }
00411 
00412                 }
00413 
00414 
00415                 SG_DEBUG("done with CKernelMachine\n");
00416             }
00417 
00418         GLOBAL_BOOST_SERIALIZATION_SPLIT_MEMBER()
00419 
00420         */
00421 
00422 #endif //HAVE_BOOST_SERIALIZATION
00423 
00424     protected:
00426         CKernel* kernel;
00428         bool use_batch_computation;
00430         bool use_linadd;
00432         bool use_bias;
00434         float64_t m_bias;
00436         float64_t* m_alpha;
00438         int32_t* m_svs;
00440         int32_t num_svs;
00441 };
00442 }
00443 #endif /* _KERNEL_MACHINE_H__ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation