KNN.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 #include "classifier/KNN.h"
00014 #include "features/Labels.h"
00015 #include "lib/Mathematics.h"
00016
00017 CKNN::CKNN()
00018 : CDistanceMachine(), k(3), num_classes(0), num_train_labels(0), train_labels(NULL)
00019 {
00020 }
00021
00022 CKNN::CKNN(int32_t k_, CDistance* d, CLabels* trainlab)
00023 : CDistanceMachine(), k(k_), num_classes(0), train_labels(NULL)
00024 {
00025 set_distance(d);
00026 set_labels(trainlab);
00027 num_train_labels=trainlab->get_num_labels();
00028 }
00029
00030
00031 CKNN::~CKNN()
00032 {
00033 delete[] train_labels;
00034 }
00035
00036 bool CKNN::train()
00037 {
00038 ASSERT(labels);
00039
00040 train_labels=labels->get_int_labels(num_train_labels);
00041 ASSERT(train_labels);
00042 ASSERT(num_train_labels>0);
00043
00044 int32_t max_class=train_labels[0];
00045 int32_t min_class=train_labels[0];
00046
00047 int32_t i;
00048 for (i=1; i<num_train_labels; i++)
00049 {
00050 max_class=CMath::max(max_class, train_labels[i]);
00051 min_class=CMath::min(min_class, train_labels[i]);
00052 }
00053
00054 for (i=0; i<num_train_labels; i++)
00055 train_labels[i]-=min_class;
00056
00057 min_label=min_class;
00058 num_classes=max_class-min_class+1;
00059
00060 SG_INFO( "num_classes: %d (%+d to %+d) num_train: %d\n", num_classes, min_class, max_class, num_train_labels);
00061 return true;
00062 }
00063
00064 CLabels* CKNN::classify(CLabels* output)
00065 {
00066 ASSERT(num_classes>0);
00067 ASSERT(distance);
00068 ASSERT(distance->get_num_vec_rhs());
00069
00070 int32_t num_lab=distance->get_num_vec_rhs();
00071 ASSERT(k<=num_lab);
00072
00073 if (output && output->get_num_labels()!=num_lab)
00074 SG_ERROR("Number of labels mismatches number of outputs\n");
00075
00076 if (!output)
00077 output=new CLabels(num_lab);
00078
00079
00080 float64_t* dists=new float64_t[num_train_labels];
00081 int32_t* train_lab=new int32_t[num_train_labels];
00082
00084 int32_t* classes=new int32_t[num_classes];
00085
00086 ASSERT(dists);
00087 ASSERT(train_lab);
00088 ASSERT(classes);
00089
00090 SG_INFO( "%d test examples\n", num_lab);
00091 for (int32_t i=0; i<num_lab; i++)
00092 {
00093 if ((i%(num_lab/10+1))== 0)
00094 SG_PROGRESS(i, 0, num_lab);
00095
00096 int32_t j;
00097 for (j=0; j<num_train_labels; j++)
00098 {
00099
00100 train_lab[j]=train_labels[j];
00101
00102 dists[j]=distance->distance(j,i);
00103 }
00104
00105
00106
00107 CMath::qsort_index(dists, train_lab, num_train_labels);
00108
00109
00110 for (j=0; j<num_classes; j++)
00111 classes[j]=0;
00112
00113 for (j=0; j<k; j++)
00114 classes[train_lab[j]]++;
00115
00116
00117 int32_t out_idx=0;
00118 int32_t out_max=0;
00119
00120 for (j=0; j<num_classes; j++)
00121 {
00122 if (out_max< classes[j])
00123 {
00124 out_idx= j;
00125 out_max= classes[j];
00126 }
00127 }
00128
00129 output->set_label(i, out_idx+min_label);
00130 }
00131
00132 delete[] dists;
00133 delete[] train_lab;
00134 delete[] classes;
00135
00136 return output;
00137 }
00138
00139 bool CKNN::load(FILE* srcfile)
00140 {
00141 return false;
00142 }
00143
00144 bool CKNN::save(FILE* dstfile)
00145 {
00146 return false;
00147 }