KNN.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  *
00008  * Written (W) 2006 Christian Gehl
00009  * Written (W) 2006-2009 Soeren Sonnenburg
00010  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00011  */
00012 
00013 #include "classifier/KNN.h"
00014 #include "features/Labels.h"
00015 #include "lib/Mathematics.h"
00016 
00017 CKNN::CKNN()
00018 : CDistanceMachine(), k(3), num_classes(0), num_train_labels(0), train_labels(NULL)
00019 {
00020 }
00021 
00022 CKNN::CKNN(int32_t k_, CDistance* d, CLabels* trainlab)
00023 : CDistanceMachine(), k(k_), num_classes(0), train_labels(NULL)
00024 {
00025     set_distance(d);
00026     set_labels(trainlab);
00027     num_train_labels=trainlab->get_num_labels();
00028 }
00029 
00030 
00031 CKNN::~CKNN()
00032 {
00033     delete[] train_labels;
00034 }
00035 
00036 bool CKNN::train()
00037 {
00038     ASSERT(labels);
00039 
00040     train_labels=labels->get_int_labels(num_train_labels);
00041     ASSERT(train_labels);
00042     ASSERT(num_train_labels>0);
00043 
00044     int32_t max_class=train_labels[0];
00045     int32_t min_class=train_labels[0];
00046 
00047     int32_t i;
00048     for (i=1; i<num_train_labels; i++)
00049     {
00050         max_class=CMath::max(max_class, train_labels[i]);
00051         min_class=CMath::min(min_class, train_labels[i]);
00052     }
00053 
00054     for (i=0; i<num_train_labels; i++)
00055         train_labels[i]-=min_class;
00056 
00057     min_label=min_class;
00058     num_classes=max_class-min_class+1;
00059 
00060     SG_INFO( "num_classes: %d (%+d to %+d) num_train: %d\n", num_classes, min_class, max_class, num_train_labels);
00061     return true;
00062 }
00063 
00064 CLabels* CKNN::classify(CLabels* output)
00065 {
00066     ASSERT(num_classes>0);
00067     ASSERT(distance);
00068     ASSERT(distance->get_num_vec_rhs());
00069 
00070     int32_t num_lab=distance->get_num_vec_rhs();
00071     ASSERT(k<=num_lab);
00072 
00073     if (output && output->get_num_labels()!=num_lab)
00074         SG_ERROR("Number of labels mismatches number of outputs\n");
00075 
00076     if (!output)
00077         output=new CLabels(num_lab);
00078 
00079     //distances to train data and working buffer of train_labels
00080     float64_t* dists=new float64_t[num_train_labels];
00081     int32_t* train_lab=new int32_t[num_train_labels];
00082 
00084     int32_t* classes=new int32_t[num_classes];
00085 
00086     ASSERT(dists);
00087     ASSERT(train_lab);
00088     ASSERT(classes);
00089 
00090     SG_INFO( "%d test examples\n", num_lab);
00091     for (int32_t i=0; i<num_lab; i++)
00092     {
00093         if ((i%(num_lab/10+1))== 0)
00094             SG_PROGRESS(i, 0, num_lab);
00095 
00096         int32_t j;
00097         for (j=0; j<num_train_labels; j++)
00098         {
00099             //copy back train labels and compute distance
00100             train_lab[j]=train_labels[j];
00101             
00102             dists[j]=distance->distance(j,i);
00103         }
00104 
00105         //sort the distance vector for test example j to all train examples
00106         //classes[1..k] then holds the classes for minimum distance
00107         CMath::qsort_index(dists, train_lab, num_train_labels);
00108 
00109         //compute histogram of class outputs of the first k nearest neighbours
00110         for (j=0; j<num_classes; j++)
00111             classes[j]=0;
00112 
00113         for (j=0; j<k; j++)
00114             classes[train_lab[j]]++;
00115 
00116         //choose the class that got 'outputted' most often
00117         int32_t out_idx=0;
00118         int32_t out_max=0;
00119 
00120         for (j=0; j<num_classes; j++)
00121         {
00122             if (out_max< classes[j])
00123             {
00124                 out_idx= j;
00125                 out_max= classes[j];
00126             }
00127         }
00128 
00129         output->set_label(i, out_idx+min_label);
00130     }
00131 
00132     delete[] dists;
00133     delete[] train_lab;
00134     delete[] classes;
00135 
00136     return output;
00137 }
00138 
00139 bool CKNN::load(FILE* srcfile)
00140 {
00141     return false;
00142 }
00143 
00144 bool CKNN::save(FILE* dstfile)
00145 {
00146     return false;
00147 }

SHOGUN Machine Learning Toolbox - Documentation