00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Soeren Sonnenburg 00008 * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _VARIANCEKERNELNORMALIZER_H___ 00012 #define _VARIANCEKERNELNORMALIZER_H___ 00013 00014 #include "kernel/KernelNormalizer.h" 00015 00016 namespace shogun 00017 { 00027 class CVarianceKernelNormalizer : public CKernelNormalizer 00028 { 00029 public: 00032 CVarianceKernelNormalizer() : meandiff(1.0), sqrt_meandiff(1.0) 00033 { 00034 } 00035 00037 virtual ~CVarianceKernelNormalizer() 00038 { 00039 } 00040 00043 virtual bool init(CKernel* k) 00044 { 00045 ASSERT(k); 00046 int32_t n=k->get_num_vec_lhs(); 00047 ASSERT(n>0); 00048 00049 CFeatures* old_lhs=k->lhs; 00050 CFeatures* old_rhs=k->rhs; 00051 k->lhs=old_lhs; 00052 k->rhs=old_lhs; 00053 00054 float64_t diag_mean=0; 00055 float64_t overall_mean=0; 00056 for (int32_t i=0; i<n; i++) 00057 { 00058 diag_mean+=k->compute(i, i); 00059 00060 for (int32_t j=0; j<n; j++) 00061 overall_mean+=k->compute(i, j); 00062 } 00063 diag_mean/=n; 00064 overall_mean/=((float64_t) n)*n; 00065 00066 k->lhs=old_lhs; 00067 k->rhs=old_rhs; 00068 00069 meandiff=1.0/(diag_mean-overall_mean); 00070 sqrt_meandiff=CMath::sqrt(meandiff); 00071 00072 return true; 00073 } 00074 00080 inline virtual float64_t normalize( 00081 float64_t value, int32_t idx_lhs, int32_t idx_rhs) 00082 { 00083 return value*meandiff; 00084 } 00085 00090 inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs) 00091 { 00092 return value*sqrt_meandiff; 00093 } 00094 00099 inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs) 00100 { 00101 return value*sqrt_meandiff; 00102 } 00103 00105 inline virtual const char* get_name() const { return "VarianceKernelNormalizer"; } 00106 00107 protected: 00109 float64_t meandiff; 00111 float64_t sqrt_meandiff; 00112 }; 00113 } 00114 #endif