00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _KERNEL_MACHINE_H__ 00012 #define _KERNEL_MACHINE_H__ 00013 00014 #include "lib/common.h" 00015 #include "lib/io.h" 00016 #include "kernel/Kernel.h" 00017 #include "features/Labels.h" 00018 #include "classifier/Classifier.h" 00019 00020 #include <stdio.h> 00021 00022 namespace shogun 00023 { 00024 class CClassifier; 00025 class CLabels; 00026 class CKernel; 00027 00043 class CKernelMachine : public CClassifier 00044 { 00045 public: 00047 CKernelMachine(); 00048 00050 virtual ~CKernelMachine(); 00051 00056 inline void set_kernel(CKernel* k) 00057 { 00058 SG_UNREF(kernel); 00059 SG_REF(k); 00060 kernel=k; 00061 } 00062 00067 inline CKernel* get_kernel() 00068 { 00069 SG_REF(kernel); 00070 return kernel; 00071 } 00072 00077 inline void set_batch_computation_enabled(bool enable) 00078 { 00079 use_batch_computation=enable; 00080 } 00081 00086 inline bool get_batch_computation_enabled() 00087 { 00088 return use_batch_computation; 00089 } 00090 00095 inline void set_linadd_enabled(bool enable) 00096 { 00097 use_linadd=enable; 00098 } 00099 00104 inline bool get_linadd_enabled() 00105 { 00106 return use_linadd ; 00107 } 00108 00113 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00114 00119 inline bool get_bias_enabled() { return use_bias; } 00120 00125 inline float64_t get_bias() 00126 { 00127 return m_bias; 00128 } 00129 00134 inline void set_bias(float64_t bias) 00135 { 00136 m_bias=bias; 00137 } 00138 00144 inline int32_t get_support_vector(int32_t idx) 00145 { 00146 ASSERT(m_svs && idx<num_svs); 00147 return m_svs[idx]; 00148 } 00149 00155 inline float64_t get_alpha(int32_t idx) 00156 { 00157 ASSERT(m_alpha && idx<num_svs); 00158 return m_alpha[idx]; 00159 } 00160 00167 inline bool set_support_vector(int32_t idx, int32_t val) 00168 { 00169 if (m_svs && idx<num_svs) 00170 m_svs[idx]=val; 00171 else 00172 return false; 00173 00174 return true; 00175 } 00176 00183 inline bool set_alpha(int32_t idx, float64_t val) 00184 { 00185 if (m_alpha && idx<num_svs) 00186 m_alpha[idx]=val; 00187 else 00188 return false; 00189 00190 return true; 00191 } 00192 00197 inline int32_t get_num_support_vectors() 00198 { 00199 return num_svs; 00200 } 00201 00207 void set_alphas(float64_t* alphas, int32_t d) 00208 { 00209 ASSERT(alphas); 00210 ASSERT(m_alpha); 00211 ASSERT(d==num_svs); 00212 00213 for(int32_t i=0; i<d; i++) 00214 m_alpha[i]=alphas[i]; 00215 } 00216 00222 void set_support_vectors(int32_t* svs, int32_t d) 00223 { 00224 ASSERT(m_svs); 00225 ASSERT(svs); 00226 ASSERT(d==num_svs); 00227 00228 for(int32_t i=0; i<d; i++) 00229 m_svs[i]=svs[i]; 00230 } 00231 00237 void get_support_vectors(int32_t** svs, int32_t* num) 00238 { 00239 int32_t nsv = get_num_support_vectors(); 00240 00241 ASSERT(svs && num); 00242 *svs=NULL; 00243 *num=nsv; 00244 00245 if (nsv>0) 00246 { 00247 *svs = (int32_t*) malloc(sizeof(int32_t)*nsv); 00248 for(int32_t i=0; i<nsv; i++) 00249 (*svs)[i] = get_support_vector(i); 00250 } 00251 } 00252 00258 void get_alphas(float64_t** alphas, int32_t* d1) 00259 { 00260 int32_t nsv = get_num_support_vectors(); 00261 00262 ASSERT(alphas && d1); 00263 *alphas=NULL; 00264 *d1=nsv; 00265 00266 if (nsv>0) 00267 { 00268 *alphas = (float64_t*) malloc(nsv*sizeof(float64_t)); 00269 for(int32_t i=0; i<nsv; i++) 00270 (*alphas)[i] = get_alpha(i); 00271 } 00272 } 00273 00278 inline bool create_new_model(int32_t num) 00279 { 00280 delete[] m_alpha; 00281 delete[] m_svs; 00282 00283 m_bias=0; 00284 num_svs=num; 00285 00286 if (num>0) 00287 { 00288 m_alpha= new float64_t[num]; 00289 m_svs= new int32_t[num]; 00290 return (m_alpha!=NULL && m_svs!=NULL); 00291 } 00292 else 00293 { 00294 m_alpha= NULL; 00295 m_svs=NULL; 00296 return true; 00297 } 00298 } 00299 00304 bool init_kernel_optimization(); 00305 00310 virtual CLabels* classify(); 00311 00317 virtual CLabels* classify(CFeatures* data); 00318 00324 virtual float64_t classify_example(int32_t num); 00325 00331 static void* classify_example_helper(void* p); 00332 00333 #ifdef HAVE_BOOST_SERIALIZATION 00334 private: 00335 00336 00337 friend class ::boost::serialization::access; 00338 // When the class Archive corresponds to an output archive, the 00339 // & operator is defined similar to <<. Likewise, when the class Archive 00340 // is a type of input archive the & operator is defined similar to >>. 00341 template<class Archive> 00342 00343 void serialize(Archive & ar, const unsigned int archive_version) 00344 { 00345 00346 SG_DEBUG("archiving CKernelMachine\n"); 00347 ar & ::boost::serialization::base_object<CClassifier>(*this); 00348 00349 ar & kernel; 00350 ar & use_batch_computation; 00351 ar & use_linadd; 00352 ar & use_bias; 00353 ar & m_bias; 00354 00355 SG_DEBUG("done with CKernelMachine\n"); 00356 } 00357 00358 00359 /* 00360 00364 friend class ::boost::serialization::access; 00365 template<class Archive> 00366 void save(Archive & ar, const unsigned int archive_version) const 00367 { 00368 00369 SG_DEBUG("archiving CKernelMachine\n"); 00370 00371 ar & kernel; 00372 ar & use_batch_computation; 00373 ar & use_linadd; 00374 ar & use_bias; 00375 ar & m_bias; 00376 ar & num_svs; 00377 00378 00379 for (int32_t i=0; i < num_svs; ++i) { 00380 ar & m_alpha[i]; 00381 ar & m_svs[i]; 00382 } 00383 00384 SG_DEBUG("done with CKernelMachine\n"); 00385 00386 } 00387 00388 template<class Archive> 00389 void load(Archive & ar, const unsigned int archive_version) 00390 { 00391 00392 SG_DEBUG("archiving CKernelMachine\n"); 00393 00394 ar & kernel; 00395 ar & use_batch_computation; 00396 ar & use_linadd; 00397 ar & use_bias; 00398 ar & m_bias; 00399 ar & num_svs; 00400 00401 00402 if (num_svs > 0) 00403 { 00404 00405 m_alpha = new float64_t[num_svs]; 00406 m_svs = new int32_t[num_svs]; 00407 for (int32_t i=0; i< num_svs; ++i){ 00408 ar & m_alpha[i]; 00409 //ar & m_svs[i]; 00410 } 00411 00412 } 00413 00414 00415 SG_DEBUG("done with CKernelMachine\n"); 00416 } 00417 00418 GLOBAL_BOOST_SERIALIZATION_SPLIT_MEMBER() 00419 00420 */ 00421 00422 #endif //HAVE_BOOST_SERIALIZATION 00423 00424 protected: 00426 CKernel* kernel; 00428 bool use_batch_computation; 00430 bool use_linadd; 00432 bool use_bias; 00434 float64_t m_bias; 00436 float64_t* m_alpha; 00438 int32_t* m_svs; 00440 int32_t num_svs; 00441 }; 00442 } 00443 #endif /* _KERNEL_MACHINE_H__ */