37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
51 #include "functorexpression.hxx"
52 #include "random_forest/rf_common.hxx"
53 #include "random_forest/rf_nodeproxy.hxx"
54 #include "random_forest/rf_split.hxx"
55 #include "random_forest/rf_decisionTree.hxx"
56 #include "random_forest/rf_visitors.hxx"
57 #include "random_forest/rf_region.hxx"
58 #include "sampling.hxx"
59 #include "random_forest/rf_preprocessing.hxx"
60 #include "random_forest/rf_online_prediction_set.hxx"
61 #include "random_forest/rf_earlystopping.hxx"
62 #include "random_forest/rf_ridge_split.hxx"
82 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
84 SamplerOptions return_opt;
86 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
143 template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
150 typedef detail::DecisionTree DecisionTree_t;
157 typedef LabelType LabelT;
165 ProblemSpec_t ext_param_;
193 ProblemSpec_t
const &
ext_param = ProblemSpec_t())
225 template<
class TopologyIterator,
class ParameterIterator>
227 TopologyIterator topology_begin,
228 ParameterIterator parameter_begin,
229 ProblemSpec_t
const & problem_spec,
230 Options_t
const &
options = Options_t())
232 trees_(treeCount, DecisionTree_t(problem_spec)),
233 ext_param_(problem_spec),
236 for(
unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
238 trees_[k].topology_ = *topology_begin;
239 trees_[k].parameters_ = *parameter_begin;
258 vigra_precondition(ext_param_.used() ==
true,
259 "RandomForest::ext_param(): "
260 "Random forest has not been trained yet.");
276 vigra_precondition(ext_param_.used() ==
false,
277 "RandomForest::set_ext_param():"
278 "Random forest has been trained! Call reset()"
279 "before specifying new extrinsic parameters.");
303 DecisionTree_t
const &
tree(
int index)
const
305 return trees_[index];
310 DecisionTree_t &
tree(
int index)
312 return trees_[index];
322 return ext_param_.column_count_;
333 return ext_param_.column_count_;
341 return ext_param_.class_count_;
348 return options_.tree_count_;
353 template<
class U,
class C1,
366 bool adjust_thresholds=
false);
368 template <
class U,
class C1,
class U2,
class C2>
373 onlineLearn(features,
383 template<
class U,
class C1,
389 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
390 MultiArrayView<2,U2,C2>
const & response,
397 template<
class U,
class C1,
class U2,
class C2>
398 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
399 MultiArrayView<2, U2, C2>
const & labels,
402 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
447 template <
class U,
class C1,
453 void learn( MultiArrayView<2, U, C1>
const & features,
454 MultiArrayView<2, U2,C2>
const & response,
458 Random_t
const & random);
460 template <
class U,
class C1,
465 void learn( MultiArrayView<2, U, C1>
const & features,
466 MultiArrayView<2, U2,C2>
const & response,
472 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
481 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
482 void learn( MultiArrayView<2, U, C1>
const & features,
483 MultiArrayView<2, U2,C2>
const & labels,
493 template <
class U,
class C1,
class U2,
class C2,
494 class Visitor_t,
class Split_t>
495 void learn( MultiArrayView<2, U, C1>
const & features,
496 MultiArrayView<2, U2,C2>
const & labels,
525 template <
class U,
class C1,
class U2,
class C2>
553 template <
class U,
class C,
class Stop>
556 template <
class U,
class C>
567 template <
class U,
class C>
568 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
569 ArrayVectorView<double> prior)
const;
581 template <
class U,
class C1,
class T,
class C2>
585 vigra_precondition(features.
shape(0) == labels.
shape(0),
586 "RandomForest::predictLabels(): Label array has wrong size.");
587 for(
int k=0; k<features.
shape(0); ++k)
589 vigra_precondition(!detail::contains_nan(
rowVector(features, k)),
590 "RandomForest::predictLabels(): NaN in feature matrix.");
591 labels(k,0) = detail::RequiresExplicitCast<T>::cast(
predictLabel(
rowVector(features, k), rf_default()));
605 template <
class U,
class C1,
class T,
class C2>
608 LabelType nanLabel)
const
610 vigra_precondition(features.
shape(0) == labels.
shape(0),
611 "RandomForest::predictLabels(): Label array has wrong size.");
612 for(
int k=0; k<features.
shape(0); ++k)
614 if(detail::contains_nan(
rowVector(features, k)))
615 labels(k,0) = nanLabel;
617 labels(k,0) = detail::RequiresExplicitCast<T>::cast(
predictLabel(
rowVector(features, k), rf_default()));
630 template <
class U,
class C1,
class T,
class C2,
class Stop>
635 vigra_precondition(features.
shape(0) == labels.
shape(0),
636 "RandomForest::predictLabels(): Label array has wrong size.");
637 for(
int k=0; k<features.
shape(0); ++k)
652 template <
class U,
class C1,
class T,
class C2,
class Stop>
656 template <
class T1,
class T2,
class C>
666 template <
class U,
class C1,
class T,
class C2>
673 template <
class U,
class C1,
class T,
class C2>
683 template <
class LabelType,
class PreprocessorTag>
684 template<
class U,
class C1,
690 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
691 MultiArrayView<2,U2,C2>
const & response,
697 bool adjust_thresholds)
699 online_visitor_.activate();
700 online_visitor_.adjust_thresholds=adjust_thresholds;
704 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
705 typedef UniformIntRandomFunctor<Random_t>
712 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
713 Default_Stop_t default_stop(options_);
714 typename RF_CHOOSER(Stop_t)::type stop
715 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
716 Default_Split_t default_split;
717 typename RF_CHOOSER(Split_t)::type split
718 = RF_CHOOSER(Split_t)::choose(split_, default_split);
719 rf::visitors::StopVisiting stopvisiting;
720 typedef rf::visitors::detail::VisitorNode
721 <rf::visitors::OnlineLearnVisitor,
722 typename RF_CHOOSER(Visitor_t)::type>
725 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
727 vigra_precondition(options_.prepare_online_learning_,
"onlineLearn: online learning must be enabled on RandomForest construction");
733 ext_param_.class_count_=0;
734 Preprocessor_t preprocessor( features, response,
735 options_, ext_param_);
738 RandFunctor_t randint ( random);
741 split.set_external_parameters(ext_param_);
742 stop.set_external_parameters(ext_param_);
746 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
752 for(
int ii = 0; ii < (int)trees_.
size(); ++ii)
754 online_visitor_.tree_id=ii;
755 poisson_sampler.sample();
756 std::map<int,int> leaf_parents;
757 leaf_parents.clear();
759 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
761 int sample=poisson_sampler[s];
762 online_visitor_.current_label=preprocessor.response()(sample,0);
763 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
764 int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
768 online_visitor_.add_to_index_list(ii,leaf,sample);
771 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
773 leaf_parents[leaf]=online_visitor_.last_node_id;
778 std::map<int,int>::iterator leaf_iterator;
779 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
781 int leaf=leaf_iterator->first;
782 int parent=leaf_iterator->second;
783 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
784 ArrayVector<Int32> indeces;
786 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
787 StackEntry_t stack_entry(indeces.begin(),
789 ext_param_.class_count_);
794 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
800 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
801 stack_entry.rightParent=parent;
805 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
807 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.
size(),ii,leaf);
820 online_visitor_.deactivate();
823 template<
class LabelType,
class PreprocessorTag>
824 template<
class U,
class C1,
845 ext_param_.class_count_=0;
853 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
854 Default_Stop_t default_stop(options_);
855 typename RF_CHOOSER(Stop_t)::type stop
856 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
857 Default_Split_t default_split;
858 typename RF_CHOOSER(Split_t)::type split
859 = RF_CHOOSER(Split_t)::choose(split_, default_split);
862 <rf::visitors::OnlineLearnVisitor,
863 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
865 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
867 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
868 online_visitor_.activate();
871 RandFunctor_t randint ( random);
877 Preprocessor_t preprocessor( features, response,
878 options_, ext_param_);
881 split.set_external_parameters(ext_param_);
882 stop.set_external_parameters(ext_param_);
889 preprocessor.strata().end(),
890 detail::make_sampler_opt(options_)
891 .sampleSize(ext_param().actual_msample_),
898 first_stack_entry( sampler.sampledIndices().begin(),
899 sampler.sampledIndices().end(),
900 ext_param_.class_count_);
902 .set_oob_range( sampler.oobIndices().begin(),
903 sampler.oobIndices().end());
905 online_visitor_.tree_id=treeId;
906 trees_[treeId].reset();
908 .learn( preprocessor.features(),
909 preprocessor.response(),
916 .visit_after_tree( *
this,
922 online_visitor_.deactivate();
925 template <
class LabelType,
class PreprocessorTag>
926 template <
class U,
class C1,
938 Random_t
const & random)
949 vigra_precondition(features.
shape(0) == response.
shape(0),
950 "RandomForest::learn(): shape mismatch between features and response.");
957 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
958 Default_Stop_t default_stop(options_);
959 typename RF_CHOOSER(Stop_t)::type stop
960 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
961 Default_Split_t default_split;
962 typename RF_CHOOSER(Split_t)::type split
963 = RF_CHOOSER(Split_t)::choose(split_, default_split);
966 rf::visitors::OnlineLearnVisitor,
967 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
969 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
971 if(options_.prepare_online_learning_)
972 online_visitor_.activate();
974 online_visitor_.deactivate();
978 RandFunctor_t randint ( random);
985 Preprocessor_t preprocessor( features, response,
986 options_, ext_param_);
989 split.set_external_parameters(ext_param_);
990 stop.set_external_parameters(ext_param_);
994 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
997 preprocessor.strata().end(),
998 detail::make_sampler_opt(options_)
999 .sampleSize(ext_param().actual_msample_),
1002 visitor.visit_at_beginning(*
this, preprocessor);
1005 for(
int ii = 0; ii < (int)trees_.
size(); ++ii)
1011 first_stack_entry( sampler.sampledIndices().begin(),
1012 sampler.sampledIndices().end(),
1013 ext_param_.class_count_);
1015 .set_oob_range( sampler.oobIndices().begin(),
1016 sampler.oobIndices().end());
1018 .learn( preprocessor.features(),
1019 preprocessor.response(),
1026 .visit_after_tree( *
this,
1033 visitor.visit_at_end(*
this, preprocessor);
1035 online_visitor_.deactivate();
1041 template <
class LabelType,
class Tag>
1042 template <
class U,
class C,
class Stop>
1046 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1047 "RandomForestn::predictLabel():"
1048 " Too few columns in feature matrix.");
1049 vigra_precondition(
rowCount(features) == 1,
1050 "RandomForestn::predictLabel():"
1051 " Feature matrix must have a singlerow.");
1054 predictProbabilities(features, probabilities, stop);
1055 ext_param_.to_classlabel(
argMax(probabilities), d);
1061 template <
class LabelType,
class PreprocessorTag>
1062 template <
class U,
class C>
1067 using namespace functor;
1068 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1069 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1070 vigra_precondition(
rowCount(features) == 1,
1071 "RandomForestn::predictLabel():"
1072 " Feature matrix must have a single row.");
1073 Matrix<double> prob(1,ext_param_.class_count_);
1074 predictProbabilities(features, prob);
1075 std::transform( prob.begin(), prob.end(),
1076 priors.
begin(), prob.begin(),
1079 ext_param_.to_classlabel(
argMax(prob), d);
1083 template<
class LabelType,
class PreprocessorTag>
1084 template <
class T1,
class T2,
class C>
1093 "RandomFroest::predictProbabilities():"
1094 " Feature matrix and probability matrix size mismatch.");
1097 vigra_precondition(
columnCount(predictionSet.features) >= ext_param_.column_count_,
1098 "RandomForestn::predictProbabilities():"
1099 " Too few columns in feature matrix.");
1102 "RandomForestn::predictProbabilities():"
1103 " Probability matrix must have as many columns as there are classes.");
1106 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1109 for(
int k=0; k<options_.tree_count_; ++k)
1111 set_id=(set_id+1) % predictionSet.indices[0].size();
1112 typedef std::set<SampleRange<T1> > my_set;
1113 typedef typename my_set::iterator set_it;
1116 std::vector<std::pair<int,set_it> > stack;
1118 for(set_it i=predictionSet.ranges[set_id].begin();
1119 i!=predictionSet.ranges[set_id].end();++i)
1120 stack.push_back(std::pair<int,set_it>(2,i));
1122 int num_decisions=0;
1123 while(!stack.empty())
1125 set_it range=stack.back().second;
1126 int index=stack.back().first;
1130 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1133 trees_[k].parameters_,
1134 index).prob_begin();
1135 for(
int i=range->start;i!=range->end;++i)
1138 for(
int l=0; l<ext_param_.class_count_; ++l)
1140 prob(predictionSet.indices[set_id][i], l) += (T2)weights[l];
1142 totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l];
1149 if(trees_[k].topology_[index]!=i_ThresholdNode)
1151 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1153 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1154 if(range->min_boundaries[node.column()]>=node.threshold())
1157 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1160 if(range->max_boundaries[node.column()]<node.threshold())
1163 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1167 SampleRange<T1> new_range=*range;
1168 new_range.min_boundaries[node.column()]=FLT_MAX;
1169 range->max_boundaries[node.column()]=-FLT_MAX;
1170 new_range.start=new_range.end=range->end;
1172 while(i!=range->end)
1175 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1177 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1178 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1181 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1186 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1187 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1192 if(range->start==range->end)
1194 predictionSet.ranges[set_id].erase(range);
1198 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1201 if(new_range.start!=new_range.end)
1203 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1204 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1208 predictionSet.cumulativePredTime[k]=num_decisions;
1210 for(
unsigned int i=0;i<totalWeights.size();++i)
1214 for(
int l=0; l<ext_param_.class_count_; ++l)
1217 prob(i, l) /= totalWeights[i];
1219 assert(test==totalWeights[i]);
1220 assert(totalWeights[i]>0.0);
1224 template <
class LabelType,
class PreprocessorTag>
1225 template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1228 MultiArrayView<2, T, C2> & prob,
1229 Stop_t & stop_)
const
1235 "RandomForestn::predictProbabilities():"
1236 " Feature matrix and probability matrix size mismatch.");
1240 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1241 "RandomForestn::predictProbabilities():"
1242 " Too few columns in feature matrix.");
1245 "RandomForestn::predictProbabilities():"
1246 " Probability matrix must have as many columns as there are classes.");
1248 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1249 Default_Stop_t default_stop(options_);
1250 typename RF_CHOOSER(Stop_t)::type & stop
1251 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1253 stop.set_external_parameters(ext_param_, tree_count());
1254 prob.init(NumericTraits<T>::zero());
1264 for(
int row=0; row <
rowCount(features); ++row)
1266 MultiArrayView<2, U, StridedArrayTag> currentRow(
rowVector(features, row));
1270 if(detail::contains_nan(currentRow))
1276 ArrayVector<double>::const_iterator weights;
1279 double totalWeight = 0.0;
1282 for(
int k=0; k<options_.tree_count_; ++k)
1285 weights = trees_[k ].predict(currentRow);
1288 int weighted = options_.predict_weighted_;
1289 for(
int l=0; l<ext_param_.class_count_; ++l)
1291 double cur_w = weights[l] * (weighted * (*(weights-1))
1293 prob(row, l) += (T)cur_w;
1295 totalWeight += cur_w;
1297 if(stop.after_prediction(weights,
1307 for(
int l=0; l< ext_param_.class_count_; ++l)
1309 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1315 template <
class LabelType,
class PreprocessorTag>
1316 template <
class U,
class C1,
class T,
class C2>
1317 void RandomForest<LabelType, PreprocessorTag>
1318 ::predictRaw(MultiArrayView<2, U, C1>
const & features,
1319 MultiArrayView<2, T, C2> & prob)
const
1325 "RandomForestn::predictProbabilities():"
1326 " Feature matrix and probability matrix size mismatch.");
1330 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1331 "RandomForestn::predictProbabilities():"
1332 " Too few columns in feature matrix.");
1335 "RandomForestn::predictProbabilities():"
1336 " Probability matrix must have as many columns as there are classes.");
1338 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1339 prob.init(NumericTraits<T>::zero());
1349 for(
int row=0; row <
rowCount(features); ++row)
1351 ArrayVector<double>::const_iterator weights;
1354 double totalWeight = 0.0;
1357 for(
int k=0; k<options_.tree_count_; ++k)
1360 weights = trees_[k ].predict(
rowVector(features, row));
1363 int weighted = options_.predict_weighted_;
1364 for(
int l=0; l<ext_param_.class_count_; ++l)
1366 double cur_w = weights[l] * (weighted * (*(weights-1))
1368 prob(row, l) += (T)cur_w;
1370 totalWeight += cur_w;
1374 prob/= options_.tree_count_;
1382 #include "random_forest/rf_algorithm.hxx"
1383 #endif // VIGRA_RANDOM_FOREST_HXX
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const
predict multiple labels with given features
Definition: random_forest.hxx:606
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition: random_forest.hxx:274
int class_count() const
return number of classes used while training.
Definition: random_forest.hxx:339
Definition: rf_nodeproxy.hxx:626
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
Definition: rf_preprocessing.hxx:62
int feature_count() const
return number of features used while training.
Definition: random_forest.hxx:320
int column_count() const
return number of features used while training.
Definition: random_forest.hxx:331
Create random samples from a sequence of indices.
Definition: sampling.hxx:233
Int32 leftParent
Definition: rf_region.hxx:69
const difference_type & shape() const
Definition: multi_array.hxx:1551
Definition: rf_split.hxx:993
const_iterator begin() const
Definition: array_vector.hxx:223
problem specification class for the random forest.
Definition: rf_common.hxx:533
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition: random_forest.hxx:192
void sample()
Definition: sampling.hxx:468
std::ptrdiff_t MultiArrayIndex
Definition: multi_shape.hxx:55
Definition: accessor.hxx:43
Standard early stopping criterion.
Definition: rf_common.hxx:878
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:630
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition: random_forest.hxx:256
DecisionTree_t & tree(int index)
access trees
Definition: random_forest.hxx:310
DecisionTree_t const & tree(int index) const
access const trees
Definition: random_forest.hxx:303
Options_t & set_options()
access random forest options
Definition: random_forest.hxx:286
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:933
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:830
Definition: random_forest.hxx:144
Options_t const & options() const
access const random forest options
Definition: random_forest.hxx:296
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
Definition: rf_visitors.hxx:249
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const
predict multiple labels with given features
Definition: random_forest.hxx:631
Definition: rf_visitors.hxx:578
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition: random_forest.hxx:582
MultiArrayShape< 2 >::type Shape2
shape type for MultiArray<2, T>
Definition: multi_shape.hxx:241
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class.
Definition: sampling.hxx:144
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition: sampling.hxx:86
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition: random_forest.hxx:667
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
int tree_count() const
return number of trees
Definition: random_forest.hxx:346
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition: random_forest.hxx:226
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:655
Options object for the random forest.
Definition: rf_common.hxx:170
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1150
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition: random_forest.hxx:1044
size_type size() const
Definition: array_vector.hxx:330
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition: random_forest.hxx:526
Definition: rf_visitors.hxx:229