00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include <stdlib.h>
00021 #include <string.h>
00022 #include <math.h>
00023 #include <sys/time.h>
00024 #include <time.h>
00025 #include <stdio.h>
00026 #include <stdint.h>
00027
00028 #include "classifier/svm/libocas.h"
00029 #include "classifier/svm/libocas_common.h"
00030 #include "classifier/svm/qpssvmlib.h"
00031
00032 namespace shogun
00033 {
00034
00035 static const uint32_t QPSolverMaxIter = 10000000;
00036
00037 static float64_t *H;
00038 static uint32_t BufSize;
00039
00040
00041
00042
00043 static const void *get_col( uint32_t i)
00044 {
00045 return( &H[ BufSize*i ] );
00046 }
00047
00048
00049
00050
00051 static float64_t get_time()
00052 {
00053 struct timeval tv;
00054 if (gettimeofday(&tv, NULL)==0)
00055 return (float64_t) (tv.tv_sec+((double)(tv.tv_usec))/1e6);
00056 else
00057 return 0.0;
00058 }
00059
00060
00061
00062
00063 ocas_return_value_T svm_ocas_solver(
00064 float64_t C,
00065 uint32_t nData,
00066 float64_t TolRel,
00067 float64_t TolAbs,
00068 float64_t QPBound,
00069 uint32_t _BufSize,
00070 uint8_t Method,
00071 void (*compute_W)(float64_t*, float64_t*, float64_t*, uint32_t, void*),
00072 float64_t (*update_W)(float64_t, void*),
00073 void (*add_new_cut)(float64_t*, uint32_t*, uint32_t, uint32_t, void*),
00074 void (*compute_output)(float64_t*, void* ),
00075 void (*sort)(float64_t*, uint32_t*, uint32_t),
00076 void* user_data)
00077 {
00078 ocas_return_value_T ocas;
00079 float64_t *b, *alpha, *diag_H;
00080 float64_t *output, *old_output;
00081 float64_t xi, sq_norm_W, QPSolverTolRel, dot_prod_WoldW, dummy, sq_norm_oldW;
00082 float64_t A0, B0, GradVal, t, t1=0, t2=0, *Ci, *Bi, *hpf;
00083 float64_t start_time;
00084 uint32_t *hpi;
00085 uint32_t cut_length;
00086 uint32_t i, *new_cut;
00087 uint16_t *I;
00088 int8_t qp_exitflag;
00089 float64_t gap;
00090
00091 ocas.ocas_time = get_time();
00092 ocas.solver_time = 0;
00093 ocas.output_time = 0;
00094 ocas.sort_time = 0;
00095 ocas.add_time = 0;
00096 ocas.w_time = 0;
00097
00098 BufSize = _BufSize;
00099
00100 QPSolverTolRel = TolRel*0.5;
00101
00102 H=NULL;
00103 b=NULL;
00104 alpha=NULL;
00105 new_cut=NULL;
00106 I=NULL;
00107 diag_H=NULL;
00108 output=NULL;
00109 old_output=NULL;
00110 hpf=NULL;
00111 hpi=NULL;
00112 Ci=NULL;
00113 Bi=NULL;
00114
00115
00116 H = (float64_t*)OCAS_CALLOC(BufSize*BufSize,sizeof(float64_t));
00117 if(H == NULL)
00118 {
00119 ocas.exitflag=-2;
00120 goto cleanup;
00121 }
00122
00123
00124 b = (float64_t*)OCAS_CALLOC(BufSize,sizeof(float64_t));
00125 if(b == NULL)
00126 {
00127 ocas.exitflag=-2;
00128 goto cleanup;
00129 }
00130
00131 alpha = (float64_t*)OCAS_CALLOC(BufSize,sizeof(float64_t));
00132 if(alpha == NULL)
00133 {
00134 ocas.exitflag=-2;
00135 goto cleanup;
00136 }
00137
00138
00139 new_cut = (uint32_t*)OCAS_CALLOC(nData,sizeof(uint32_t));
00140 if(new_cut == NULL)
00141 {
00142 ocas.exitflag=-2;
00143 goto cleanup;
00144 }
00145
00146 I = (uint16_t*)OCAS_CALLOC(BufSize,sizeof(uint16_t));
00147 if(I == NULL)
00148 {
00149 ocas.exitflag=-2;
00150 goto cleanup;
00151 }
00152
00153 for(i=0; i< BufSize; i++) I[i] = 1;
00154
00155 diag_H = (float64_t*)OCAS_CALLOC(BufSize,sizeof(float64_t));
00156 if(diag_H == NULL)
00157 {
00158 ocas.exitflag=-2;
00159 goto cleanup;
00160 }
00161
00162 output = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t));
00163 if(output == NULL)
00164 {
00165 ocas.exitflag=-2;
00166 goto cleanup;
00167 }
00168
00169 old_output = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t));
00170 if(old_output == NULL)
00171 {
00172 ocas.exitflag=-2;
00173 goto cleanup;
00174 }
00175
00176
00177 hpf = (float64_t*) OCAS_CALLOC(nData, sizeof(hpf[0]));
00178 if(hpf == NULL)
00179 {
00180 ocas.exitflag=-2;
00181 goto cleanup;
00182 }
00183
00184 hpi = (uint32_t*) OCAS_CALLOC(nData, sizeof(hpi[0]));
00185 if(hpi == NULL)
00186 {
00187 ocas.exitflag=-2;
00188 goto cleanup;
00189 }
00190
00191
00192 Ci = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t));
00193 if(Ci == NULL)
00194 {
00195 ocas.exitflag=-2;
00196 goto cleanup;
00197 }
00198
00199 Bi = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t));
00200 if(Bi == NULL)
00201 {
00202 ocas.exitflag=-2;
00203 goto cleanup;
00204 }
00205
00206 ocas.nCutPlanes = 0;
00207 ocas.exitflag = 0;
00208 ocas.nIter = 0;
00209
00210
00211 sq_norm_W = 0;
00212 xi = nData;
00213 ocas.Q_P = 0.5*sq_norm_W + C*xi;
00214 ocas.Q_D = 0;
00215
00216
00217 cut_length = nData;
00218 for(i=0; i < nData; i++)
00219 new_cut[i] = i;
00220
00221 gap=(ocas.Q_P-ocas.Q_D)/CMath::abs(ocas.Q_P);
00222 SG_SABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(1), -CMath::log10(TolRel), 6);
00223
00224
00225 while( ocas.exitflag == 0 )
00226 {
00227 ocas.nIter++;
00228
00229
00230 b[ocas.nCutPlanes] = -(float64_t)cut_length;
00231
00232 start_time = get_time();
00233
00234 add_new_cut( &H[INDEX2(0,ocas.nCutPlanes,BufSize)], new_cut, cut_length, ocas.nCutPlanes, user_data );
00235
00236 ocas.add_time += get_time() - start_time;
00237
00238
00239 diag_H[ocas.nCutPlanes] = H[INDEX2(ocas.nCutPlanes,ocas.nCutPlanes,BufSize)];
00240 for(i=0; i < ocas.nCutPlanes; i++) {
00241 H[INDEX2(ocas.nCutPlanes,i,BufSize)] = H[INDEX2(i,ocas.nCutPlanes,BufSize)];
00242 }
00243
00244 ocas.nCutPlanes++;
00245
00246
00247 start_time = get_time();
00248
00249 qp_exitflag = qpssvm_solver( &get_col, diag_H, b, C, I, alpha,
00250 ocas.nCutPlanes, QPSolverMaxIter, 0.0, QPSolverTolRel, &ocas.Q_D, &dummy, 0 );
00251
00252 ocas.solver_time += get_time() - start_time;
00253
00254 ocas.Q_D = -ocas.Q_D;
00255
00256 ocas.nNZAlpha = 0;
00257 for(i=0; i < ocas.nCutPlanes; i++) {
00258 if( alpha[i] != 0) ocas.nNZAlpha++;
00259 }
00260
00261 sq_norm_oldW = sq_norm_W;
00262 start_time = get_time();
00263 compute_W( &sq_norm_W, &dot_prod_WoldW, alpha, ocas.nCutPlanes, user_data );
00264 ocas.w_time += get_time() - start_time;
00265
00266
00267 switch( Method )
00268 {
00269
00270 case 0:
00271
00272 start_time = get_time();
00273 compute_output( output, user_data );
00274 ocas.output_time += get_time()-start_time;
00275
00276 xi = 0;
00277 cut_length = 0;
00278 ocas.trn_err = 0;
00279 for(i=0; i < nData; i++)
00280 {
00281 if(output[i] <= 0) ocas.trn_err++;
00282
00283 if(output[i] <= 1) {
00284 xi += 1 - output[i];
00285 new_cut[cut_length] = i;
00286 cut_length++;
00287 }
00288 }
00289 ocas.Q_P = 0.5*sq_norm_W + C*xi;
00290
00291 gap=(ocas.Q_P-ocas.Q_D)/CMath::abs(ocas.Q_P);
00292 SG_SABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(1), -CMath::log10(TolRel), 6);
00293
00294 break;
00295
00296
00297
00298 case 1:
00299
00300
00301 A0 = sq_norm_W -2*dot_prod_WoldW + sq_norm_oldW;
00302 B0 = dot_prod_WoldW - sq_norm_oldW;
00303
00304 memcpy( old_output, output, sizeof(float64_t)*nData );
00305
00306 start_time = get_time();
00307 compute_output( output, user_data );
00308 ocas.output_time += get_time()-start_time;
00309
00310 uint32_t num_hp = 0;
00311 GradVal = B0;
00312 for(i=0; i< nData; i++) {
00313
00314 Ci[i] = C*(1-old_output[i]);
00315 Bi[i] = C*(old_output[i] - output[i]);
00316
00317 float64_t val;
00318 if(Bi[i] != 0)
00319 val = -Ci[i]/Bi[i];
00320 else
00321 val = -OCAS_PLUS_INF;
00322
00323 if (val>0)
00324 {
00325 hpi[num_hp] = i;
00326 hpf[num_hp] = val;
00327 num_hp++;
00328 }
00329
00330 if( (Bi[i] < 0 && val > 0) || (Bi[i] > 0 && val <= 0))
00331 GradVal += Bi[i];
00332
00333 }
00334
00335 t = 0;
00336 if( GradVal < 0 )
00337 {
00338 start_time = get_time();
00339 sort(hpf, hpi, num_hp);
00340 ocas.sort_time += get_time() - start_time;
00341
00342 float64_t t_new, GradVal_new;
00343 i = 0;
00344 while( GradVal < 0 && i < num_hp )
00345 {
00346 t_new = hpf[i];
00347 GradVal_new = GradVal + CMath::abs(Bi[hpi[i]]) + A0*(t_new-t);
00348
00349 if( GradVal_new >= 0 )
00350 {
00351 t = t + GradVal*(t-t_new)/(GradVal_new - GradVal);
00352 }
00353 else
00354 {
00355 t = t_new;
00356 i++;
00357 }
00358
00359 GradVal = GradVal_new;
00360 }
00361 }
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374 t = CMath::max(t,0.0);
00375
00376 t1 = t;
00377 t2 = t+(1.0-t)/10.0;
00378
00379
00380 sq_norm_W = update_W( t1, user_data );
00381
00382
00383 xi = 0;
00384 cut_length = 0;
00385 ocas.trn_err = 0;
00386 for(i=0; i < nData; i++ ) {
00387
00388 if( (old_output[i]*(1-t2) + t2*output[i]) <= 1 )
00389 {
00390 new_cut[cut_length] = i;
00391 cut_length++;
00392 }
00393
00394 output[i] = old_output[i]*(1-t1) + t1*output[i];
00395
00396 if( output[i] <= 1) xi += 1-output[i];
00397 if( output[i] <= 0) ocas.trn_err++;
00398
00399 }
00400
00401 ocas.Q_P = 0.5*sq_norm_W + C*xi;
00402
00403 gap=(ocas.Q_P-ocas.Q_D)/CMath::abs(ocas.Q_P);
00404 SG_SABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(1), -CMath::log10(TolRel), 6);
00405
00406 break;
00407 }
00408
00409
00410 if( ocas.Q_P - ocas.Q_D <= TolRel*CMath::abs(ocas.Q_P)) ocas.exitflag = 1;
00411 if( ocas.Q_P - ocas.Q_D <= TolAbs) ocas.exitflag = 2;
00412 if( ocas.Q_P <= QPBound) ocas.exitflag = 3;
00413 if(ocas.nCutPlanes >= BufSize) ocas.exitflag = -1;
00414
00415 }
00416
00417 cleanup:
00418
00419 OCAS_FREE(H);
00420 OCAS_FREE(b);
00421 OCAS_FREE(alpha);
00422 OCAS_FREE(new_cut);
00423 OCAS_FREE(I);
00424 OCAS_FREE(diag_H);
00425 OCAS_FREE(output);
00426 OCAS_FREE(old_output);
00427 OCAS_FREE(hpf);
00428 OCAS_FREE(hpi);
00429 OCAS_FREE(Ci);
00430 OCAS_FREE(Bi);
00431
00432 ocas.ocas_time = get_time() - ocas.ocas_time;
00433
00434 return(ocas);
00435 }
00436 }