Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _MULTITASKKERNELTREENORMALIZER_H___
00012 #define _MULTITASKKERNELTREENORMALIZER_H___
00013
00014 #include "kernel/KernelNormalizer.h"
00015 #include "kernel/MultitaskKernelMklNormalizer.h"
00016 #include "kernel/Kernel.h"
00017 #include <algorithm>
00018 #include <map>
00019 #include <set>
00020 #include <deque>
00021
00022 namespace shogun
00023 {
00024
00029 class CNode: public CSGObject
00030 {
00031
00032 public:
00033
00034
00037 CNode()
00038 {
00039 parent = NULL;
00040 beta = 1.0;
00041 node_id = 0;
00042 }
00043
00047 std::set<CNode*> get_path_root()
00048 {
00049 std::set<CNode*> nodes_on_path = std::set<CNode*>();
00050 CNode *node = this;
00051 while (node != NULL) {
00052 nodes_on_path.insert(node);
00053 node = node->parent;
00054 }
00055 return nodes_on_path;
00056 }
00057
00061 std::vector<int32_t> get_task_ids_below()
00062 {
00063
00064 std::vector<int32_t> task_ids;
00065 std::deque<CNode*> grey_nodes;
00066 grey_nodes.push_back(this);
00067
00068 while(grey_nodes.size() > 0)
00069 {
00070
00071 CNode *current_node = grey_nodes.front();
00072 grey_nodes.pop_front();
00073
00074 for(int32_t i = 0; i!=int32_t(current_node->children.size()); i++){
00075 grey_nodes.push_back(current_node->children[i]);
00076 }
00077
00078 if(current_node->is_leaf()){
00079 task_ids.push_back(current_node->getNode_id());
00080 }
00081 }
00082
00083 return task_ids;
00084 }
00085
00089 void add_child(CNode *node)
00090 {
00091 node->parent = this;
00092 this->children.push_back(node);
00093 }
00094
00096 inline virtual const char *get_name() const
00097 {
00098 return "CNode";
00099 }
00100
00102 bool is_leaf()
00103 {
00104 return children.empty();
00105
00106 }
00107
00109 int32_t getNode_id() const
00110 {
00111 return node_id;
00112 }
00113
00115 void setNode_id(int32_t node_idx)
00116 {
00117 this->node_id = node_idx;
00118 }
00119
00121 float64_t beta;
00122
00123 protected:
00124
00126 CNode* parent;
00127
00129 std::vector<CNode*> children;
00130
00132 int32_t node_id;
00133
00134 };
00135
00136
00141 class CTaxonomy : public CSGObject
00142 {
00143
00144 public:
00145
00148 CTaxonomy(){
00149 root = new CNode();
00150 nodes.push_back(root);
00151
00152 name2id = std::map<std::string, int32_t>();
00153 name2id["root"] = 0;
00154 }
00155
00160 CNode* get_node(int32_t task_id) {
00161 return nodes[task_id];
00162 }
00163
00167 void set_root_beta(float64_t beta)
00168 {
00169 nodes[0]->beta = beta;
00170 }
00171
00177 CNode* add_node(std::string parent_name, std::string child_name, float64_t beta) {
00178
00179
00180 if (child_name=="") SG_ERROR("child_name empty");
00181 if (parent_name=="") SG_ERROR("parent_name empty");
00182
00183
00184 CNode* child_node = new CNode();
00185
00186 child_node->beta = beta;
00187
00188 nodes.push_back(child_node);
00189 int32_t id = nodes.size()-1;
00190
00191 name2id[child_name] = id;
00192
00193 child_node->setNode_id(id);
00194
00195
00196
00197 CNode* parent = nodes[name2id[parent_name]];
00198
00199 parent->add_child(child_node);
00200
00201 return child_node;
00202
00203 }
00204
00209 int32_t get_id(std::string name) {
00210 return name2id[name];
00211 }
00212
00218 std::set<CNode*> intersect_root_path(CNode* node_lhs, CNode* node_rhs) {
00219
00220 std::set<CNode*> root_path_lhs = node_lhs->get_path_root();
00221 std::set<CNode*> root_path_rhs = node_rhs->get_path_root();
00222
00223 std::set<CNode*> intersection;
00224
00225 std::set_intersection(root_path_lhs.begin(), root_path_lhs.end(),
00226 root_path_rhs.begin(), root_path_rhs.end(),
00227 std::inserter(intersection, intersection.end()));
00228
00229 return intersection;
00230
00231 }
00232
00238 float64_t compute_node_similarity(int32_t task_lhs, int32_t task_rhs)
00239 {
00240
00241 CNode* node_lhs = get_node(task_lhs);
00242 CNode* node_rhs = get_node(task_rhs);
00243
00244
00245 std::set<CNode*> intersection = intersect_root_path(node_lhs, node_rhs);
00246
00247
00248 float64_t gamma = 0;
00249 for (std::set<CNode*>::const_iterator p = intersection.begin(); p != intersection.end(); ++p) {
00250
00251 gamma += (*p)->beta;
00252 }
00253
00254 return gamma;
00255
00256 }
00257
00261 void update_task_histogram(std::vector<int32_t> task_vector_lhs) {
00262
00263
00264 task_histogram.clear();
00265
00266
00267
00268 for (std::vector<int32_t>::const_iterator it=task_vector_lhs.begin(); it!=task_vector_lhs.end(); it++)
00269 {
00270 task_histogram[*it] = 0.0;
00271 }
00272
00273
00274 for (std::vector<int32_t>::const_iterator it=task_vector_lhs.begin(); it!=task_vector_lhs.end(); it++)
00275 {
00276 task_histogram[*it] += 1.0;
00277 }
00278
00279
00280 for (std::map<int32_t, float64_t>::const_iterator it=task_histogram.begin(); it!=task_histogram.end(); it++)
00281 {
00282 task_histogram[it->first] = task_histogram[it->first] / float64_t(task_vector_lhs.size());
00283
00284 std::cout << "task_histogram:" << task_histogram[it->first] << std::endl;
00285
00286 }
00287
00288 }
00289
00291 int32_t get_num_nodes()
00292 {
00293 return (int32_t)(nodes.size());
00294 }
00295
00297 int32_t get_num_leaves()
00298 {
00299 int32_t num_leaves = 0;
00300
00301 for (int32_t i=0; i!=get_num_nodes(); i++)
00302 {
00303 if (get_node(i)->is_leaf()==true)
00304 {
00305 num_leaves++;
00306 }
00307 }
00308
00309 return num_leaves;
00310 }
00311
00313 float64_t get_node_weight(int32_t idx)
00314 {
00315 CNode* node = get_node(idx);
00316 return node->beta;
00317 }
00318
00323 void set_node_weight(int32_t idx, float64_t weight)
00324 {
00325 CNode* node = get_node(idx);
00326 node->beta = weight;
00327 }
00328
00330 inline virtual const char* get_name() const
00331 {
00332 return "CTaxonomy";
00333 }
00334
00336 std::map<std::string, int32_t> get_name2id() {
00337 return name2id;
00338 }
00339
00345 int32_t get_id_by_name(std::string name)
00346 {
00347 return name2id[name];
00348 }
00349
00350
00351 protected:
00352
00353 CNode* root;
00354 std::map<std::string, int32_t> name2id;
00355 std::vector<CNode*> nodes;
00356 std::map<int32_t, float64_t> task_histogram;
00357
00358 };
00359
00360
00361
00362
00363 class CMultitaskKernelMklNormalizer;
00364
00368 class CMultitaskKernelTreeNormalizer: public CMultitaskKernelMklNormalizer
00369 {
00370
00371
00372
00373 public:
00374
00377 CMultitaskKernelTreeNormalizer()
00378 {
00379 }
00380
00387 CMultitaskKernelTreeNormalizer(std::vector<std::string> task_lhs,
00388 std::vector<std::string> task_rhs,
00389 CTaxonomy tax)
00390 {
00391
00392 taxonomy = tax;
00393 set_task_vector_lhs(task_lhs);
00394 set_task_vector_rhs(task_rhs);
00395
00396 num_nodes = taxonomy.get_num_nodes();
00397
00398 std::cout << "num nodes:" << num_nodes << std::endl;
00399
00400 dependency_matrix = std::vector<float64_t>(num_nodes * num_nodes);
00401
00402 update_cache();
00403 }
00404
00405
00407 virtual ~CMultitaskKernelTreeNormalizer()
00408 {
00409 }
00410
00411
00413 void update_cache()
00414 {
00415
00416
00417 for (int32_t i=0; i!=num_nodes; i++)
00418 {
00419 for (int32_t j=0; j!=num_nodes; j++)
00420 {
00421
00422 float64_t similarity = taxonomy.compute_node_similarity(i, j);
00423 set_node_similarity(i,j,similarity);
00424
00425 }
00426
00427 }
00428 }
00429
00430
00431
00437 inline virtual float64_t normalize(float64_t value, int32_t idx_lhs, int32_t idx_rhs)
00438 {
00439
00440
00441 int32_t task_idx_lhs = task_vector_lhs[idx_lhs];
00442 int32_t task_idx_rhs = task_vector_rhs[idx_rhs];
00443
00444
00445
00446
00447 float64_t task_similarity = get_node_similarity(task_idx_lhs, task_idx_rhs);
00448
00449
00450
00451 float64_t similarity = (value/scale) * task_similarity;
00452
00453
00454 return similarity;
00455
00456 }
00457
00462 inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs)
00463 {
00464 SG_ERROR("normalize_lhs not implemented");
00465 return 0;
00466 }
00467
00472 inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs)
00473 {
00474 SG_ERROR("normalize_rhs not implemented");
00475 return 0;
00476 }
00477
00478
00480 void set_task_vector_lhs(std::vector<std::string> vec)
00481 {
00482
00483 task_vector_lhs.clear();
00484
00485 for (int32_t i = 0; i != (int32_t)(vec.size()); ++i)
00486 {
00487 task_vector_lhs.push_back(taxonomy.get_id(vec[i]));
00488 }
00489
00490
00491 taxonomy.update_task_histogram(task_vector_lhs);
00492
00493 }
00494
00496 void set_task_vector_rhs(std::vector<std::string> vec)
00497 {
00498
00499 task_vector_rhs.clear();
00500
00501 for (int32_t i = 0; i != (int32_t)(vec.size()); ++i)
00502 {
00503 task_vector_rhs.push_back(taxonomy.get_id(vec[i]));
00504 }
00505
00506 }
00507
00509 void set_task_vector(std::vector<std::string> vec)
00510 {
00511 set_task_vector_lhs(vec);
00512 set_task_vector_rhs(vec);
00513 }
00514
00516 int32_t get_num_betas()
00517 {
00518
00519 return taxonomy.get_num_nodes();
00520
00521 }
00522
00526 float64_t get_beta(int32_t idx)
00527 {
00528
00529 return taxonomy.get_node_weight(idx);
00530
00531 }
00532
00536 void set_beta(int32_t idx, float64_t weight)
00537 {
00538
00539 taxonomy.set_node_weight(idx, weight);
00540
00541 update_cache();
00542
00543 }
00544
00545
00551 float64_t get_node_similarity(int32_t node_lhs, int32_t node_rhs)
00552 {
00553
00554 ASSERT(node_lhs < num_nodes && node_lhs >= 0);
00555 ASSERT(node_rhs < num_nodes && node_rhs >= 0);
00556
00557 return dependency_matrix[node_lhs * num_nodes + node_rhs];
00558
00559 }
00560
00566 void set_node_similarity(int32_t node_lhs, int32_t node_rhs,
00567 float64_t similarity)
00568 {
00569
00570 ASSERT(node_lhs < num_nodes && node_lhs >= 0);
00571 ASSERT(node_rhs < num_nodes && node_rhs >= 0);
00572
00573 dependency_matrix[node_lhs * num_nodes + node_rhs] = similarity;
00574
00575 }
00576
00577
00579 inline virtual const char* get_name() const
00580 {
00581 return "MultitaskKernelTreeNormalizer";
00582 }
00583
00584
00585
00586 protected:
00587
00588
00590 CTaxonomy taxonomy;
00591
00593 int32_t num_nodes;
00594
00596 std::vector<int32_t> task_vector_lhs;
00597
00599 std::vector<int32_t> task_vector_rhs;
00600
00602 std::vector<float64_t> dependency_matrix;
00603
00604 };
00605 }
00606 #endif