37 #ifndef VIGRA_RF_COMMON_HXX
38 #define VIGRA_RF_COMMON_HXX
44 struct ClassificationTag
54 inline detail::RF_DEFAULT& rf_default();
69 friend RF_DEFAULT& ::vigra::rf_default();
99 template<
class T,
class C>
104 static T & choose(T & t, C &)
111 class Value_Chooser<detail::RF_DEFAULT, C>
116 static C & choose(detail::RF_DEFAULT &, C & c)
131 detail::RF_DEFAULT& rf_default()
133 static detail::RF_DEFAULT result;
140 enum RF_OptionTag { RF_EQUAL,
176 double training_set_proportion_;
177 int training_set_size_;
178 int (*training_set_func_)(int);
180 training_set_calc_switch_;
182 bool sample_with_replacement_;
184 stratification_method_;
193 RF_OptionTag mtry_switch_;
195 int (*mtry_func_)(int) ;
197 bool predict_weighted_;
199 int min_split_node_size_;
200 bool prepare_online_learning_;
204 typedef std::map<std::string, double_array> map_type;
206 int serialized_size()
const
215 #define COMPARE(field) result = result && (this->field == rhs.field);
216 COMPARE(training_set_proportion_);
217 COMPARE(training_set_size_);
218 COMPARE(training_set_calc_switch_);
219 COMPARE(sample_with_replacement_);
220 COMPARE(stratification_method_);
221 COMPARE(mtry_switch_);
223 COMPARE(tree_count_);
224 COMPARE(min_split_node_size_);
225 COMPARE(predict_weighted_);
232 return !(*
this == rhs_);
235 void unserialize(Iter
const & begin, Iter
const & end)
238 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
239 "RandomForestOptions::unserialize():"
240 "wrong number of parameters");
241 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
242 PULL(training_set_proportion_,
double);
243 PULL(training_set_size_,
int);
245 PULL(training_set_calc_switch_, (RF_OptionTag)
int);
246 PULL(sample_with_replacement_, 0 != );
247 PULL(stratification_method_, (RF_OptionTag)
int);
248 PULL(mtry_switch_, (RF_OptionTag)
int);
251 PULL(tree_count_,
int);
252 PULL(min_split_node_size_,
int);
253 PULL(predict_weighted_, 0 !=);
257 void serialize(Iter
const & begin, Iter
const & end)
const
260 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
261 "RandomForestOptions::serialize():"
262 "wrong number of parameters");
263 #define PUSH(item_) *iter = double(item_); ++iter;
264 PUSH(training_set_proportion_);
265 PUSH(training_set_size_);
266 if(training_set_func_ != 0)
274 PUSH(training_set_calc_switch_);
275 PUSH(sample_with_replacement_);
276 PUSH(stratification_method_);
288 PUSH(min_split_node_size_);
289 PUSH(predict_weighted_);
293 void make_from_map(map_type & in)
295 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
296 #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0);
297 PULL(training_set_proportion_,
double);
298 PULL(training_set_size_,
int);
300 PULL(tree_count_,
int);
301 PULL(min_split_node_size_,
int);
302 PULLBOOL(sample_with_replacement_,
bool);
303 PULLBOOL(prepare_online_learning_,
bool);
304 PULLBOOL(predict_weighted_,
bool);
306 PULL(training_set_calc_switch_, (RF_OptionTag)(
int));
308 PULL(stratification_method_, (RF_OptionTag)(
int));
309 PULL(mtry_switch_, (RF_OptionTag)(
int));
317 void make_map(map_type & in)
const
319 #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_));
320 #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0));
321 PUSH(training_set_proportion_,
double);
322 PUSH(training_set_size_,
int);
324 PUSH(tree_count_,
int);
325 PUSH(min_split_node_size_,
int);
326 PUSH(sample_with_replacement_,
bool);
327 PUSH(prepare_online_learning_,
bool);
328 PUSH(predict_weighted_,
bool);
330 PUSH(training_set_calc_switch_, RF_OptionTag);
331 PUSH(stratification_method_, RF_OptionTag);
332 PUSH(mtry_switch_, RF_OptionTag);
334 PUSHFUNC(mtry_func_,
int);
335 PUSHFUNC(training_set_func_,
int);
348 training_set_proportion_(1.0),
349 training_set_size_(0),
350 training_set_func_(0),
351 training_set_calc_switch_(RF_PROPORTIONAL),
352 sample_with_replacement_(true),
353 stratification_method_(RF_NONE),
354 mtry_switch_(RF_SQRT),
357 predict_weighted_(false),
359 min_split_node_size_(1),
360 prepare_online_learning_(false)
376 vigra_precondition(in == RF_EQUAL ||
377 in == RF_PROPORTIONAL ||
380 "RandomForestOptions::use_stratification()"
381 "input must be RF_EQUAL, RF_PROPORTIONAL,"
382 "RF_EXTERNAL or RF_NONE");
383 stratification_method_ = in;
389 prepare_online_learning_=in;
399 sample_with_replacement_ = in;
413 training_set_proportion_ = in;
414 training_set_calc_switch_ = RF_PROPORTIONAL;
422 training_set_size_ = in;
423 training_set_calc_switch_ = RF_CONST;
435 training_set_func_ = in;
436 training_set_calc_switch_ = RF_FUNCTION;
444 predict_weighted_ =
true;
457 vigra_precondition(in == RF_LOG ||
460 "RandomForestOptions()::features_per_node():"
461 "input must be of type RF_LOG or RF_SQRT");
475 mtry_switch_ = RF_CONST;
487 mtry_switch_ = RF_FUNCTION;
511 min_split_node_size_ = in;
532 template<
class LabelType =
double>
545 typedef std::map<std::string, double_array> map_type;
563 void to_classlabel(
int index, T & out)
const
565 out = T(classes[index]);
568 int to_classIndex(T index)
const
570 return std::find(classes.
begin(), classes.
end(), index) - classes.
begin();
573 #define EQUALS(field) field(rhs.field)
576 EQUALS(column_count_),
577 EQUALS(class_count_),
579 EQUALS(actual_mtry_),
580 EQUALS(actual_msample_),
581 EQUALS(problem_type_),
583 EQUALS(class_weights_),
584 EQUALS(is_weighted_),
586 EQUALS(response_size_)
588 std::back_insert_iterator<ArrayVector<Label_t> >
590 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
593 #define EQUALS(field) field(rhs.field)
597 EQUALS(column_count_),
598 EQUALS(class_count_),
600 EQUALS(actual_mtry_),
601 EQUALS(actual_msample_),
602 EQUALS(problem_type_),
604 EQUALS(class_weights_),
605 EQUALS(is_weighted_),
607 EQUALS(response_size_)
609 std::back_insert_iterator<ArrayVector<Label_t> >
611 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
615 #define EQUALS(field) (this->field = rhs.field);
618 EQUALS(column_count_);
619 EQUALS(class_count_);
621 EQUALS(actual_mtry_);
622 EQUALS(actual_msample_);
623 EQUALS(problem_type_);
625 EQUALS(is_weighted_);
627 EQUALS(response_size_)
628 class_weights_.clear();
629 std::back_insert_iterator<ArrayVector<
double> >
630 iter2(class_weights_);
631 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
633 std::back_insert_iterator<ArrayVector<Label_t> >
635 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
642 EQUALS(column_count_);
643 EQUALS(class_count_);
645 EQUALS(actual_mtry_);
646 EQUALS(actual_msample_);
647 EQUALS(problem_type_);
649 EQUALS(is_weighted_);
651 EQUALS(response_size_)
652 class_weights_.clear();
653 std::back_insert_iterator<ArrayVector<
double> >
654 iter2(class_weights_);
655 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
657 std::back_insert_iterator<ArrayVector<Label_t> >
659 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
665 bool operator==(ProblemSpec<T>
const & rhs)
668 #define COMPARE(field) result = result && (this->field == rhs.field);
669 COMPARE(column_count_);
670 COMPARE(class_count_);
672 COMPARE(actual_mtry_);
673 COMPARE(actual_msample_);
674 COMPARE(problem_type_);
675 COMPARE(is_weighted_);
678 COMPARE(class_weights_);
680 COMPARE(response_size_)
687 return !(*
this == rhs);
691 size_t serialized_size()
const
693 return 10 + class_count_ *int(is_weighted_+1);
698 void unserialize(Iter
const & begin, Iter
const & end)
701 vigra_precondition(end - begin >= 10,
702 "ProblemSpec::unserialize():"
703 "wrong number of parameters");
704 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
705 PULL(column_count_,
int);
706 PULL(class_count_,
int);
708 vigra_precondition(end - begin >= 10 + class_count_,
709 "ProblemSpec::unserialize(): 1");
710 PULL(row_count_,
int);
711 PULL(actual_mtry_,
int);
712 PULL(actual_msample_,
int);
714 PULL(is_weighted_,
int);
716 PULL(precision_,
double);
717 PULL(response_size_,
int);
720 vigra_precondition(end - begin == 10 + 2*class_count_,
721 "ProblemSpec::unserialize(): 2");
722 class_weights_.insert(class_weights_.end(),
724 iter + class_count_);
725 iter += class_count_;
727 classes.insert(classes.end(), iter, end);
733 void serialize(Iter
const & begin, Iter
const & end)
const
736 vigra_precondition(end - begin == serialized_size(),
737 "RandomForestOptions::serialize():"
738 "wrong number of parameters");
739 #define PUSH(item_) *iter = double(item_); ++iter;
744 PUSH(actual_msample_);
749 PUSH(response_size_);
752 std::copy(class_weights_.begin(),
753 class_weights_.end(),
755 iter += class_count_;
757 std::copy(classes.begin(),
763 void make_from_map(map_type & in)
765 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
766 PULL(column_count_,
int);
767 PULL(class_count_,
int);
768 PULL(row_count_,
int);
769 PULL(actual_mtry_,
int);
770 PULL(actual_msample_,
int);
772 PULL(is_weighted_,
int);
774 PULL(precision_,
double);
775 PULL(response_size_,
int);
776 class_weights_ = in[
"class_weights_"];
779 void make_map(map_type & in)
const
781 #define PUSH(item_) in[#item_] = double_array(1, double(item_));
786 PUSH(actual_msample_);
791 PUSH(response_size_);
792 in["class_weights_"] = class_weights_;
804 problem_type_(CHECKLATER),
822 template<
class C_Iter>
825 int size = end-begin;
826 for(
int k=0; k<size; ++k, ++begin)
827 classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
837 template<
class W_Iter>
840 class_weights_.insert(class_weights_.end(), begin, end);
851 class_weights_.clear();
856 problem_type_ = CHECKLATER;
857 is_weighted_ =
false;
881 int min_split_node_size_;
885 : min_split_node_size_(opt.min_split_node_size_)
889 void set_external_parameters(
ProblemSpec<T>const &,
int = 0,
bool =
false)
892 template<
class Region>
893 bool operator()(Region& region)
895 return region.size() < min_split_node_size_;
898 template<
class WeightIter,
class T,
class C>
908 #endif //VIGRA_RF_COMMON_HXX
RandomForestOptions & features_per_node(RF_OptionTag in)
use built in mapping to calculate mtry
Definition: rf_common.hxx:455
RandomForestOptions & tree_count(int in)
Definition: rf_common.hxx:495
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition: rf_common.hxx:411
RandomForestOptions & features_per_node(int(*in)(int))
use a external function to calculate mtry
Definition: rf_common.hxx:484
const_iterator begin() const
Definition: array_vector.hxx:223
RandomForestOptions & samples_per_tree(int in)
directly specify the number of samples per tree
Definition: rf_common.hxx:420
RandomForestOptions & samples_per_tree(int(*in)(int))
use external function to calculate the number of samples each tree should be learnt with...
Definition: rf_common.hxx:433
problem specification class for the random forest.
Definition: rf_common.hxx:533
Definition: array_vector.hxx:903
LabelType Label_t
problem class
Definition: rf_common.hxx:542
RandomForestOptions & min_split_node_size(int in)
Number of examples required for a node to be split.
Definition: rf_common.hxx:509
Definition: accessor.hxx:43
RandomForestOptions & features_per_node(int in)
Set mtry to a constant value.
Definition: rf_common.hxx:472
Standard early stopping criterion.
Definition: rf_common.hxx:878
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels -
Definition: rf_common.hxx:823
RandomForestOptions()
create a RandomForestOptions object with default initialisation.
Definition: rf_common.hxx:346
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights -
Definition: rf_common.hxx:838
RandomForestOptions & sample_with_replacement(bool in)
sample from training population with or without replacement?
Definition: rf_common.hxx:397
RandomForestOptions & predict_weighted()
weight each tree with number of samples in that node
Definition: rf_common.hxx:442
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:655
Options object for the random forest.
Definition: rf_common.hxx:170
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition: rf_common.hxx:374
const_iterator end() const
Definition: array_vector.hxx:237
ProblemSpec()
set default values (-> values not set)
Definition: rf_common.hxx:798
Problem_t
problem types
Definition: rf_common.hxx:519