00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #include <stdlib.h>
00021 #include <string.h>
00022 #include <math.h>
00023 #include <sys/time.h>
00024 #include <time.h>
00025 #include <stdio.h>
00026 #include <stdint.h>
00027
00028 #include "classifier/svm/libocas.h"
00029 #include "classifier/svm/libocas_common.h"
00030 #include "classifier/svm/qpssvmlib.h"
00031
00032 static const uint32_t QPSolverMaxIter = 10000000;
00033
00034 static float64_t *H;
00035 static uint32_t BufSize;
00036
00037
00038
00039
00040 static const void *get_col( uint32_t i)
00041 {
00042 return( &H[ BufSize*i ] );
00043 }
00044
00045
00046
00047
00048 static float64_t get_time()
00049 {
00050 struct timeval tv;
00051 if (gettimeofday(&tv, NULL)==0)
00052 return (float64_t) (tv.tv_sec+((double)(tv.tv_usec))/1e6);
00053 else
00054 return 0.0;
00055 }
00056
00057
00058
00059
00060 ocas_return_value_T svm_ocas_solver(
00061 float64_t C,
00062 uint32_t nData,
00063 float64_t TolRel,
00064 float64_t TolAbs,
00065 float64_t QPBound,
00066 uint32_t _BufSize,
00067 uint8_t Method,
00068 void (*compute_W)(float64_t*, float64_t*, float64_t*, uint32_t, void*),
00069 float64_t (*update_W)(float64_t, void*),
00070 void (*add_new_cut)(float64_t*, uint32_t*, uint32_t, uint32_t, void*),
00071 void (*compute_output)(float64_t*, void* ),
00072 void (*sort)(float64_t*, uint32_t*, uint32_t),
00073 void* user_data)
00074 {
00075 ocas_return_value_T ocas;
00076 float64_t *b, *alpha, *diag_H;
00077 float64_t *output, *old_output;
00078 float64_t xi, sq_norm_W, QPSolverTolRel, dot_prod_WoldW, dummy, sq_norm_oldW;
00079 float64_t A0, B0, GradVal, t, t1=0, t2=0, *Ci, *Bi, *hpf;
00080 float64_t start_time;
00081 uint32_t *hpi;
00082 uint32_t cut_length;
00083 uint32_t i, *new_cut;
00084 uint16_t *I;
00085 int8_t qp_exitflag;
00086 float64_t gap;
00087
00088 ocas.ocas_time = get_time();
00089 ocas.solver_time = 0;
00090 ocas.output_time = 0;
00091 ocas.sort_time = 0;
00092 ocas.add_time = 0;
00093 ocas.w_time = 0;
00094
00095 BufSize = _BufSize;
00096
00097 QPSolverTolRel = TolRel*0.5;
00098
00099 H=NULL;
00100 b=NULL;
00101 alpha=NULL;
00102 new_cut=NULL;
00103 I=NULL;
00104 diag_H=NULL;
00105 output=NULL;
00106 old_output=NULL;
00107 hpf=NULL;
00108 hpi=NULL;
00109 Ci=NULL;
00110 Bi=NULL;
00111
00112
00113 H = (float64_t*)OCAS_CALLOC(BufSize*BufSize,sizeof(float64_t));
00114 if(H == NULL)
00115 {
00116 ocas.exitflag=-2;
00117 goto cleanup;
00118 }
00119
00120
00121 b = (float64_t*)OCAS_CALLOC(BufSize,sizeof(float64_t));
00122 if(b == NULL)
00123 {
00124 ocas.exitflag=-2;
00125 goto cleanup;
00126 }
00127
00128 alpha = (float64_t*)OCAS_CALLOC(BufSize,sizeof(float64_t));
00129 if(alpha == NULL)
00130 {
00131 ocas.exitflag=-2;
00132 goto cleanup;
00133 }
00134
00135
00136 new_cut = (uint32_t*)OCAS_CALLOC(nData,sizeof(uint32_t));
00137 if(new_cut == NULL)
00138 {
00139 ocas.exitflag=-2;
00140 goto cleanup;
00141 }
00142
00143 I = (uint16_t*)OCAS_CALLOC(BufSize,sizeof(uint16_t));
00144 if(I == NULL)
00145 {
00146 ocas.exitflag=-2;
00147 goto cleanup;
00148 }
00149
00150 for(i=0; i< BufSize; i++) I[i] = 1;
00151
00152 diag_H = (float64_t*)OCAS_CALLOC(BufSize,sizeof(float64_t));
00153 if(diag_H == NULL)
00154 {
00155 ocas.exitflag=-2;
00156 goto cleanup;
00157 }
00158
00159 output = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t));
00160 if(output == NULL)
00161 {
00162 ocas.exitflag=-2;
00163 goto cleanup;
00164 }
00165
00166 old_output = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t));
00167 if(old_output == NULL)
00168 {
00169 ocas.exitflag=-2;
00170 goto cleanup;
00171 }
00172
00173
00174 hpf = (float64_t*) OCAS_CALLOC(nData, sizeof(hpf[0]));
00175 if(hpf == NULL)
00176 {
00177 ocas.exitflag=-2;
00178 goto cleanup;
00179 }
00180
00181 hpi = (uint32_t*) OCAS_CALLOC(nData, sizeof(hpi[0]));
00182 if(hpi == NULL)
00183 {
00184 ocas.exitflag=-2;
00185 goto cleanup;
00186 }
00187
00188
00189 Ci = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t));
00190 if(Ci == NULL)
00191 {
00192 ocas.exitflag=-2;
00193 goto cleanup;
00194 }
00195
00196 Bi = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t));
00197 if(Bi == NULL)
00198 {
00199 ocas.exitflag=-2;
00200 goto cleanup;
00201 }
00202
00203 ocas.nCutPlanes = 0;
00204 ocas.exitflag = 0;
00205 ocas.nIter = 0;
00206
00207
00208 sq_norm_W = 0;
00209 xi = nData;
00210 ocas.Q_P = 0.5*sq_norm_W + C*xi;
00211 ocas.Q_D = 0;
00212
00213
00214 cut_length = nData;
00215 for(i=0; i < nData; i++)
00216 new_cut[i] = i;
00217
00218 gap=(ocas.Q_P-ocas.Q_D)/CMath::abs(ocas.Q_P);
00219 SG_SABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(1), -CMath::log10(TolRel), 6);
00220
00221
00222 while( ocas.exitflag == 0 )
00223 {
00224 ocas.nIter++;
00225
00226
00227 b[ocas.nCutPlanes] = -(float64_t)cut_length;
00228
00229 start_time = get_time();
00230
00231 add_new_cut( &H[INDEX2(0,ocas.nCutPlanes,BufSize)], new_cut, cut_length, ocas.nCutPlanes, user_data );
00232
00233 ocas.add_time += get_time() - start_time;
00234
00235
00236 diag_H[ocas.nCutPlanes] = H[INDEX2(ocas.nCutPlanes,ocas.nCutPlanes,BufSize)];
00237 for(i=0; i < ocas.nCutPlanes; i++) {
00238 H[INDEX2(ocas.nCutPlanes,i,BufSize)] = H[INDEX2(i,ocas.nCutPlanes,BufSize)];
00239 }
00240
00241 ocas.nCutPlanes++;
00242
00243
00244 start_time = get_time();
00245
00246 qp_exitflag = qpssvm_solver( &get_col, diag_H, b, C, I, alpha,
00247 ocas.nCutPlanes, QPSolverMaxIter, 0.0, QPSolverTolRel, &ocas.Q_D, &dummy, 0 );
00248
00249 ocas.solver_time += get_time() - start_time;
00250
00251 ocas.Q_D = -ocas.Q_D;
00252
00253 ocas.nNZAlpha = 0;
00254 for(i=0; i < ocas.nCutPlanes; i++) {
00255 if( alpha[i] != 0) ocas.nNZAlpha++;
00256 }
00257
00258 sq_norm_oldW = sq_norm_W;
00259 start_time = get_time();
00260 compute_W( &sq_norm_W, &dot_prod_WoldW, alpha, ocas.nCutPlanes, user_data );
00261 ocas.w_time += get_time() - start_time;
00262
00263
00264 switch( Method )
00265 {
00266
00267 case 0:
00268
00269 start_time = get_time();
00270 compute_output( output, user_data );
00271 ocas.output_time += get_time()-start_time;
00272
00273 xi = 0;
00274 cut_length = 0;
00275 ocas.trn_err = 0;
00276 for(i=0; i < nData; i++)
00277 {
00278 if(output[i] <= 0) ocas.trn_err++;
00279
00280 if(output[i] <= 1) {
00281 xi += 1 - output[i];
00282 new_cut[cut_length] = i;
00283 cut_length++;
00284 }
00285 }
00286 ocas.Q_P = 0.5*sq_norm_W + C*xi;
00287
00288 gap=(ocas.Q_P-ocas.Q_D)/CMath::abs(ocas.Q_P);
00289 SG_SABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(1), -CMath::log10(TolRel), 6);
00290
00291 break;
00292
00293
00294
00295 case 1:
00296
00297
00298 A0 = sq_norm_W -2*dot_prod_WoldW + sq_norm_oldW;
00299 B0 = dot_prod_WoldW - sq_norm_oldW;
00300
00301 memcpy( old_output, output, sizeof(float64_t)*nData );
00302
00303 start_time = get_time();
00304 compute_output( output, user_data );
00305 ocas.output_time += get_time()-start_time;
00306
00307 uint32_t num_hp = 0;
00308 GradVal = B0;
00309 for(i=0; i< nData; i++) {
00310
00311 Ci[i] = C*(1-old_output[i]);
00312 Bi[i] = C*(old_output[i] - output[i]);
00313
00314 float64_t val;
00315 if(Bi[i] != 0)
00316 val = -Ci[i]/Bi[i];
00317 else
00318 val = -OCAS_PLUS_INF;
00319
00320 if (val>0)
00321 {
00322 hpi[num_hp] = i;
00323 hpf[num_hp] = val;
00324 num_hp++;
00325 }
00326
00327 if( (Bi[i] < 0 && val > 0) || (Bi[i] > 0 && val <= 0))
00328 GradVal += Bi[i];
00329
00330 }
00331
00332 t = 0;
00333 if( GradVal < 0 )
00334 {
00335 start_time = get_time();
00336 sort(hpf, hpi, num_hp);
00337 ocas.sort_time += get_time() - start_time;
00338
00339 float64_t t_new, GradVal_new;
00340 i = 0;
00341 while( GradVal < 0 && i < num_hp )
00342 {
00343 t_new = hpf[i];
00344 GradVal_new = GradVal + CMath::abs(Bi[hpi[i]]) + A0*(t_new-t);
00345
00346 if( GradVal_new >= 0 )
00347 {
00348 t = t + GradVal*(t-t_new)/(GradVal_new - GradVal);
00349 }
00350 else
00351 {
00352 t = t_new;
00353 i++;
00354 }
00355
00356 GradVal = GradVal_new;
00357 }
00358 }
00359
00360
00361
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371 t = CMath::max(t,0.0);
00372
00373 t1 = t;
00374 t2 = t+(1.0-t)/10.0;
00375
00376
00377 sq_norm_W = update_W( t1, user_data );
00378
00379
00380 xi = 0;
00381 cut_length = 0;
00382 ocas.trn_err = 0;
00383 for(i=0; i < nData; i++ ) {
00384
00385 if( (old_output[i]*(1-t2) + t2*output[i]) <= 1 )
00386 {
00387 new_cut[cut_length] = i;
00388 cut_length++;
00389 }
00390
00391 output[i] = old_output[i]*(1-t1) + t1*output[i];
00392
00393 if( output[i] <= 1) xi += 1-output[i];
00394 if( output[i] <= 0) ocas.trn_err++;
00395
00396 }
00397
00398 ocas.Q_P = 0.5*sq_norm_W + C*xi;
00399
00400 gap=(ocas.Q_P-ocas.Q_D)/CMath::abs(ocas.Q_P);
00401 SG_SABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(1), -CMath::log10(TolRel), 6);
00402
00403 break;
00404 }
00405
00406
00407 if( ocas.Q_P - ocas.Q_D <= TolRel*CMath::abs(ocas.Q_P)) ocas.exitflag = 1;
00408 if( ocas.Q_P - ocas.Q_D <= TolAbs) ocas.exitflag = 2;
00409 if( ocas.Q_P <= QPBound) ocas.exitflag = 3;
00410 if(ocas.nCutPlanes >= BufSize) ocas.exitflag = -1;
00411
00412 }
00413
00414 cleanup:
00415
00416 OCAS_FREE(H);
00417 OCAS_FREE(b);
00418 OCAS_FREE(alpha);
00419 OCAS_FREE(new_cut);
00420 OCAS_FREE(I);
00421 OCAS_FREE(diag_H);
00422 OCAS_FREE(output);
00423 OCAS_FREE(old_output);
00424 OCAS_FREE(hpf);
00425 OCAS_FREE(hpi);
00426 OCAS_FREE(Ci);
00427 OCAS_FREE(Bi);
00428
00429 ocas.ocas_time = get_time() - ocas.ocas_time;
00430
00431 return(ocas);
00432 }
00433
00434
00435
00436
00437
00438 static void swapf(float64_t* a, float64_t* b)
00439 {
00440 float64_t dummy=*b;
00441 *b=*a;
00442 *a=dummy;
00443 }
00444
00445 static void swapi(uint32_t* a, uint32_t* b)
00446 {
00447 uint32_t dummy=*b;
00448 *b=*a;
00449 *a=dummy;
00450 }
00451
00452 void qsort_index(float64_t* value, uint32_t* index, uint32_t size)
00453 {
00454 if (size==2)
00455 {
00456 if (value[0] > value[1])
00457 {
00458 swapf(&value[0], &value[1]);
00459 swapi(&index[0], &index[1]);
00460 }
00461 return;
00462 }
00463 float64_t split=value[size/2];
00464
00465 uint32_t left=0;
00466 uint32_t right=size-1;
00467
00468 while (left<=right)
00469 {
00470 while (value[left] < split)
00471 left++;
00472 while (value[right] > split)
00473 right--;
00474
00475 if (left<=right)
00476 {
00477 swapf(&value[left], &value[right]);
00478 swapi(&index[left], &index[right]);
00479 left++;
00480 right--;
00481 }
00482 }
00483
00484 if (right+1> 1)
00485 qsort_index(value,index,right+1);
00486
00487 if (size-left> 1)
00488 qsort_index(&value[left],&index[left], size-left);
00489
00490
00491 return;
00492 }