35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
39 # include "vigra/hdf5impex.hxx"
42 # include "vigra/multi_array.hxx"
43 # include "vigra/multi_impex.hxx"
44 # include "vigra/inspectimage.hxx"
46 #include <vigra/windows.h>
50 #include <vigra/multi_pointoperators.hxx>
146 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
152 Feature_t & features,
165 template<
class RF,
class PR,
class SM,
class ST>
175 template<
class RF,
class PR>
185 template<
class RF,
class PR>
201 template<
class TR,
class IntT,
class TopT,
class Feat>
209 template<
class TR,
class IntT,
class TopT,
class Feat>
248 template <
class Visitor,
class Next = StopVisiting>
258 next_(next), visitor_(visitor)
263 next_(stop_), visitor_(visitor)
266 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
267 void visit_after_split( Tree & tree,
272 Feature_t & features,
275 if(visitor_.is_active())
276 visitor_.visit_after_split(tree, split,
277 parent, leftChild, rightChild,
279 next_.visit_after_split(tree, split, parent, leftChild, rightChild,
283 template<
class RF,
class PR,
class SM,
class ST>
284 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st,
int index)
286 if(visitor_.is_active())
287 visitor_.visit_after_tree(rf, pr, sm, st, index);
288 next_.visit_after_tree(rf, pr, sm, st, index);
291 template<
class RF,
class PR>
292 void visit_at_beginning(RF & rf, PR & pr)
294 if(visitor_.is_active())
295 visitor_.visit_at_beginning(rf, pr);
296 next_.visit_at_beginning(rf, pr);
298 template<
class RF,
class PR>
299 void visit_at_end(RF & rf, PR & pr)
301 if(visitor_.is_active())
302 visitor_.visit_at_end(rf, pr);
303 next_.visit_at_end(rf, pr);
306 template<
class TR,
class IntT,
class TopT,
class Feat>
307 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
309 if(visitor_.is_active())
310 visitor_.visit_external_node(tr, index, node_t,features);
311 next_.visit_external_node(tr, index, node_t,features);
313 template<
class TR,
class IntT,
class TopT,
class Feat>
314 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
316 if(visitor_.is_active())
317 visitor_.visit_internal_node(tr, index, node_t,features);
318 next_.visit_internal_node(tr, index, node_t,features);
323 if(visitor_.is_active() && visitor_.has_value())
324 return visitor_.return_val();
325 return next_.return_val();
349 template<
class A,
class B>
350 detail::VisitorNode<A, detail::VisitorNode<B> >
363 template<
class A,
class B,
class C>
364 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
379 template<
class A,
class B,
class C,
class D>
380 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
381 detail::VisitorNode<D> > > >
398 template<
class A,
class B,
class C,
class D,
class E>
399 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
400 detail::VisitorNode<D, detail::VisitorNode<E> > > > >
420 template<
class A,
class B,
class C,
class D,
class E,
422 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
423 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
445 template<
class A,
class B,
class C,
class D,
class E,
447 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
448 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
449 detail::VisitorNode<G> > > > > > >
451 D & d, E & e, F & f, G & g)
473 template<
class A,
class B,
class C,
class D,
class E,
474 class F,
class G,
class H>
475 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
476 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
477 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
504 template<
class A,
class B,
class C,
class D,
class E,
505 class F,
class G,
class H,
class I>
506 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
507 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
508 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
536 template<
class A,
class B,
class C,
class D,
class E,
537 class F,
class G,
class H,
class I,
class J>
538 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
539 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
540 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
541 detail::VisitorNode<J> > > > > > > > > >
582 bool adjust_thresholds;
592 adjust_thresholds(
false), tree_id(0), last_node_id(0), current_label(0)
594 struct MarginalDistribution
597 Int32 leftTotalCounts;
599 Int32 rightTotalCounts;
606 struct TreeOnlineInformation
608 std::vector<MarginalDistribution> mag_distributions;
609 std::vector<IndexList> index_lists;
611 std::map<int,int> interior_to_index;
613 std::map<int,int> exterior_to_index;
617 std::vector<TreeOnlineInformation> trees_online_information;
621 template<
class RF,
class PR>
625 trees_online_information.resize(rf.options_.tree_count_);
632 trees_online_information[tree_id].mag_distributions.clear();
633 trees_online_information[tree_id].index_lists.clear();
634 trees_online_information[tree_id].interior_to_index.clear();
635 trees_online_information[tree_id].exterior_to_index.clear();
640 template<
class RF,
class PR,
class SM,
class ST>
646 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
647 void visit_after_split( Tree & tree,
652 Feature_t & features,
656 int addr=tree.topology_.size();
657 if(split.createNode().typeID() == i_ThresholdNode)
659 if(adjust_thresholds)
662 linear_index=trees_online_information[tree_id].mag_distributions.size();
663 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
664 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
666 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
667 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
669 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
670 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
672 double gap_left,gap_right;
674 gap_left=features(leftChild[0],split.bestSplitColumn());
675 for(i=1;i<leftChild.size();++i)
676 if(features(leftChild[i],split.bestSplitColumn())>gap_left)
677 gap_left=features(leftChild[i],split.bestSplitColumn());
678 gap_right=features(rightChild[0],split.bestSplitColumn());
679 for(i=1;i<rightChild.size();++i)
680 if(features(rightChild[i],split.bestSplitColumn())<gap_right)
681 gap_right=features(rightChild[i],split.bestSplitColumn());
682 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
683 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
689 linear_index=trees_online_information[tree_id].index_lists.size();
690 trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
692 trees_online_information[tree_id].index_lists.push_back(IndexList());
694 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
695 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
698 void add_to_index_list(
int tree,
int node,
int index)
702 TreeOnlineInformation &ti=trees_online_information[tree];
703 ti.index_lists[ti.exterior_to_index[node]].push_back(index);
705 void move_exterior_node(
int src_tree,
int src_index,
int dst_tree,
int dst_index)
709 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
710 trees_online_information[src_tree].exterior_to_index.erase(src_index);
717 template<
class TR,
class IntT,
class TopT,
class Feat>
721 if(adjust_thresholds)
723 vigra_assert(node_t==i_ThresholdNode,
"We can only visit threshold nodes");
725 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
726 TreeOnlineInformation &ti=trees_online_information[tree_id];
727 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
728 if(value>m.gap_left && value<m.gap_right)
731 if(m.leftCounts[current_label]/
double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
741 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
744 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
746 ++m.rightTotalCounts;
747 ++m.rightCounts[current_label];
752 ++m.rightCounts[current_label];
800 template<
class RF,
class PR,
class SM,
class ST>
804 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
806 oobCount.resize(rf.ext_param_.row_count_, 0);
807 oobErrorCount.resize(rf.ext_param_.row_count_, 0);
810 for(
int l = 0; l < rf.ext_param_.row_count_; ++l)
817 .predictLabel(
rowVector(pr.features(), l))
818 != pr.response()(l,0))
829 template<
class RF,
class PR>
833 for(
int l=0; l < (int)rf.ext_param_.row_count_; ++l)
837 oobError += double(oobErrorCount[l]) / oobCount[l];
841 oobError/=totalOobCount;
874 void save(std::string filen, std::string pathn)
876 if(*(pathn.end()-1) !=
'/')
878 const char* filename = filen.c_str();
881 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
887 template<
class RF,
class PR>
888 void visit_at_beginning(RF & rf, PR & pr)
890 class_count = rf.class_count();
891 tmp_prob.
reshape(Shp(1, class_count), 0);
892 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
893 is_weighted = rf.options().predict_weighted_;
894 indices.resize(rf.ext_param().row_count_);
895 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
897 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
899 for(
int ii = 0; ii < rf.ext_param().row_count_; ++ii)
905 template<
class RF,
class PR,
class SM,
class ST>
906 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st,
int index)
913 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
915 ArrayVector<int> oob_indices;
916 ArrayVector<int> cts(class_count, 0);
917 std::random_shuffle(indices.
begin(), indices.
end());
918 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
920 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
922 oob_indices.push_back(indices[ii]);
923 ++cts[pr.response()(indices[ii], 0)];
926 for(
unsigned int ll = 0; ll < oob_indices.size(); ++ll)
929 ++oobCount[oob_indices[ll]];
934 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),oob_indices[ll]));
935 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
936 rf.tree(index).parameters_,
939 for(
int ii = 0; ii < class_count; ++ii)
941 tmp_prob[ii] = node.prob_begin()[ii];
945 for(
int ii = 0; ii < class_count; ++ii)
946 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
948 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
953 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
956 if(!sm.is_used()[ll])
964 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
965 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
966 rf.tree(index).parameters_,
969 for(
int ii = 0; ii < class_count; ++ii)
971 tmp_prob[ii] = node.prob_begin()[ii];
975 for(
int ii = 0; ii < class_count; ++ii)
976 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
987 template<
class RF,
class PR>
991 int totalOobCount =0;
992 int breimanstyle = 0;
993 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1002 oob_breiman = double(breimanstyle)/totalOobCount;
1068 void save(std::string filen, std::string pathn)
1070 if(*(pathn.end()-1) !=
'/')
1072 const char* filename = filen.c_str();
1074 writeHDF5(filename, (pathn +
"oob_per_tree").c_str(), oob_per_tree);
1075 writeHDF5(filename, (pathn +
"oobroc_per_tree").c_str(), oobroc_per_tree);
1076 writeHDF5(filename, (pathn +
"breiman_per_tree").c_str(), breiman_per_tree);
1078 writeHDF5(filename, (pathn +
"per_tree_error").c_str(), temp);
1080 writeHDF5(filename, (pathn +
"per_tree_error_std").c_str(), temp);
1082 writeHDF5(filename, (pathn +
"breiman_error").c_str(), temp);
1084 writeHDF5(filename, (pathn +
"ulli_error").c_str(), temp);
1090 template<
class RF,
class PR>
1091 void visit_at_beginning(RF & rf, PR & pr)
1093 class_count = rf.class_count();
1094 if(class_count == 2)
1098 tmp_prob.
reshape(Shp(1, class_count), 0);
1099 prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1100 is_weighted = rf.options().predict_weighted_;
1101 oob_per_tree.
reshape(Shp(1, rf.tree_count()), 0);
1102 breiman_per_tree.
reshape(Shp(1, rf.tree_count()), 0);
1104 if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
1106 oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1107 oobErrorCount.
reshape(Shp(rf.ext_param_.row_count_,1), 0);
1111 template<
class RF,
class PR,
class SM,
class ST>
1112 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st,
int index)
1117 for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1120 if(!sm.is_used()[ll])
1128 int pos = rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
1129 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1130 rf.tree(index).parameters_,
1133 for(
int ii = 0; ii < class_count; ++ii)
1135 tmp_prob[ii] = node.prob_begin()[ii];
1139 for(
int ii = 0; ii < class_count; ++ii)
1140 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1143 int label =
argMax(tmp_prob);
1145 if(label != pr.response()(ll, 0))
1150 ++oobErrorCount[ll];
1154 int breimanstyle = 0;
1155 int totalOobCount = 0;
1156 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1163 if(oobroc_per_tree.
shape(2) == 1)
1169 if(oobroc_per_tree.
shape(2) == 1)
1170 oobroc_per_tree.
bindOuter(index)/=totalOobCount;
1171 if(oobroc_per_tree.
shape(2) > 1)
1173 MultiArrayView<3, double> current_roc
1175 for(
int gg = 0; gg < current_roc.shape(2); ++gg)
1177 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1181 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1183 current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1186 current_roc.
bindOuter(gg)/= totalOobCount;
1189 breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1190 oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1196 template<
class RF,
class PR>
1201 int totalOobCount =0;
1202 int breimanstyle = 0;
1203 for(
int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1209 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1213 oob_per_tree2 /= totalOobCount;
1214 oob_breiman = double(breimanstyle)/totalOobCount;
1254 int repetition_count_;
1258 void save(std::string filename, std::string prefix)
1260 prefix =
"variable_importance_" + prefix;
1273 : repetition_count_(rep_cnt)
1280 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1285 Region & rightChild,
1286 Feature_t & features,
1291 Int32 const class_count = tree.ext_param_.class_count_;
1292 Int32 const column_count = tree.ext_param_.column_count_;
1293 if(variable_importance_.
size() == 0)
1296 variable_importance_
1301 if(split.createNode().typeID() == i_ThresholdNode)
1303 Node<i_ThresholdNode> node(split.createNode());
1305 += split.region_gini_ - split.minGini();
1315 template<
class RF,
class PR,
class SM,
class ST>
1319 Int32 column_count = rf.ext_param_.column_count_;
1320 Int32 class_count = rf.ext_param_.class_count_;
1330 typedef typename PR::FeatureWithMemory_t FeatureArray;
1331 typedef typename FeatureArray::value_type FeatureValue;
1333 FeatureArray features = pr.features();
1339 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1340 if(!sm.is_used()[ii])
1341 oob_indices.push_back(ii);
1347 #ifdef CLASSIFIER_TEST
1358 oob_right(Shp_t(1, class_count + 1));
1360 perm_oob_right (Shp_t(1, class_count + 1));
1364 for(iter = oob_indices.
begin();
1365 iter != oob_indices.
end();
1369 .predictLabel(
rowVector(features, *iter))
1370 == pr.response()(*iter, 0))
1373 ++oob_right[pr.response()(*iter,0)];
1375 ++oob_right[class_count];
1379 for(
int ii = 0; ii < column_count; ++ii)
1381 perm_oob_right.
init(0.0);
1383 backup_column.clear();
1384 for(iter = oob_indices.
begin();
1385 iter != oob_indices.
end();
1388 backup_column.push_back(features(*iter,ii));
1392 for(
int rr = 0; rr < repetition_count_; ++rr)
1395 int n = oob_indices.
size();
1396 for(
int jj = 1; jj < n; ++jj)
1397 std::swap(features(oob_indices[jj], ii),
1398 features(oob_indices[randint(jj+1)], ii));
1401 for(iter = oob_indices.
begin();
1402 iter != oob_indices.
end();
1406 .predictLabel(
rowVector(features, *iter))
1407 == pr.response()(*iter, 0))
1410 ++perm_oob_right[pr.response()(*iter, 0)];
1412 ++perm_oob_right[class_count];
1419 perm_oob_right /= repetition_count_;
1420 perm_oob_right -=oob_right;
1421 perm_oob_right *= -1;
1422 perm_oob_right /= oob_indices.
size();
1423 variable_importance_
1425 Shp_t(ii+1,class_count+1)) += perm_oob_right;
1427 for(
int jj = 0; jj < int(oob_indices.
size()); ++jj)
1428 features(oob_indices[jj], ii) = backup_column[jj];
1437 template<
class RF,
class PR,
class SM,
class ST>
1445 template<
class RF,
class PR>
1448 variable_importance_ /= rf.trees_.
size();
1458 template<
class RF,
class PR,
class SM,
class ST>
1459 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st,
int index){
1460 if(index != rf.options().tree_count_-1) {
1461 std::cout <<
"\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 <<
"%]"
1462 <<
" (" << index+1 <<
" of " << rf.options().tree_count_ <<
") done" << std::flush;
1465 std::cout <<
"\r[" << std::setw(10) << 100.0 <<
"%]" << std::endl;
1469 template<
class RF,
class PR>
1470 void visit_at_end(RF
const & rf, PR
const & pr) {
1471 std::string a =
TOCS;
1472 std::cout <<
"all " << rf.options().tree_count_ <<
" trees have been learned in " << a << std::endl;
1475 template<
class RF,
class PR>
1476 void visit_at_beginning(RF
const & rf, PR
const & pr) {
1478 std::cout <<
"growing random forest, which will have " << rf.options().tree_count_ <<
" trees" << std::endl;
1526 void save(std::string file, std::string prefix)
1543 template<
class RF,
class PR>
1544 void visit_at_beginning(RF
const & rf, PR & pr)
1547 int n = rf.ext_param_.column_count_;
1548 gini_missc.
reshape(Shp(n +1,n+ 1));
1549 corr_noise.
reshape(Shp(n + 1, 10));
1550 corr_l.
reshape(Shp(n +1, 10));
1552 noise.
reshape(Shp(pr.features().shape(0), 10));
1553 noise_l.
reshape(Shp(pr.features().shape(0), 10));
1555 for(
int ii = 0; ii < noise.
size(); ++ii)
1557 noise[ii] = random.uniform53();
1558 noise_l[ii] = random.uniform53() > 0.5;
1560 bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1561 tmp_labels.
reshape(pr.response().shape());
1563 numChoices.resize(n+1);
1566 template<
class RF,
class PR>
1567 void visit_at_end(RF
const & rf, PR
const & pr)
1575 int rC = similarity.
shape(0);
1576 for(
int jj = 0; jj < rC-1; ++jj)
1578 rowVector(similarity, jj) /= numChoices[jj];
1579 rowVector(similarity, jj) -= mean_noise(jj, 0);
1581 for(
int jj = 0; jj < rC; ++jj)
1585 rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1586 similarity =
abs(similarity);
1587 FindMinMax<double> minmax;
1590 for(
int jj = 0; jj < rC; ++jj)
1593 similarity.
subarray(Shp(0,0), Shp(rC-1, rC-1))
1595 similarity.
subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1597 for(
int jj = 0; jj < rC; ++jj)
1600 FindMinMax<double> minmax2;
1602 for(
int jj = 0; jj < rC; ++jj)
1608 template<
class Tree,
class Split,
class Region,
class Feature_t,
class Label_t>
1609 void visit_after_split( Tree & tree,
1613 Region & rightChild,
1614 Feature_t & features,
1617 if(split.createNode().typeID() == i_ThresholdNode)
1621 for(
int ii = 0; ii < parent.size(); ++ii)
1623 tmp_labels[parent[ii]]
1624 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1625 ++tmp_cc[tmp_labels[parent[ii]]];
1627 double region_gini = bgfunc.loss_of_region(tmp_labels,
1632 int n = split.bestSplitColumn();
1634 ++(*(numChoices.
end()-1));
1636 for(
int k = 0; k < features.shape(1); ++k)
1640 parent.
begin(), parent.end(),
1642 wgini = (region_gini - bgfunc.min_gini_);
1646 for(
int k = 0; k < 10; ++k)
1650 parent.
begin(), parent.end(),
1652 wgini = (region_gini - bgfunc.min_gini_);
1657 for(
int k = 0; k < 10; ++k)
1661 parent.
begin(), parent.end(),
1663 wgini = (region_gini - bgfunc.min_gini_);
1667 bgfunc(labels, tmp_labels, parent.
begin(), parent.end(),tmp_cc);
1668 wgini = (region_gini - bgfunc.min_gini_);
1672 region_gini = split.region_gini_;
1674 Node<i_ThresholdNode> node(split.createNode());
1677 +=split.region_gini_ - split.minGini();
1679 for(
int k = 0; k < 10; ++k)
1683 parent.begin(), parent.end(),
1684 parent.classCounts());
1690 for(
int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1692 wgini = region_gini - split.min_gini_[k];
1695 split.splitColumns[k])
1699 for(
int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1701 split.bgfunc(
columnVector(features, split.splitColumns[k]),
1703 parent.begin(), parent.end(),
1704 parent.classCounts());
1705 wgini = region_gini - split.bgfunc.min_gini_;
1707 split.splitColumns[k]) += wgini;
1714 SortSamplesByDimensions<Feature_t>
1715 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1716 std::partition(parent.begin(), parent.end(), sorter);
1727 #endif // RF_VISITORS_HXX
#define TIC
Definition: timing.hxx:322
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:210
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:1316
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:725
MultiArray< 2, double > breiman_per_tree
Definition: rf_visitors.hxx:1044
MultiArray< 2, double > gini_missc
Definition: rf_visitors.hxx:1495
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
const difference_type & shape() const
Definition: multi_array.hxx:1551
void visit_at_end(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:176
void visit_at_beginning(RF &rf, const PR &pr)
Definition: rf_visitors.hxx:622
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:166
const_iterator begin() const
Definition: array_vector.hxx:223
double oobError
Definition: rf_visitors.hxx:782
iterator begin()
Definition: multi_array.hxx:1815
MultiArray< 2, double > similarity
Definition: rf_visitors.hxx:1515
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2738
Definition: rf_visitors.hxx:857
ArrayVector< int > numChoices
Definition: rf_visitors.hxx:1523
Definition: rf_visitors.hxx:1489
Definition: rf_visitors.hxx:1224
Definition: accessor.hxx:43
MultiArrayView< N, T, StridedArrayTag > transpose() const
Definition: multi_array.hxx:1470
double oob_per_tree2
Definition: rf_visitors.hxx:1039
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:339
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:630
difference_type_1 size() const
Definition: multi_array.hxx:1544
MultiArray< 4, double > oobroc_per_tree
Definition: rf_visitors.hxx:1061
double return_val()
Definition: rf_visitors.hxx:220
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:830
Definition: rf_visitors.hxx:249
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
MultiArray< 2, double > noise
Definition: rf_visitors.hxx:1499
void init(U const &initial)
Definition: array_vector.hxx:146
Definition: rf_split.hxx:831
MultiArray & init(const U &init)
Definition: multi_array.hxx:2728
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
Definition: rf_visitors.hxx:1009
Definition: rf_visitors.hxx:578
MultiArray< 2, double > oob_per_tree
Definition: rf_visitors.hxx:1019
#define TOCS
Definition: timing.hxx:325
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void writeHDF5(...)
Store array data in an HDF5 file.
Definition: rf_visitors.hxx:1454
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1518
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
double oob_std
Definition: rf_visitors.hxx:1025
Definition: rf_visitors.hxx:106
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
image import and export functions
Definition: random.hxx:336
MultiArray< 2, double > variable_importance_
Definition: rf_visitors.hxx:1253
double oob_breiman
Definition: rf_visitors.hxx:868
const_iterator end() const
Definition: array_vector.hxx:237
const_pointer data() const
Definition: array_vector.hxx:209
size_type size() const
Definition: array_vector.hxx:330
void visit_at_beginning(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:186
MultiArrayView subarray(difference_type p, difference_type q) const
Definition: multi_array.hxx:1431
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:202
Definition: rf_visitors.hxx:777
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const
Definition: multi_array.hxx:2067
Definition: rf_visitors.hxx:229
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition: rf_visitors.hxx:147
double oob_mean
Definition: rf_visitors.hxx:1022
MultiArray< 2, double > corr_noise
Definition: rf_visitors.hxx:1503