Skip to content

Commit

Permalink
Refactor BaggingMachine (#5103)
Browse files Browse the repository at this point in the history
* make BaggingMachine stateless
* change get_oob_error to lambda
* fix meta example
* fix segfault
  • Loading branch information
LiuYuHui authored Aug 5, 2020
1 parent afbdeac commit 551a102
Show file tree
Hide file tree
Showing 18 changed files with 135 additions and 199 deletions.
11 changes: 5 additions & 6 deletions doc/ipython-notebooks/multiclass/Tree/DecisionTrees.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -974,10 +974,9 @@
" c = sg.create_machine(\"CARTree\", nominal=feat_types,\n",
" mode=problem_type,\n",
" folds=num_folds,\n",
" apply_cv_pruning=use_cv_pruning,\n",
" labels=labels)\n",
" apply_cv_pruning=use_cv_pruning)\n",
" # train using training features\n",
" c.train(feats)\n",
" c.train(feats, labels)\n",
" \n",
" return c\n",
"\n",
Expand Down Expand Up @@ -1408,7 +1407,7 @@
" c = sg.create_machine(\"CHAIDTree\", dependent_vartype=dependent_var_type,\n",
" feature_types=feature_types,\n",
" num_breakpoints=num_bins,\n",
" labels=labels)\n",
" labels = labels)\n",
" # train using training features\n",
" c.train(feats)\n",
" \n",
Expand Down Expand Up @@ -1722,9 +1721,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
}
19 changes: 7 additions & 12 deletions doc/ipython-notebooks/multiclass/Tree/TreeEnsemble.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@
"outputs": [],
"source": [
"# train forest\n",
"rand_forest.put('labels', train_labels)\n",
"rand_forest.train(train_feats)\n",
"rand_forest.train(train_feats, train_labels)\n",
"\n",
"# load test dataset\n",
"testfeat_file= os.path.join(SHOGUN_DATA_DIR, 'uci/letter/test_fm_letter.dat')\n",
Expand Down Expand Up @@ -142,9 +141,8 @@
" c=sg.create_machine(\"CARTree\", nominal=feature_types,\n",
" mode=problem_type,\n",
" folds=2,\n",
" apply_cv_pruning=False,\n",
" labels=train_labels)\n",
" c.train(train_feats)\n",
" apply_cv_pruning=False)\n",
" c.train(train_feats, train_labels)\n",
" \n",
" return c\n",
"\n",
Expand Down Expand Up @@ -213,8 +211,7 @@
"source": [
"def get_rf_accuracy(num_trees,rand_subset_size):\n",
" rf=setup_random_forest(num_trees,rand_subset_size,comb_rule,feat_types)\n",
" rf.put('labels', train_labels)\n",
" rf.train(train_feats)\n",
" rf.train(train_feats, train_labels)\n",
" out_test=rf.apply_multiclass(test_feats)\n",
" acc=sg.create_evaluation(\"MulticlassAccuracy\")\n",
" return acc.evaluate(out_test,test_labels)"
Expand Down Expand Up @@ -365,8 +362,7 @@
"outputs": [],
"source": [
"rf=setup_random_forest(100,2,comb_rule,feat_types)\n",
"rf.put('labels', train_labels)\n",
"rf.train(train_feats)\n",
"rf.train(train_feats, train_labels)\n",
" \n",
"# set evaluation strategy\n",
"rf.put(\"oob_evaluation_metric\", sg.create_evaluation(\"MulticlassAccuracy\"))\n",
Expand Down Expand Up @@ -411,8 +407,7 @@
"def get_oob_errors_wine(num_trees,rand_subset_size):\n",
" feat_types=np.array([False]*13)\n",
" rf=setup_random_forest(num_trees,rand_subset_size,sg.create_combination_rule(\"MajorityVote\"),feat_types)\n",
" rf.put('labels', train_labels)\n",
" rf.train(train_feats)\n",
" rf.train(train_feats, train_labels)\n",
" rf.put(\"oob_evaluation_metric\", sg.create_evaluation(\"MulticlassAccuracy\"))\n",
" return rf.get(\"oob_error\") \n",
"\n",
Expand Down Expand Up @@ -494,7 +489,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.6.9"
}
},
"nbformat": 4,
Expand Down
3 changes: 1 addition & 2 deletions examples/meta/src/multiclass/cartree.sg.in
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ ft[1] = False
#![create_instance]
Machine classifier = create_machine("CARTree", nominal = ft,mode = enum EProblemType.PT_MULTICLASS, folds=5, apply_cv_pruning=True, seed=1)

classifier.set_labels(labels_train)
#![create_instance]

#![train_and_apply]
classifier.train(features_train)
classifier.train(features_train, labels_train)
MulticlassLabels labels_predict = classifier.apply_multiclass(features_test)
#![train_and_apply]

Expand Down
4 changes: 2 additions & 2 deletions examples/meta/src/multiclass/random_forest.sg.in
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ CombinationRule m_vote = create_combination_rule("MajorityVote")
#![create_combination_rule]

#![create_instance]
Machine rand_forest = create_machine("RandomForest", labels=labels_train, num_bags=100, combination_rule=m_vote, seed=1)
Machine rand_forest = create_machine("RandomForest", num_bags=100, combination_rule=m_vote, seed=1)
Parallel p = rand_forest.get_global_parallel()
p.set_num_threads(1)
#![create_instance]

#![train_and_apply]
rand_forest.train(features_train)
rand_forest.train(features_train, labels_train)
MulticlassLabels labels_predict = rand_forest.apply_multiclass(features_test)
#![train_and_apply]

Expand Down
4 changes: 2 additions & 2 deletions examples/meta/src/regression/cartree.sg.in
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ ft[0] = False
#![set_attribute_types]

#![create_machine]
Machine cartree = create_machine("CARTree", labels=labels_train, nominal=ft, mode=enum EProblemType.PT_REGRESSION, folds=5, apply_cv_pruning=True, seed=1)
Machine cartree = create_machine("CARTree", nominal=ft, mode=enum EProblemType.PT_REGRESSION, folds=5, apply_cv_pruning=True, seed=1)
#![create_machine]

#![train_and_apply]
cartree.train(feats_train)
cartree.train(feats_train, labels_train)
Labels labels_predict = cartree.apply(feats_test)
#![train_and_apply]

Expand Down
7 changes: 4 additions & 3 deletions examples/meta/src/regression/random_forest_regression.sg.in
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ CombinationRule mean_rule = create_combination_rule("MeanRule")
#![create_combination_rule]

#![create_instance]
Machine rand_forest = create_machine("RandomForest", labels=labels_train, num_bags=5, seed=1, combination_rule=mean_rule)
Machine rand_forest = create_machine("RandomForest", num_bags=5, seed=1, combination_rule=mean_rule)
#![create_instance]

#![train_and_apply]
rand_forest.train(features_train)
RegressionLabels labels_predict = rand_forest.apply_regression(features_test)
rand_forest.train(features_train, labels_train)
Labels labels_predict = rand_forest.apply_regression(features_test)
#![train_and_apply]

#![evaluate_error]
Expand All @@ -32,3 +32,4 @@ real mserror = mse.evaluate(labels_predict, labels_test)

# additional integration testing variables
RealVector output = labels_predict.get_real_vector("labels")

95 changes: 40 additions & 55 deletions src/shogun/machine/BaggingMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ BaggingMachine::BaggingMachine() : RandomMixin<Machine>()
register_parameters();
}

BaggingMachine::BaggingMachine(std::shared_ptr<Features> features, std::shared_ptr<Labels> labels)
: BaggingMachine()
{
set_labels(std::move(labels));
m_features = std::move(features);
}

std::shared_ptr<BinaryLabels> BaggingMachine::apply_binary(std::shared_ptr<Features> data)
{
Expand All @@ -48,21 +42,12 @@ std::shared_ptr<MulticlassLabels> BaggingMachine::apply_multiclass(std::shared_p
{
SGMatrix<float64_t> bagged_outputs =
apply_outputs_without_combination(data);

require(m_labels, "Labels not set.");
require(
m_labels->get_label_type() == LT_MULTICLASS,
"Labels ({}) are not compatible with multiclass.",
m_labels->get_name());

auto labels_multiclass = std::dynamic_pointer_cast<MulticlassLabels>(m_labels);
auto num_samples = bagged_outputs.size() / m_num_bags;
auto num_classes = labels_multiclass->get_num_classes();

auto pred = std::make_shared<MulticlassLabels>(num_samples);
pred->allocate_confidences_for(num_classes);
pred->allocate_confidences_for(m_num_classes);

SGMatrix<float64_t> class_probabilities(num_classes, num_samples);
SGMatrix<float64_t> class_probabilities(m_num_classes, num_samples);
class_probabilities.zero();

for (auto i = 0; i < num_samples; ++i)
Expand Down Expand Up @@ -125,27 +110,24 @@ BaggingMachine::apply_outputs_without_combination(std::shared_ptr<Features> data
return output;
}

bool BaggingMachine::train_machine(std::shared_ptr<Features> data)
bool BaggingMachine::train_machine(const std::shared_ptr<Features>& data, const std::shared_ptr<Labels>& labs)
{
require(m_machine != NULL, "Machine is not set!");
require(m_num_bags > 0, "Number of bag is not set!");

if (data)
m_num_vectors = data->get_num_vectors();
if(auto multiclass_labs = std::dynamic_pointer_cast<MulticlassLabels>(labs))
{
m_features = data;

ASSERT(m_features->get_num_vectors() == m_labels->get_num_labels());
m_num_classes = multiclass_labs->get_num_classes();
}

// if bag size is not provided, set it equal to number of training vectors
if (m_bag_size == 0)
m_bag_size = m_features->get_num_vectors();
m_bag_size = data->get_num_vectors();

// clear the array, if previously trained
m_bags.clear();

// reset the oob index vector
m_all_oob_idx = SGVector<bool>(m_features->get_num_vectors());
m_all_oob_idx = SGVector<bool>(data->get_num_vectors());
m_all_oob_idx.zero();


Expand All @@ -160,24 +142,27 @@ bool BaggingMachine::train_machine(std::shared_ptr<Features> data)
{
auto c=std::dynamic_pointer_cast<Machine>(m_machine->clone());
ASSERT(c != NULL);
SGVector<index_t> idx(
rnd_indicies.get_column_vector(i), m_bag_size, false);
SGVector<index_t> idx(rnd_indicies.get_column_vector(i), m_bag_size, false);

std::shared_ptr<Features> features;
std::shared_ptr<Labels> labels;

if (env()->get_num_threads() == 1)
{
features = m_features;
labels = m_labels;
features = data;
labels = labs;
}
else
{
features = m_features->shallow_subset_copy();
labels = m_labels->shallow_subset_copy();
features = data->shallow_subset_copy();
labels = labs->shallow_subset_copy();
}

labels->add_subset(idx);
#pragma omp critical
{
labels->add_subset(idx);
features->add_subset(idx);
}

/* TODO:
if it's a binary labeling ensure that
there's always samples of both classes
Expand All @@ -194,12 +179,15 @@ bool BaggingMachine::train_machine(std::shared_ptr<Features> data)
}
}
*/
features->add_subset(idx);

set_machine_parameters(c, idx);
c->set_labels(labels);
c->train(features);
features->remove_subset();
labels->remove_subset();
c->train(features, labels);
#pragma omp critical
{
features->remove_subset();
labels->remove_subset();
}


#pragma omp critical
{
Expand All @@ -214,7 +202,7 @@ bool BaggingMachine::train_machine(std::shared_ptr<Features> data)
pb.print_progress();
}
pb.complete();

get_oob_error_lambda = [=](){return get_oob_error_impl(data, labs);};
return true;
}

Expand All @@ -224,7 +212,6 @@ void BaggingMachine::set_machine_parameters(std::shared_ptr<Machine> m, SGVector

void BaggingMachine::register_parameters()
{
SG_ADD(&m_features, kFeatures, "Train features for bagging");
SG_ADD(
&m_num_bags, kNBags, "Number of bags", ParameterProperties::HYPER);
SG_ADD(
Expand Down Expand Up @@ -275,9 +262,7 @@ void BaggingMachine::set_machine(std::shared_ptr<Machine> machine)
void BaggingMachine::init()
{
m_machine = nullptr;
m_features = nullptr;
m_combination_rule = nullptr;
m_labels = nullptr;
m_num_bags = 0;
m_bag_size = 0;
m_all_oob_idx = SGVector<bool>();
Expand All @@ -294,16 +279,16 @@ std::shared_ptr<CombinationRule> BaggingMachine::get_combination_rule() const
return m_combination_rule;
}

float64_t BaggingMachine::get_oob_error() const
float64_t BaggingMachine::get_oob_error_impl(const std::shared_ptr<Features>& data, const std::shared_ptr<Labels>& labs) const
{
require(
m_oob_evaluation_metric, "Out of bag evaluation metric is not set!");
require(m_combination_rule, "Combination rule is not set!");
require(m_bags.size() > 0, "BaggingMachine is not trained!");

SGMatrix<float64_t> output(
m_features->get_num_vectors(), m_bags.size());
if (m_labels->get_label_type() == LT_REGRESSION)
m_num_vectors, m_bags.size());
if (labs->get_label_type() == LT_REGRESSION)
output.zero();
else
output.set_const(NAN);
Expand All @@ -318,9 +303,9 @@ float64_t BaggingMachine::get_oob_error() const
auto current_oob = m_oob_indices[i];

SGVector<index_t> oob(current_oob.data(), current_oob.size(), false);
m_features->add_subset(oob);
data->add_subset(oob);

auto l = m->apply(m_features);
auto l = m->apply(data);
SGVector<float64_t> lv;
if (l!=NULL)
lv = std::dynamic_pointer_cast<DenseLabels>(l)->get_labels();
Expand All @@ -331,14 +316,14 @@ float64_t BaggingMachine::get_oob_error() const
for (index_t j = 0; j < oob.vlen; j++)
output(oob[j], i) = lv[j];

m_features->remove_subset();
data->remove_subset();



}

std::vector<index_t> idx;
for (index_t i = 0; i < m_features->get_num_vectors(); i++)
for (index_t i = 0; i < data->get_num_vectors(); i++)
{
if (m_all_oob_idx[i])
idx.push_back(i);
Expand All @@ -350,7 +335,7 @@ float64_t BaggingMachine::get_oob_error() const
lab[i] = combined[idx[i]];

std::shared_ptr<Labels> predicted = NULL;
switch (m_labels->get_label_type())
switch (labs->get_label_type())
{
case LT_BINARY:
predicted = std::make_shared<BinaryLabels>(lab);
Expand All @@ -369,16 +354,16 @@ float64_t BaggingMachine::get_oob_error() const
}


m_labels->add_subset(SGVector<index_t>(idx.data(), idx.size(), false));
float64_t res = m_oob_evaluation_metric->evaluate(predicted, m_labels);
m_labels->remove_subset();
labs->add_subset(SGVector<index_t>(idx.data(), idx.size(), false));
float64_t res = m_oob_evaluation_metric->evaluate(predicted, labs);
labs->remove_subset();

return res;
}

std::vector<index_t> BaggingMachine::get_oob_indices(const SGVector<index_t>& in_bag)
{
SGVector<bool> out_of_bag(m_features->get_num_vectors());
SGVector<bool> out_of_bag(m_num_vectors);
out_of_bag.set_const(true);

// mark the ones that are in_bag
Expand Down
Loading

0 comments on commit 551a102

Please sign in to comment.