LinearClassifier.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _LINEARCLASSIFIER_H__
00012 #define _LINEARCLASSIFIER_H__
00013 
00014 #include "lib/common.h"
00015 #include "features/Labels.h"
00016 #include "features/DotFeatures.h"
00017 #include "classifier/Classifier.h"
00018 
00019 #include <stdio.h>
00020 
00055 class CLinearClassifier : public CClassifier
00056 {
00057     public:
00059         CLinearClassifier();
00060         virtual ~CLinearClassifier();
00061 
00063         virtual inline float64_t classify_example(int32_t vec_idx)
00064         {
00065             return features->dense_dot(vec_idx, w, w_dim) + bias;
00066         }
00067 
00073         inline void get_w(float64_t*& dst_w, int32_t& dst_dims)
00074         {
00075             ASSERT(w && features);
00076             dst_w=w;
00077             dst_dims=features->get_dim_feature_space();
00078         }
00079 
00085         inline void get_w(float64_t** dst_w, int32_t* dst_dims)
00086         {
00087             ASSERT(dst_w && dst_dims);
00088             ASSERT(w && features);
00089             *dst_dims=features->get_dim_feature_space();
00090             *dst_w=(float64_t*) malloc(sizeof(float64_t)*(*dst_dims));
00091             ASSERT(*dst_w);
00092             memcpy(*dst_w, w, sizeof(float64_t) * (*dst_dims));
00093         }
00094 
00100         inline void set_w(float64_t* src_w, int32_t src_w_dim)
00101         {
00102             w=src_w;
00103             w_dim=src_w_dim;
00104         }
00105 
00110         inline void set_bias(float64_t b)
00111         {
00112             bias=b;
00113         }
00114 
00119         inline float64_t get_bias()
00120         {
00121             return bias;
00122         }
00123 
00129         virtual bool load(FILE* srcfile);
00130 
00136         virtual bool save(FILE* dstfile);
00137 
00143         virtual CLabels* classify(CLabels* output=NULL);
00144 
00149         virtual inline void set_features(CDotFeatures* feat)
00150         {
00151             SG_UNREF(features);
00152             SG_REF(feat);
00153             features=feat;
00154         }
00155 
00160         virtual CDotFeatures* get_features() { SG_REF(features); return features; }
00161 
00162     protected:
00164         int32_t w_dim;
00166         float64_t* w;
00168         float64_t bias;
00170         CDotFeatures* features;
00171 };
00172 #endif

SHOGUN Machine Learning Toolbox - Documentation