KMeans.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _KMEANS_H__
00013 #define _KMEANS_H__
00014
00015 #include <stdio.h>
00016 #include "lib/common.h"
00017 #include "lib/io.h"
00018 #include "features/SimpleFeatures.h"
00019 #include "distance/Distance.h"
00020 #include "distance/DistanceMachine.h"
00021
00022 class CDistanceMachine;
00023
00037 class CKMeans : public CDistanceMachine
00038 {
00039 public:
00041 CKMeans();
00042
00048 CKMeans(int32_t k, CDistance* d);
00049 virtual ~CKMeans();
00050
00055 virtual inline EClassifierType get_classifier_type() { return CT_KMEANS; }
00056
00061 virtual bool train();
00062
00068 virtual bool load(FILE* srcfile);
00069
00075 virtual bool save(FILE* dstfile);
00076
00081 inline void set_k(int32_t p_k)
00082 {
00083 ASSERT(p_k>0);
00084 this->k=p_k;
00085 }
00086
00091 inline int32_t get_k()
00092 {
00093 return k;
00094 }
00095
00100 inline void set_max_iter(int32_t iter)
00101 {
00102 ASSERT(iter>0);
00103 max_iter=iter;
00104 }
00105
00110 inline float64_t get_max_iter()
00111 {
00112 return max_iter;
00113 }
00114
00120 inline void get_radi(float64_t*& radi, int32_t& num)
00121 {
00122 radi=R;
00123 num=k;
00124 }
00125
00132 inline void get_centers(float64_t*& centers, int32_t& dim, int32_t& num)
00133 {
00134 centers=mus;
00135 dim=dimensions;
00136 num=k;
00137 }
00138
00144 inline void get_radiuses(float64_t** radii, int32_t* num)
00145 {
00146 size_t sz=sizeof(*R)*k;
00147 *radii=(float64_t*) malloc(sz);
00148 ASSERT(*radii);
00149
00150 memcpy(*radii, R, sz);
00151 *num=k;
00152 }
00153
00160 inline void get_cluster_centers(
00161 float64_t** centers, int32_t* dim, int32_t* num)
00162 {
00163 size_t sz=sizeof(*mus)*dimensions*k;
00164 *centers=(float64_t*) malloc(sz);
00165 ASSERT(*centers);
00166
00167 memcpy(*centers, mus, sz);
00168 *dim=dimensions;
00169 *num=k;
00170 }
00171
00176 inline int32_t get_dimensions()
00177 {
00178 return dimensions;
00179 }
00180
00181
00182 protected:
00193 void sqdist(
00194 float64_t* x, CSimpleFeatures<float64_t>* y, float64_t *z, int32_t n1,
00195 int32_t offs, int32_t n2, int32_t m);
00196
00202 void clustknb(bool use_old_mus, float64_t *mus_start);
00203
00205 inline virtual const char* get_name() const { return "KMeans"; }
00206
00207 protected:
00209 int32_t max_iter;
00210
00212 int32_t k;
00213
00215 int32_t dimensions;
00216
00218 float64_t* R;
00219
00221 float64_t* mus;
00222
00223 private:
00225 float64_t* Weights;
00226 };
00227 #endif
00228