00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _LINEARCLASSIFIER_H__ 00012 #define _LINEARCLASSIFIER_H__ 00013 00014 #include "lib/common.h" 00015 #include "features/Labels.h" 00016 #include "features/DotFeatures.h" 00017 #include "classifier/Classifier.h" 00018 00019 #include <stdio.h> 00020 00055 class CLinearClassifier : public CClassifier 00056 { 00057 public: 00059 CLinearClassifier(); 00060 virtual ~CLinearClassifier(); 00061 00063 virtual inline float64_t classify_example(int32_t vec_idx) 00064 { 00065 return features->dense_dot(vec_idx, w, w_dim) + bias; 00066 } 00067 00073 inline void get_w(float64_t*& dst_w, int32_t& dst_dims) 00074 { 00075 ASSERT(w && features); 00076 dst_w=w; 00077 dst_dims=features->get_dim_feature_space(); 00078 } 00079 00085 inline void get_w(float64_t** dst_w, int32_t* dst_dims) 00086 { 00087 ASSERT(dst_w && dst_dims); 00088 ASSERT(w && features); 00089 *dst_dims=features->get_dim_feature_space(); 00090 *dst_w=(float64_t*) malloc(sizeof(float64_t)*(*dst_dims)); 00091 ASSERT(*dst_w); 00092 memcpy(*dst_w, w, sizeof(float64_t) * (*dst_dims)); 00093 } 00094 00100 inline void set_w(float64_t* src_w, int32_t src_w_dim) 00101 { 00102 w=src_w; 00103 w_dim=src_w_dim; 00104 } 00105 00110 inline void set_bias(float64_t b) 00111 { 00112 bias=b; 00113 } 00114 00119 inline float64_t get_bias() 00120 { 00121 return bias; 00122 } 00123 00129 virtual bool load(FILE* srcfile); 00130 00136 virtual bool save(FILE* dstfile); 00137 00143 virtual CLabels* classify(CLabels* output=NULL); 00144 00149 virtual inline void set_features(CDotFeatures* feat) 00150 { 00151 SG_UNREF(features); 00152 SG_REF(feat); 00153 features=feat; 00154 } 00155 00160 virtual CDotFeatures* get_features() { SG_REF(features); return features; } 00161 00162 protected: 00164 int32_t w_dim; 00166 float64_t* w; 00168 float64_t bias; 00170 CDotFeatures* features; 00171 }; 00172 #endif