00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2006 Christian Gehl 00008 * Written (W) 1999-2009 Soeren Sonnenburg 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _KNN_H__ 00013 #define _KNN_H__ 00014 00015 #include <stdio.h> 00016 #include "lib/common.h" 00017 #include "lib/io.h" 00018 #include "features/Features.h" 00019 #include "distance/Distance.h" 00020 #include "distance/DistanceMachine.h" 00021 00022 class CDistanceMachine; 00023 00037 class CKNN : public CDistanceMachine 00038 { 00039 public: 00041 CKNN(); 00042 00049 CKNN(int32_t k, CDistance* d, CLabels* trainlab); 00050 virtual ~CKNN(); 00051 00056 virtual inline EClassifierType get_classifier_type() { return CT_KNN; } 00057 //inline EDistanceType get_distance_type() { return DT_KNN;} 00058 00063 virtual bool train(); 00064 00070 virtual CLabels* classify(CLabels* output=NULL); 00071 00073 virtual float64_t classify_example(int32_t vec_idx) 00074 { 00075 SG_ERROR( "for performance reasons use classify() instead of classify_example\n"); 00076 return 0; 00077 } 00078 00084 virtual bool load(FILE* srcfile); 00085 00091 virtual bool save(FILE* dstfile); 00092 00097 inline void set_k(float64_t p_k) 00098 { 00099 ASSERT(p_k>0); 00100 this->k=p_k; 00101 } 00102 00107 inline float64_t get_k() 00108 { 00109 return k; 00110 } 00111 00113 inline virtual const char* get_name() const { return "KNN"; } 00114 00115 protected: 00117 float64_t k; 00118 00120 int32_t num_classes; 00121 00123 int32_t min_label; 00124 00126 int32_t num_train_labels; 00127 00129 int32_t* train_labels; 00130 }; 00131 #endif 00132