SVM_linear.h
Go to the documentation of this file.00001 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00002
00003 #ifndef _LIBLINEAR_H
00004 #define _LIBLINEAR_H
00005
00006 #include "lib/config.h"
00007
00008 #ifdef HAVE_LAPACK
00009 #include "classifier/svm/Tron.h"
00010 #include "features/DotFeatures.h"
00011
00012 #ifdef __cplusplus
00013 extern "C" {
00014 #endif
00015
00017 struct problem
00018 {
00020 int32_t l;
00022 int32_t n;
00024 int32_t *y;
00026 CDotFeatures* x;
00028 bool use_bias;
00029 };
00030
00032 struct parameter
00033 {
00035 int32_t solver_type;
00036
00037
00039 float64_t eps;
00041 float64_t C;
00043 int32_t nr_weight;
00045 int32_t *weight_label;
00047 float64_t* weight;
00048 };
00049
00051 struct model
00052 {
00054 struct parameter param;
00056 int32_t nr_class;
00058 int32_t nr_feature;
00060 float64_t *w;
00062 int32_t *label;
00064 float64_t bias;
00065 };
00066
00067 struct model* train(const struct problem *prob, const struct parameter *param);
00068 void cross_validation(
00069 const struct problem *prob, const struct parameter *param, int32_t nr_fold,
00070 int32_t *target);
00071
00072 int32_t predict_values(
00073 const struct model *model_, const struct feature_node *x,
00074 float64_t* dec_values);
00075 int32_t predict(const struct model *model_, const struct feature_node *x);
00076 int32_t predict_probability(
00077 const struct model *model_, const struct feature_node *x,
00078 float64_t* prob_estimates);
00079
00080 int32_t save_model(const char *model_file_name, const struct model *model_);
00081 struct model *load_model(const char *model_file_name);
00082
00083 int32_t get_nr_feature(const struct model *model_);
00084 int32_t get_nr_class(const struct model *model_);
00085 void get_labels(const struct model *model_, int32_t* label);
00086
00087 void destroy_model(struct model *model_);
00088 void destroy_param(struct parameter *param);
00089 const char *check_parameter(
00090 const struct problem *prob, const struct parameter *param);
00091
00092 #ifdef __cplusplus
00093 }
00094 #endif
00095
00097 class l2loss_svm_fun : public function
00098 {
00099 public:
00106 l2loss_svm_fun(const problem *prob, float64_t Cp, float64_t Cn);
00107 ~l2loss_svm_fun();
00108
00114 float64_t fun(float64_t *w);
00115
00121 void grad(float64_t *w, float64_t *g);
00122
00128 void Hv(float64_t *s, float64_t *Hs);
00129
00134 int32_t get_nr_variable(void);
00135
00136 private:
00137 void Xv(float64_t *v, float64_t *Xv);
00138 void subXv(float64_t *v, float64_t *Xv);
00139 void subXTv(float64_t *v, float64_t *XTv);
00140
00141 float64_t *C;
00142 float64_t *z;
00143 float64_t *D;
00144 int32_t *I;
00145 int32_t sizeI;
00146 const problem *prob;
00147 };
00148
00150 class l2_lr_fun : public function
00151 {
00152 public:
00159 l2_lr_fun(const problem *prob, float64_t Cp, float64_t Cn);
00160 ~l2_lr_fun();
00161
00167 float64_t fun(float64_t *w);
00168
00174 void grad(float64_t *w, float64_t *g);
00175
00181 void Hv(float64_t *s, float64_t *Hs);
00182
00183 int32_t get_nr_variable(void);
00184
00185 private:
00186 void Xv(float64_t *v, float64_t *Xv);
00187 void XTv(float64_t *v, float64_t *XTv);
00188
00189 float64_t *C;
00190 float64_t *z;
00191 float64_t *D;
00192 const problem *prob;
00193 };
00194 #endif //HAVE_LAPACK
00195 #endif //_LIBLINEAR_H
00196
00197 #endif // DOXYGEN_SHOULD_SKIP_THIS