Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _SQRTDIAGKERNELNORMALIZER_H___
00012 #define _SQRTDIAGKERNELNORMALIZER_H___
00013
00014 #include "kernel/KernelNormalizer.h"
00015 #include "kernel/CommWordStringKernel.h"
00016
00017 namespace shogun
00018 {
00029 class CSqrtDiagKernelNormalizer : public CKernelNormalizer
00030 {
00031 public:
00036 CSqrtDiagKernelNormalizer(bool use_opt_diag=false): sqrtdiag_lhs(NULL),
00037 sqrtdiag_rhs(NULL), use_optimized_diagonal_computation(use_opt_diag)
00038 {
00039 }
00040
00042 virtual ~CSqrtDiagKernelNormalizer()
00043 {
00044 delete[] sqrtdiag_lhs;
00045 delete[] sqrtdiag_rhs;
00046 }
00047
00050 virtual bool init(CKernel* k)
00051 {
00052 ASSERT(k);
00053 int32_t num_lhs=k->get_num_vec_lhs();
00054 int32_t num_rhs=k->get_num_vec_rhs();
00055 ASSERT(num_lhs>0);
00056 ASSERT(num_rhs>0);
00057
00058 CFeatures* old_lhs=k->lhs;
00059 CFeatures* old_rhs=k->rhs;
00060
00061 k->lhs=old_lhs;
00062 k->rhs=old_lhs;
00063 bool r1=alloc_and_compute_diag(k, sqrtdiag_lhs, num_lhs);
00064
00065 k->lhs=old_rhs;
00066 k->rhs=old_rhs;
00067 bool r2=alloc_and_compute_diag(k, sqrtdiag_rhs, num_rhs);
00068
00069 k->lhs=old_lhs;
00070 k->rhs=old_rhs;
00071
00072 return r1 && r2;
00073 }
00074
00080 inline virtual float64_t normalize(
00081 float64_t value, int32_t idx_lhs, int32_t idx_rhs)
00082 {
00083 float64_t sqrt_both=sqrtdiag_lhs[idx_lhs]*sqrtdiag_rhs[idx_rhs];
00084 return value/sqrt_both;
00085 }
00086
00091 inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs)
00092 {
00093 return value/sqrtdiag_lhs[idx_lhs];
00094 }
00095
00100 inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs)
00101 {
00102 return value/sqrtdiag_rhs[idx_rhs];
00103 }
00104
00105 public:
00110 bool alloc_and_compute_diag(CKernel* k, float64_t* &v, int32_t num)
00111 {
00112 delete[] v;
00113 v=new float64_t[num];
00114
00115 for (int32_t i=0; i<num; i++)
00116 {
00117 if (k->get_kernel_type() == K_COMMWORDSTRING)
00118 {
00119 if (use_optimized_diagonal_computation)
00120 v[i]=sqrt(((CCommWordStringKernel*) k)->compute_diag(i));
00121 else
00122 v[i]=sqrt(((CCommWordStringKernel*) k)->compute_helper(i,i, true));
00123 }
00124 else
00125 v[i]=sqrt(k->compute(i,i));
00126
00127 if (v[i]==0.0)
00128 v[i]=1e-16;
00129 }
00130
00131 return (v!=NULL);
00132 }
00133
00135 inline virtual const char* get_name() const { return "SqrtDiagKernelNormalizer"; }
00136
00137 protected:
00139 float64_t* sqrtdiag_lhs;
00141 float64_t* sqrtdiag_rhs;
00143 bool use_optimized_diagonal_computation;
00144 };
00145 }
00146 #endif