Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _CLASSIFIER_H__
00012 #define _CLASSIFIER_H__
00013
00014 #include "lib/common.h"
00015 #include "base/SGObject.h"
00016 #include "lib/Mathematics.h"
00017 #include "features/Labels.h"
00018 #include "features/Features.h"
00019
00020 namespace shogun
00021 {
00022
00023 class CFeatures;
00024 class CLabels;
00025 class CMath;
00026
00027 enum EClassifierType
00028 {
00029 CT_NONE = 0,
00030 CT_LIGHT = 10,
00031 CT_LIBSVM = 20,
00032 CT_LIBSVMONECLASS=30,
00033 CT_LIBSVMMULTICLASS=40,
00034 CT_MPD = 50,
00035 CT_GPBT = 60,
00036 CT_CPLEXSVM = 70,
00037 CT_PERCEPTRON = 80,
00038 CT_KERNELPERCEPTRON = 90,
00039 CT_LDA = 100,
00040 CT_LPM = 110,
00041 CT_LPBOOST = 120,
00042 CT_KNN = 130,
00043 CT_SVMLIN=140,
00044 CT_KRR = 150,
00045 CT_GNPPSVM = 160,
00046 CT_GMNPSVM = 170,
00047 CT_SUBGRADIENTSVM = 180,
00048 CT_SUBGRADIENTLPM = 190,
00049 CT_SVMPERF = 200,
00050 CT_LIBSVR = 210,
00051 CT_SVRLIGHT = 220,
00052 CT_LIBLINEAR = 230,
00053 CT_KMEANS = 240,
00054 CT_HIERARCHICAL = 250,
00055 CT_SVMOCAS = 260,
00056 CT_WDSVMOCAS = 270,
00057 CT_SVMSGD = 280,
00058 CT_MKLMULTICLASS = 290,
00059 CT_MKLCLASSIFICATION = 300,
00060 CT_MKLONECLASS = 310,
00061 CT_MKLREGRESSION = 320,
00062 CT_SCATTERSVM = 330,
00063 CT_DASVM = 340,
00064 CT_LARANK = 350
00065 };
00066
00067 enum ESolverType
00068 {
00069 ST_AUTO=0,
00070 ST_CPLEX=1,
00071 ST_GLPK=2,
00072 ST_NEWTON=3,
00073 ST_DIRECT=4,
00074 ST_ELASTICNET=5
00075 };
00076
00088 class CClassifier : public CSGObject
00089 {
00090 public:
00092 CClassifier();
00093 virtual ~CClassifier();
00094
00103 virtual bool train(CFeatures* data=NULL)
00104 {
00105 SG_NOTIMPLEMENTED;
00106 return false;
00107 }
00108
00113 virtual CLabels* classify()=0;
00114
00120 virtual CLabels* classify(CFeatures* data)=0;
00121
00129 virtual float64_t classify_example(int32_t num)
00130 {
00131 SG_NOTIMPLEMENTED;
00132 return CMath::INFTY;
00133 }
00134
00142 virtual bool load(FILE* srcfile) { ASSERT(srcfile); return false; }
00143
00151 virtual bool save(FILE* dstfile) { ASSERT(dstfile); return false; }
00152
00157 virtual inline void set_labels(CLabels* lab)
00158 {
00159 SG_UNREF(labels);
00160 SG_REF(lab);
00161 labels=lab;
00162 }
00163
00168 virtual inline CLabels* get_labels() { SG_REF(labels); return labels; }
00169
00175 virtual inline float64_t get_label(int32_t i)
00176 {
00177 if (!labels)
00178 SG_ERROR("No Labels assigned\n");
00179
00180 return labels->get_label(i);
00181 }
00182
00187 inline void set_max_train_time(float64_t t) { max_train_time=t; }
00188
00193 inline float64_t get_max_train_time() { return max_train_time; }
00194
00199 virtual inline EClassifierType get_classifier_type() { return CT_NONE; }
00200
00205 inline void set_solver_type(ESolverType st) { solver_type=st; }
00206
00211 inline ESolverType get_solver_type() { return solver_type; }
00212
00213 #ifdef HAVE_BOOST_SERIALIZATION
00214 private:
00215
00216 friend class ::boost::serialization::access;
00217 template<class Archive>
00218 void serialize(Archive & ar, const unsigned int archive_version)
00219 {
00220
00221 SG_DEBUG("archiving CClassifier\n");
00222
00223 ar & ::boost::serialization::base_object<CSGObject>(*this);
00224 ar & max_train_time;
00225 ar & labels;
00226
00227 SG_DEBUG("done with CClassifier\n");
00228
00229 }
00230 #endif //HAVE_BOOST_SERIALIZATION
00231
00232 protected:
00234 float64_t max_train_time;
00235
00237 CLabels* labels;
00238
00240 ESolverType solver_type;
00241 };
00242 }
00243 #endif // _CLASSIFIER_H__