diff --git a/doc/ipython-notebooks/multiclass/Tree/DecisionTrees.ipynb b/doc/ipython-notebooks/multiclass/Tree/DecisionTrees.ipynb index 29de81445cf..e229a277421 100644 --- a/doc/ipython-notebooks/multiclass/Tree/DecisionTrees.ipynb +++ b/doc/ipython-notebooks/multiclass/Tree/DecisionTrees.ipynb @@ -974,10 +974,9 @@ " c = sg.create_machine(\"CARTree\", nominal=feat_types,\n", " mode=problem_type,\n", " folds=num_folds,\n", - " apply_cv_pruning=use_cv_pruning,\n", - " labels=labels)\n", + " apply_cv_pruning=use_cv_pruning)\n", " # train using training features\n", - " c.train(feats)\n", + " c.train(feats, labels)\n", " \n", " return c\n", "\n", @@ -1408,7 +1407,7 @@ " c = sg.create_machine(\"CHAIDTree\", dependent_vartype=dependent_var_type,\n", " feature_types=feature_types,\n", " num_breakpoints=num_bins,\n", - " labels=labels)\n", + " labels = labels)\n", " # train using training features\n", " c.train(feats)\n", " \n", @@ -1722,9 +1721,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 1 -} \ No newline at end of file +} diff --git a/doc/ipython-notebooks/multiclass/Tree/TreeEnsemble.ipynb b/doc/ipython-notebooks/multiclass/Tree/TreeEnsemble.ipynb index d2991dacf25..84231482df7 100644 --- a/doc/ipython-notebooks/multiclass/Tree/TreeEnsemble.ipynb +++ b/doc/ipython-notebooks/multiclass/Tree/TreeEnsemble.ipynb @@ -112,8 +112,7 @@ "outputs": [], "source": [ "# train forest\n", - "rand_forest.put('labels', train_labels)\n", - "rand_forest.train(train_feats)\n", + "rand_forest.train(train_feats, train_labels)\n", "\n", "# load test dataset\n", "testfeat_file= os.path.join(SHOGUN_DATA_DIR, 'uci/letter/test_fm_letter.dat')\n", @@ -142,9 +141,8 @@ " c=sg.create_machine(\"CARTree\", nominal=feature_types,\n", " mode=problem_type,\n", " folds=2,\n", - " apply_cv_pruning=False,\n", - " labels=train_labels)\n", - " c.train(train_feats)\n", + " apply_cv_pruning=False)\n", + " c.train(train_feats, train_labels)\n", " \n", " return c\n", "\n", @@ -213,8 +211,7 @@ "source": [ "def get_rf_accuracy(num_trees,rand_subset_size):\n", " rf=setup_random_forest(num_trees,rand_subset_size,comb_rule,feat_types)\n", - " rf.put('labels', train_labels)\n", - " rf.train(train_feats)\n", + " rf.train(train_feats, train_labels)\n", " out_test=rf.apply_multiclass(test_feats)\n", " acc=sg.create_evaluation(\"MulticlassAccuracy\")\n", " return acc.evaluate(out_test,test_labels)" @@ -365,8 +362,7 @@ "outputs": [], "source": [ "rf=setup_random_forest(100,2,comb_rule,feat_types)\n", - "rf.put('labels', train_labels)\n", - "rf.train(train_feats)\n", + "rf.train(train_feats, train_labels)\n", " \n", "# set evaluation strategy\n", "rf.put(\"oob_evaluation_metric\", sg.create_evaluation(\"MulticlassAccuracy\"))\n", @@ -411,8 +407,7 @@ "def get_oob_errors_wine(num_trees,rand_subset_size):\n", " feat_types=np.array([False]*13)\n", " rf=setup_random_forest(num_trees,rand_subset_size,sg.create_combination_rule(\"MajorityVote\"),feat_types)\n", - " rf.put('labels', train_labels)\n", - " rf.train(train_feats)\n", + " rf.train(train_feats, train_labels)\n", " rf.put(\"oob_evaluation_metric\", sg.create_evaluation(\"MulticlassAccuracy\"))\n", " return rf.get(\"oob_error\") \n", "\n", @@ -494,7 +489,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.6.9" } }, "nbformat": 4, diff --git a/examples/meta/src/multiclass/cartree.sg.in b/examples/meta/src/multiclass/cartree.sg.in index d842ac93624..2064342f8a3 100644 --- a/examples/meta/src/multiclass/cartree.sg.in +++ b/examples/meta/src/multiclass/cartree.sg.in @@ -19,11 +19,10 @@ ft[1] = False #![create_instance] Machine classifier = create_machine("CARTree", nominal = ft,mode = enum EProblemType.PT_MULTICLASS, folds=5, apply_cv_pruning=True, seed=1) -classifier.set_labels(labels_train) #![create_instance] #![train_and_apply] -classifier.train(features_train) +classifier.train(features_train, labels_train) MulticlassLabels labels_predict = classifier.apply_multiclass(features_test) #![train_and_apply] diff --git a/examples/meta/src/multiclass/random_forest.sg.in b/examples/meta/src/multiclass/random_forest.sg.in index ed6a09bee40..540c8b75e6b 100644 --- a/examples/meta/src/multiclass/random_forest.sg.in +++ b/examples/meta/src/multiclass/random_forest.sg.in @@ -15,13 +15,13 @@ CombinationRule m_vote = create_combination_rule("MajorityVote") #![create_combination_rule] #![create_instance] -Machine rand_forest = create_machine("RandomForest", labels=labels_train, num_bags=100, combination_rule=m_vote, seed=1) +Machine rand_forest = create_machine("RandomForest", num_bags=100, combination_rule=m_vote, seed=1) Parallel p = rand_forest.get_global_parallel() p.set_num_threads(1) #![create_instance] #![train_and_apply] -rand_forest.train(features_train) +rand_forest.train(features_train, labels_train) MulticlassLabels labels_predict = rand_forest.apply_multiclass(features_test) #![train_and_apply] diff --git a/examples/meta/src/regression/cartree.sg.in b/examples/meta/src/regression/cartree.sg.in index 8c9846b8f8e..faea334f478 100644 --- a/examples/meta/src/regression/cartree.sg.in +++ b/examples/meta/src/regression/cartree.sg.in @@ -14,11 +14,11 @@ ft[0] = False #![set_attribute_types] #![create_machine] -Machine cartree = create_machine("CARTree", labels=labels_train, nominal=ft, mode=enum EProblemType.PT_REGRESSION, folds=5, apply_cv_pruning=True, seed=1) +Machine cartree = create_machine("CARTree", nominal=ft, mode=enum EProblemType.PT_REGRESSION, folds=5, apply_cv_pruning=True, seed=1) #![create_machine] #![train_and_apply] -cartree.train(feats_train) +cartree.train(feats_train, labels_train) Labels labels_predict = cartree.apply(feats_test) #![train_and_apply] diff --git a/examples/meta/src/regression/random_forest_regression.sg.in b/examples/meta/src/regression/random_forest_regression.sg.in index 346d73a0119..1828ebb69aa 100644 --- a/examples/meta/src/regression/random_forest_regression.sg.in +++ b/examples/meta/src/regression/random_forest_regression.sg.in @@ -15,12 +15,12 @@ CombinationRule mean_rule = create_combination_rule("MeanRule") #![create_combination_rule] #![create_instance] -Machine rand_forest = create_machine("RandomForest", labels=labels_train, num_bags=5, seed=1, combination_rule=mean_rule) +Machine rand_forest = create_machine("RandomForest", num_bags=5, seed=1, combination_rule=mean_rule) #![create_instance] #![train_and_apply] -rand_forest.train(features_train) -RegressionLabels labels_predict = rand_forest.apply_regression(features_test) +rand_forest.train(features_train, labels_train) +Labels labels_predict = rand_forest.apply_regression(features_test) #![train_and_apply] #![evaluate_error] @@ -32,3 +32,4 @@ real mserror = mse.evaluate(labels_predict, labels_test) # additional integration testing variables RealVector output = labels_predict.get_real_vector("labels") + diff --git a/src/shogun/machine/BaggingMachine.cpp b/src/shogun/machine/BaggingMachine.cpp index 632a5150ad6..cbfd444f313 100644 --- a/src/shogun/machine/BaggingMachine.cpp +++ b/src/shogun/machine/BaggingMachine.cpp @@ -24,12 +24,6 @@ BaggingMachine::BaggingMachine() : RandomMixin() register_parameters(); } -BaggingMachine::BaggingMachine(std::shared_ptr features, std::shared_ptr labels) - : BaggingMachine() -{ - set_labels(std::move(labels)); - m_features = std::move(features); -} std::shared_ptr BaggingMachine::apply_binary(std::shared_ptr data) { @@ -48,21 +42,12 @@ std::shared_ptr BaggingMachine::apply_multiclass(std::shared_p { SGMatrix bagged_outputs = apply_outputs_without_combination(data); - - require(m_labels, "Labels not set."); - require( - m_labels->get_label_type() == LT_MULTICLASS, - "Labels ({}) are not compatible with multiclass.", - m_labels->get_name()); - - auto labels_multiclass = std::dynamic_pointer_cast(m_labels); auto num_samples = bagged_outputs.size() / m_num_bags; - auto num_classes = labels_multiclass->get_num_classes(); auto pred = std::make_shared(num_samples); - pred->allocate_confidences_for(num_classes); + pred->allocate_confidences_for(m_num_classes); - SGMatrix class_probabilities(num_classes, num_samples); + SGMatrix class_probabilities(m_num_classes, num_samples); class_probabilities.zero(); for (auto i = 0; i < num_samples; ++i) @@ -125,27 +110,24 @@ BaggingMachine::apply_outputs_without_combination(std::shared_ptr data return output; } -bool BaggingMachine::train_machine(std::shared_ptr data) +bool BaggingMachine::train_machine(const std::shared_ptr& data, const std::shared_ptr& labs) { require(m_machine != NULL, "Machine is not set!"); require(m_num_bags > 0, "Number of bag is not set!"); - - if (data) + m_num_vectors = data->get_num_vectors(); + if(auto multiclass_labs = std::dynamic_pointer_cast(labs)) { - m_features = data; - - ASSERT(m_features->get_num_vectors() == m_labels->get_num_labels()); + m_num_classes = multiclass_labs->get_num_classes(); } - // if bag size is not provided, set it equal to number of training vectors if (m_bag_size == 0) - m_bag_size = m_features->get_num_vectors(); + m_bag_size = data->get_num_vectors(); // clear the array, if previously trained m_bags.clear(); // reset the oob index vector - m_all_oob_idx = SGVector(m_features->get_num_vectors()); + m_all_oob_idx = SGVector(data->get_num_vectors()); m_all_oob_idx.zero(); @@ -160,24 +142,27 @@ bool BaggingMachine::train_machine(std::shared_ptr data) { auto c=std::dynamic_pointer_cast(m_machine->clone()); ASSERT(c != NULL); - SGVector idx( - rnd_indicies.get_column_vector(i), m_bag_size, false); + SGVector idx(rnd_indicies.get_column_vector(i), m_bag_size, false); std::shared_ptr features; std::shared_ptr labels; if (env()->get_num_threads() == 1) { - features = m_features; - labels = m_labels; + features = data; + labels = labs; } else { - features = m_features->shallow_subset_copy(); - labels = m_labels->shallow_subset_copy(); + features = data->shallow_subset_copy(); + labels = labs->shallow_subset_copy(); } - - labels->add_subset(idx); +#pragma omp critical + { + labels->add_subset(idx); + features->add_subset(idx); + } + /* TODO: if it's a binary labeling ensure that there's always samples of both classes @@ -194,12 +179,15 @@ bool BaggingMachine::train_machine(std::shared_ptr data) } } */ - features->add_subset(idx); + set_machine_parameters(c, idx); - c->set_labels(labels); - c->train(features); - features->remove_subset(); - labels->remove_subset(); + c->train(features, labels); + #pragma omp critical + { + features->remove_subset(); + labels->remove_subset(); + } + #pragma omp critical { @@ -214,7 +202,7 @@ bool BaggingMachine::train_machine(std::shared_ptr data) pb.print_progress(); } pb.complete(); - + get_oob_error_lambda = [=](){return get_oob_error_impl(data, labs);}; return true; } @@ -224,7 +212,6 @@ void BaggingMachine::set_machine_parameters(std::shared_ptr m, SGVector void BaggingMachine::register_parameters() { - SG_ADD(&m_features, kFeatures, "Train features for bagging"); SG_ADD( &m_num_bags, kNBags, "Number of bags", ParameterProperties::HYPER); SG_ADD( @@ -275,9 +262,7 @@ void BaggingMachine::set_machine(std::shared_ptr machine) void BaggingMachine::init() { m_machine = nullptr; - m_features = nullptr; m_combination_rule = nullptr; - m_labels = nullptr; m_num_bags = 0; m_bag_size = 0; m_all_oob_idx = SGVector(); @@ -294,7 +279,7 @@ std::shared_ptr BaggingMachine::get_combination_rule() const return m_combination_rule; } -float64_t BaggingMachine::get_oob_error() const +float64_t BaggingMachine::get_oob_error_impl(const std::shared_ptr& data, const std::shared_ptr& labs) const { require( m_oob_evaluation_metric, "Out of bag evaluation metric is not set!"); @@ -302,8 +287,8 @@ float64_t BaggingMachine::get_oob_error() const require(m_bags.size() > 0, "BaggingMachine is not trained!"); SGMatrix output( - m_features->get_num_vectors(), m_bags.size()); - if (m_labels->get_label_type() == LT_REGRESSION) + m_num_vectors, m_bags.size()); + if (labs->get_label_type() == LT_REGRESSION) output.zero(); else output.set_const(NAN); @@ -318,9 +303,9 @@ float64_t BaggingMachine::get_oob_error() const auto current_oob = m_oob_indices[i]; SGVector oob(current_oob.data(), current_oob.size(), false); - m_features->add_subset(oob); + data->add_subset(oob); - auto l = m->apply(m_features); + auto l = m->apply(data); SGVector lv; if (l!=NULL) lv = std::dynamic_pointer_cast(l)->get_labels(); @@ -331,14 +316,14 @@ float64_t BaggingMachine::get_oob_error() const for (index_t j = 0; j < oob.vlen; j++) output(oob[j], i) = lv[j]; - m_features->remove_subset(); + data->remove_subset(); } std::vector idx; - for (index_t i = 0; i < m_features->get_num_vectors(); i++) + for (index_t i = 0; i < data->get_num_vectors(); i++) { if (m_all_oob_idx[i]) idx.push_back(i); @@ -350,7 +335,7 @@ float64_t BaggingMachine::get_oob_error() const lab[i] = combined[idx[i]]; std::shared_ptr predicted = NULL; - switch (m_labels->get_label_type()) + switch (labs->get_label_type()) { case LT_BINARY: predicted = std::make_shared(lab); @@ -369,16 +354,16 @@ float64_t BaggingMachine::get_oob_error() const } - m_labels->add_subset(SGVector(idx.data(), idx.size(), false)); - float64_t res = m_oob_evaluation_metric->evaluate(predicted, m_labels); - m_labels->remove_subset(); + labs->add_subset(SGVector(idx.data(), idx.size(), false)); + float64_t res = m_oob_evaluation_metric->evaluate(predicted, labs); + labs->remove_subset(); return res; } std::vector BaggingMachine::get_oob_indices(const SGVector& in_bag) { - SGVector out_of_bag(m_features->get_num_vectors()); + SGVector out_of_bag(m_num_vectors); out_of_bag.set_const(true); // mark the ones that are in_bag diff --git a/src/shogun/machine/BaggingMachine.h b/src/shogun/machine/BaggingMachine.h index a08ff0fb1f2..a4693ce891d 100644 --- a/src/shogun/machine/BaggingMachine.h +++ b/src/shogun/machine/BaggingMachine.h @@ -30,19 +30,11 @@ namespace shogun /** default ctor */ BaggingMachine(); - /** - * constructor - * - * @param features training features - * @param labels training labels - */ - BaggingMachine(std::shared_ptr features, std::shared_ptr labels); - ~BaggingMachine() override = default; - std::shared_ptr apply_binary(std::shared_ptr data=NULL) override; - std::shared_ptr apply_multiclass(std::shared_ptr data=NULL) override; - std::shared_ptr apply_regression(std::shared_ptr data=NULL) override; + std::shared_ptr apply_binary(std::shared_ptr data) override; + std::shared_ptr apply_multiclass(std::shared_ptr data) override; + std::shared_ptr apply_regression(std::shared_ptr data) override; /** * Set number of bags/machine to create @@ -118,8 +110,10 @@ namespace shogun * @param eval Evaluation method to use for calculating the error * @return out-of-bag error. */ - float64_t get_oob_error() const; - + float64_t get_oob_error() const + { + return get_oob_error_lambda(); + } /** name **/ const char* get_name() const override { @@ -127,7 +121,7 @@ namespace shogun } protected: - bool train_machine(std::shared_ptr data=NULL) override; + bool train_machine(const std::shared_ptr&, const std::shared_ptr& labs) override; /** * sets parameters of Machine - useful in Random Forest @@ -170,13 +164,11 @@ namespace shogun std::vector get_oob_indices(const SGVector& in_bag); + float64_t get_oob_error_impl(const std::shared_ptr& data, const std::shared_ptr& labs) const; protected: /** bags array */ std::vector> m_bags; - /** features to train on */ - std::shared_ptr m_features; - /** machine to use for bagging */ std::shared_ptr m_machine; @@ -198,9 +190,15 @@ namespace shogun /** metric to calculate the oob error */ std::shared_ptr m_oob_evaluation_metric; + int32_t m_num_classes; + + int32_t m_num_vectors; + + std::function get_oob_error_lambda; + + #ifndef SWIG public: - static constexpr std::string_view kFeatures = "features"; static constexpr std::string_view kNBags = "num_bags"; static constexpr std::string_view kBagSize = "bag_size"; static constexpr std::string_view kBags = "bags"; @@ -208,8 +206,8 @@ namespace shogun static constexpr std::string_view kAllOobIdx = "all_oob_idx"; static constexpr std::string_view kOobIndices = "oob_indices"; static constexpr std::string_view kMachine = "machine"; + static constexpr std::string_view kOobEvaluationMetric = "oob_evaluation_metric"; static constexpr std::string_view kOobError = "oob_error"; - static constexpr std::string_view kOobEvaluationMetric = "oob_evaluation_metric"; #endif }; } // namespace shogun diff --git a/src/shogun/machine/RandomForest.cpp b/src/shogun/machine/RandomForest.cpp index 410d99379f0..436fe1b78c6 100644 --- a/src/shogun/machine/RandomForest.cpp +++ b/src/shogun/machine/RandomForest.cpp @@ -53,26 +53,12 @@ RandomForest::RandomForest(int32_t rand_numfeats, int32_t num_bags) m_machine->as()->set_feature_subset_size(rand_numfeats); } -RandomForest::RandomForest(std::shared_ptr features, std::shared_ptr labels, int32_t num_bags, int32_t rand_numfeats) -: BaggingMachine() -{ - init(); - m_features=std::move(features); - set_labels(std::move(labels)); - - set_num_bags(num_bags); - - if (rand_numfeats>0) - m_machine->as()->set_feature_subset_size(rand_numfeats); -} -RandomForest::RandomForest(std::shared_ptr features, std::shared_ptr labels, SGVector weights, int32_t num_bags, int32_t rand_numfeats) +RandomForest::RandomForest(SGVector weights, int32_t num_bags, int32_t rand_numfeats) : BaggingMachine() { init(); - m_features=std::move(features); - set_labels(std::move(labels)); m_weights=weights; set_num_bags(num_bags); @@ -163,24 +149,17 @@ void RandomForest::set_machine_parameters(std::shared_ptr m, SGVectorset_machine_problem_type(m_machine->as()->get_machine_problem_type()); } -bool RandomForest::train_machine(std::shared_ptr data) +bool RandomForest::train_machine(const std::shared_ptr& data, const std::shared_ptr& labs) { - if (data) - { - m_features = data; - } - - require(m_features, "Training features not set!"); - - m_machine->as()->pre_sort_features(m_features, m_sorted_transposed_feats, m_sorted_indices); - return BaggingMachine::train_machine(); + m_machine->as()->pre_sort_features(data, m_sorted_transposed_feats, m_sorted_indices); + m_num_features = data->as>()->get_num_features(); + return BaggingMachine::train_machine(data, labs); } SGVector RandomForest::get_feature_importances() const { - auto num_feats = - m_features->as>()->get_num_features(); + const auto& num_feats = m_num_features; SGVector feat_importances(num_feats); feat_importances.zero(); for (size_t i = 0; i < m_bags.size(); i++) diff --git a/src/shogun/machine/RandomForest.h b/src/shogun/machine/RandomForest.h index 8990d5d25ee..e7eebbe93df 100644 --- a/src/shogun/machine/RandomForest.h +++ b/src/shogun/machine/RandomForest.h @@ -56,15 +56,6 @@ class RandomForest : public BaggingMachine */ RandomForest(int32_t num_rand_feats, int32_t num_bags=10); - /** constructor - * - * @param features training features - * @param labels training labels - * @param num_bags number of trees in forest - * @param num_rand_feats number of attributes chosen randomly during node split in candidate trees - */ - RandomForest(std::shared_ptr features, std::shared_ptr labels, int32_t num_bags=10, int32_t num_rand_feats=0); - /** constructor * * @param features training features @@ -73,7 +64,7 @@ class RandomForest : public BaggingMachine * @param num_bags number of trees in forest * @param num_rand_feats number of attributes chosen randomly during node split in candidate trees */ - RandomForest(std::shared_ptr features, std::shared_ptr labels, SGVector weights, int32_t num_bags=10, int32_t num_rand_feats=0); + RandomForest(SGVector weights, int32_t num_bags=10, int32_t num_rand_feats=0); /** destructor */ ~RandomForest() override; @@ -146,7 +137,7 @@ class RandomForest : public BaggingMachine protected: - bool train_machine(std::shared_ptr data=NULL) override; + bool train_machine(const std::shared_ptr& data, const std::shared_ptr& labs) override; /** sets parameters of CARTree - sets machine labels and weights here * * @param m machine @@ -159,6 +150,7 @@ class RandomForest : public BaggingMachine void init(); private: + int32_t m_num_features; /** weights */ SGVector m_weights; diff --git a/src/shogun/machine/StochasticGBMachine.cpp b/src/shogun/machine/StochasticGBMachine.cpp index 9b8cee1bbce..a8be1f9e0a6 100644 --- a/src/shogun/machine/StochasticGBMachine.cpp +++ b/src/shogun/machine/StochasticGBMachine.cpp @@ -237,8 +237,7 @@ std::shared_ptr StochasticGBMachine::fit_model(const std::shared_ptrclone()->as(); // train cloned machine - c->set_labels(labels); - c->train(feats); + c->train(feats, labels); return c; } diff --git a/src/shogun/multiclass/tree/CARTree.cpp b/src/shogun/multiclass/tree/CARTree.cpp index 217081ed3cc..780f4b1edc6 100644 --- a/src/shogun/multiclass/tree/CARTree.cpp +++ b/src/shogun/multiclass/tree/CARTree.cpp @@ -75,17 +75,6 @@ CARTree::~CARTree() { } -void CARTree::set_labels(std::shared_ptr lab) -{ - if (lab->get_label_type()==LT_MULTICLASS) - set_machine_problem_type(PT_MULTICLASS); - else if (lab->get_label_type()==LT_REGRESSION) - set_machine_problem_type(PT_REGRESSION); - else - error("label type supplied is not supported"); - - m_labels=lab; -} void CARTree::set_machine_problem_type(EProblemType mode) { @@ -255,11 +244,11 @@ bool CARTree::weights_set() return m_weights.size() != 0; } -bool CARTree::train_machine(std::shared_ptr data) +bool CARTree::train_machine(const std::shared_ptr& data, const std::shared_ptr& labs) { require(data,"Data required for training"); require(data->get_feature_class()==C_DENSE,"Dense data required for training"); - + set_machine_problem_type(labs); auto dense_features = data->as>(); auto num_features = dense_features->get_num_features(); auto num_vectors = dense_features->get_num_vectors(); @@ -292,12 +281,12 @@ bool CARTree::train_machine(std::shared_ptr data) linalg::set_const(m_nominal, false); } - auto dense_labels = m_labels->as(); + auto dense_labels = labs->as(); set_root(CARTtrain(dense_features,m_weights,dense_labels,0)); if (m_apply_cv_pruning) { - prune_by_cross_validation(dense_features,m_folds); + prune_by_cross_validation(dense_features, labs, m_folds); } // compute feature importances and normalize it if (m_root) @@ -1223,7 +1212,7 @@ std::shared_ptr CARTree::apply_from_current_node(const std::shared_ptr>& data, int32_t folds) +void CARTree::prune_by_cross_validation(const std::shared_ptr>& data, const std::shared_ptr& labs, int32_t folds) { auto num_vecs=data->get_num_vectors(); @@ -1254,7 +1243,7 @@ void CARTree::prune_by_cross_validation(const std::shared_ptr subset(train_indices.data(),train_indices.size(),false); - auto dense_labels = m_labels->as(); + auto dense_labels = labs->as(); auto feats_train = view(data, subset); auto labels_train = view(dense_labels, subset); SGVector subset_weights(train_indices.size()); diff --git a/src/shogun/multiclass/tree/CARTree.h b/src/shogun/multiclass/tree/CARTree.h index 7a90eb6e259..4c341176da5 100644 --- a/src/shogun/multiclass/tree/CARTree.h +++ b/src/shogun/multiclass/tree/CARTree.h @@ -105,11 +105,6 @@ class CARTree : public RandomMixin> /** destructor */ ~CARTree() override; - /** set labels - automagically switch machine problem type based on type of labels supplied - * @param lab labels - */ - void set_labels(std::shared_ptr lab) override; - /** get name * @return class name CARTree */ @@ -248,7 +243,7 @@ class CARTree : public RandomMixin> * @param data training data * @return true */ - bool train_machine(std::shared_ptr data=NULL) override; + bool train_machine(const std::shared_ptr& data, const std::shared_ptr& labs) override; /** CARTtrain - recursive CART training method * @@ -387,7 +382,7 @@ class CARTree : public RandomMixin> * @param data training data * @param folds the integer V for V-fold cross validation */ - void prune_by_cross_validation(const std::shared_ptr>& data, int32_t folds); + void prune_by_cross_validation(const std::shared_ptr>& data, const std::shared_ptr& labs, int32_t folds); /** computes error in classification/regression * for classification it eveluates weight_missclassified/total_weight @@ -429,6 +424,16 @@ class CARTree : public RandomMixin> /** initializes members of class */ void init(); + + void set_machine_problem_type(const std::shared_ptr& labs) + { + if (labs->get_label_type()==LT_MULTICLASS) + set_machine_problem_type(PT_MULTICLASS); + else if (labs->get_label_type()==LT_REGRESSION) + set_machine_problem_type(PT_REGRESSION); + else + error("label type supplied is not supported"); + } public: /** denotes that a feature in a vector is missing MISSING = NOT_A_NUMBER */ static const float64_t MISSING; diff --git a/tests/unit/machine/MockMachine.h b/tests/unit/machine/MockMachine.h index d29c8532925..6b8479308ff 100644 --- a/tests/unit/machine/MockMachine.h +++ b/tests/unit/machine/MockMachine.h @@ -10,6 +10,7 @@ namespace shogun { public: MOCK_METHOD1(apply, std::shared_ptr(std::shared_ptr)); MOCK_METHOD1(train_machine, bool(std::shared_ptr)); + MOCK_METHOD2(train_machine, bool(const std::shared_ptr&, const std::shared_ptr&)); MOCK_CONST_METHOD1(clone, std::shared_ptr(ParameterProperties)); virtual const char* get_name() const { return "MockMachine"; } diff --git a/tests/unit/multiclass/BaggingMachine_unittest.cc b/tests/unit/multiclass/BaggingMachine_unittest.cc index ea3e3388c93..3a2145eb1ba 100644 --- a/tests/unit/multiclass/BaggingMachine_unittest.cc +++ b/tests/unit/multiclass/BaggingMachine_unittest.cc @@ -79,7 +79,7 @@ TEST_F(BaggingMachineTest, mock_train) auto features = std::make_shared>(); auto labels = std::make_shared>(); - auto bm = std::make_shared(features, labels); + auto bm = std::make_shared(); auto mm = std::make_shared>(); auto mv = std::make_shared(); @@ -90,7 +90,7 @@ TEST_F(BaggingMachineTest, mock_train) bm->set_combination_rule(mv); bm->put("seed", seed); - ON_CALL(*mm, train_machine(_)) + ON_CALL(*mm, train_machine(_, _)) .WillByDefault(Return(true)); ON_CALL(*features, get_num_vectors()) @@ -103,13 +103,13 @@ TEST_F(BaggingMachineTest, mock_train) .Times(1) .WillRepeatedly(Return(mm)); - EXPECT_CALL(*mm, train_machine(_)) + EXPECT_CALL(*mm, train_machine(_, _)) .Times(1) .WillRepeatedly(Return(true)); } } - bm->train(); + bm->train(features_train, labels_train); EXPECT_TRUE(Mock::VerifyAndClearExpectations(mm.get())); } @@ -120,7 +120,7 @@ TEST_F(BaggingMachineTest, classify_CART) auto cv=std::make_shared(); cart->set_feature_types(ft); - auto c = std::make_shared(features_train, labels_train); + auto c = std::make_shared(); env()->set_num_threads(1); c->set_machine(cart); @@ -128,7 +128,7 @@ TEST_F(BaggingMachineTest, classify_CART) c->set_num_bags(10); c->set_combination_rule(cv); c->put("seed", seed); - c->train(features_train); + c->train(features_train, labels_train); auto result = c->apply_multiclass(features_test); SGVector res_vector=result->get_labels(); @@ -151,14 +151,14 @@ TEST_F(BaggingMachineTest, output_binary) auto cv = std::make_shared(); cart->set_feature_types(ft); - auto c = std::make_shared(features_train, labels_train); + auto c = std::make_shared(); env()->set_num_threads(1); c->set_machine(cart); c->set_bag_size(14); c->set_num_bags(10); c->set_combination_rule(cv); c->put("seed", seed); - c->train(features_train); + c->train(features_train, labels_train); auto result = c->apply_binary(features_test); SGVector res_vector = result->get_labels(); @@ -185,13 +185,13 @@ TEST_F(BaggingMachineTest, output_multiclass_probs_sum_to_one) auto cv = std::make_shared(); cart->set_feature_types(ft); - auto c = std::make_shared(features_train, labels_train); + auto c = std::make_shared(); c->set_machine(cart); c->set_bag_size(14); c->set_num_bags(10); c->set_combination_rule(cv); c->put("seed", seed); - c->train(features_train); + c->train(features_train, labels_train); auto result = c->apply_multiclass(features_test); diff --git a/tests/unit/multiclass/tree/CARTree_unittest.cc b/tests/unit/multiclass/tree/CARTree_unittest.cc index 2af2a21702f..390e33995d9 100644 --- a/tests/unit/multiclass/tree/CARTree_unittest.cc +++ b/tests/unit/multiclass/tree/CARTree_unittest.cc @@ -155,9 +155,8 @@ TEST(CARTree, classify_nominal) auto labels=std::make_shared(lab); auto c=std::make_shared(); - c->set_labels(labels); c->set_feature_types(ft); - c->train(feats); + c->train(feats, labels); SGMatrix test(4,5); test(0,0)=overcast; @@ -218,8 +217,7 @@ TEST(CARTree, comparable_with_sklearn) auto labels = std::make_shared(y); auto c = std::make_shared(); - c->set_labels(labels); - c->train(feats); + c->train(feats, labels); auto feat_import = c->get_feature_importance(); // those data are generated by below sklearn program EXPECT_NEAR(0.111111, feat_import[0], 0.00001); @@ -342,7 +340,7 @@ TEST(CARTree, classify_non_nominal) auto c=std::make_shared(); c->set_labels(labels); c->set_feature_types(ft); - c->train(feats); + c->train(feats, labels); SGMatrix test(4,5); test(0,0)=overcast; @@ -445,7 +443,7 @@ TEST(CARTree, handle_missing_nominal) auto c=std::make_shared(); c->set_labels(labels); c->set_feature_types(ft); - c->train(feats); + c->train(feats, labels); auto root=c->get_root()->as>(); auto left=root->left(); @@ -516,9 +514,8 @@ TEST(CARTree, handle_missing_continuous) auto labels=std::make_shared(lab); auto c=std::make_shared(); - c->set_labels(labels); c->set_feature_types(ft); - c->train(feats); + c->train(feats, labels); auto root=c->get_root()->as>(); auto left=root->left(); @@ -553,9 +550,8 @@ TEST(CARTree, form_t1_test) auto labels=std::make_shared(lab); auto c=std::make_shared(); - c->set_labels(labels); c->set_feature_types(ft); - c->train(feats); + c->train(feats, labels); auto root=c->get_root(); EXPECT_EQ(2,root->data.num_leaves); @@ -643,9 +639,8 @@ TEST(CARTree,cv_prune_simple) auto labels=std::make_shared(lab); auto c=std::make_shared(); - c->set_labels(labels); c->set_feature_types(ft); - c->train(feats); + c->train(feats, labels); auto root=c->get_root()->as>(); @@ -654,7 +649,7 @@ TEST(CARTree,cv_prune_simple) c->set_num_folds(2); c->set_cv_pruning(true); - c->train(feats); + c->train(feats, labels); root=c->get_root()->as>(); diff --git a/tests/unit/multiclass/tree/RandomCARTree_unittest.cc b/tests/unit/multiclass/tree/RandomCARTree_unittest.cc index 599d63fb12d..8961d72d042 100644 --- a/tests/unit/multiclass/tree/RandomCARTree_unittest.cc +++ b/tests/unit/multiclass/tree/RandomCARTree_unittest.cc @@ -153,11 +153,10 @@ TEST(RandomCARTree, classify_nominal) auto labels=std::make_shared(lab); auto c=std::make_shared(); - c->set_labels(labels); c->set_feature_types(ft); c->set_feature_subset_size(4); c->put("seed", seed); - c->train(feats); + c->train(feats, labels); SGMatrix test(4,5); test(0,0)=overcast; diff --git a/tests/unit/multiclass/tree/RandomForest_unittest.cc b/tests/unit/multiclass/tree/RandomForest_unittest.cc index 67766b111a9..82998d3781c 100644 --- a/tests/unit/multiclass/tree/RandomForest_unittest.cc +++ b/tests/unit/multiclass/tree/RandomForest_unittest.cc @@ -93,13 +93,13 @@ TEST_F(RandomForestTest, classify_nominal_test) { int32_t seed = 2343; auto c = - std::make_shared(weather_features_train, weather_labels_train, 100, 2); + std::make_shared(2, 100); c->set_feature_types(weather_ft); auto mv = std::make_shared(); c->set_combination_rule(mv); env()->set_num_threads(1); c->put("seed", seed); - c->train(weather_features_train); + c->train(weather_features_train, weather_labels_train); auto result = c->apply(weather_features_test)->as(); @@ -126,13 +126,13 @@ TEST_F(RandomForestTest, classify_non_nominal_test) weather_ft[3] = false; auto c = - std::make_shared(weather_features_train, weather_labels_train, 100, 2); + std::make_shared(2, 100); c->set_feature_types(weather_ft); auto mv = std::make_shared(); c->set_combination_rule(mv); env()->set_num_threads(1); c->put("seed", seed); - c->train(weather_features_train); + c->train(weather_features_train, weather_labels_train); auto result = c->apply(weather_features_test)->as(); @@ -147,7 +147,7 @@ TEST_F(RandomForestTest, classify_non_nominal_test) std::shared_ptr eval=std::make_shared(); c->put(RandomForest::kOobEvaluationMetric, eval); - EXPECT_NEAR(0.714285,c->get(RandomForest::kOobError),1e-6); + EXPECT_NEAR(0.7142857,c->get(RandomForest::kOobError),1e-6); } TEST_F(RandomForestTest, score_compare_sklearn_toydata) @@ -166,7 +166,7 @@ TEST_F(RandomForestTest, score_compare_sklearn_toydata) SGVector lab {0.0, 0.0, 1.0, 1.0}; auto labels_train = std::make_shared(lab); - auto c = std::make_shared(features_train, labels_train, 10, 2); + auto c = std::make_shared(2, 10); SGVector ft = SGVector(2); ft[0] = false; ft[1] = false; @@ -175,7 +175,7 @@ TEST_F(RandomForestTest, score_compare_sklearn_toydata) auto mr = std::make_shared(); c->set_combination_rule(mr); c->put("seed", seed); - c->train(features_train); + c->train(features_train, labels_train); auto result = c->apply_binary(features_train); SGVector res_vector = result->get_labels(); @@ -226,7 +226,7 @@ TEST_F(RandomForestTest, score_consistent_with_binary_trivial_data) std::make_shared>(test_data); auto c = - std::make_shared(features_train, labels_train, num_trees, 1); + std::make_shared(1, num_trees); SGVector ft = SGVector(1); ft[0] = false; c->set_feature_types(ft); @@ -234,7 +234,7 @@ TEST_F(RandomForestTest, score_consistent_with_binary_trivial_data) auto mr = std::make_shared(); c->set_combination_rule(mr); c->put("seed", seed); - c->train(features_train); + c->train(features_train, labels_train); auto result = c->apply_binary(features_test); SGVector res_vector = result->get_labels();