From 2b093cfa8ff88b41e297af2eb60fe203e8550cf0 Mon Sep 17 00:00:00 2001 From: LiuYuhui Date: Fri, 31 Jul 2020 14:41:46 +0800 Subject: [PATCH] Refactor NearestCentroid class (#5053) * Add NonParametricMachine class (#5055) * add nonparametric machine * fix notebooks * Refactor NearestCentroid class --- .../classification/Classification.ipynb | 4 +-- examples/meta/src/multiclass/lmnn.sg.in | 2 +- src/shogun/classifier/NearestCentroid.cpp | 31 ++++--------------- src/shogun/classifier/NearestCentroid.h | 13 ++------ .../classifier/NearestCentroid_unittest.cc | 31 +++++++++++++++++++ 5 files changed, 43 insertions(+), 38 deletions(-) create mode 100644 tests/unit/classifier/NearestCentroid_unittest.cc diff --git a/doc/ipython-notebooks/classification/Classification.ipynb b/doc/ipython-notebooks/classification/Classification.ipynb index a24fa498f63..56277fdd9ac 100644 --- a/doc/ipython-notebooks/classification/Classification.ipynb +++ b/doc/ipython-notebooks/classification/Classification.ipynb @@ -441,7 +441,7 @@ "distances_linear.init(shogun_feats_linear, shogun_feats_linear)\n", "knn_linear = sg.create_machine(\"KNN\", k=number_of_neighbors, distance=distances_linear, \n", " labels=shogun_labels_linear)\n", - "knn_linear.train()\n", + "knn_linear.train(shogun_feats_linear)\n", "classifiers_linear.append(knn_linear)\n", "classifiers_names.append(\"Nearest Neighbors\")\n", "fadings.append(False)\n", @@ -455,7 +455,7 @@ "distances_non_linear.init(shogun_feats_non_linear, shogun_feats_non_linear)\n", "knn_non_linear = sg.create_machine(\"KNN\", k=number_of_neighbors, distance=distances_non_linear, \n", " labels=shogun_labels_non_linear)\n", - "knn_non_linear.train()\n", + "knn_non_linear.train(shogun_feats_non_linear)\n", "classifiers_non_linear.append(knn_non_linear)\n", "\n", "plt.subplot(122)\n", diff --git a/examples/meta/src/multiclass/lmnn.sg.in b/examples/meta/src/multiclass/lmnn.sg.in index ab64a312efa..ccd6c00dc53 100644 --- a/examples/meta/src/multiclass/lmnn.sg.in +++ b/examples/meta/src/multiclass/lmnn.sg.in @@ -20,7 +20,7 @@ Machine knn = create_machine("KNN", k=k,distance=lmnn_distance,labels=labels_tra #![create_instance] #![train_and_apply] -knn.train() +knn.train(features_train) Labels labels_predict = knn.apply(features_test) RealVector output = labels_predict.get_real_vector("labels") #![train_and_apply] diff --git a/src/shogun/classifier/NearestCentroid.cpp b/src/shogun/classifier/NearestCentroid.cpp index 1598df603d7..16ff2913059 100644 --- a/src/shogun/classifier/NearestCentroid.cpp +++ b/src/shogun/classifier/NearestCentroid.cpp @@ -17,43 +17,24 @@ namespace shogun{ NearestCentroid::NearestCentroid() : DistanceMachine() { - init(); } - NearestCentroid::NearestCentroid(const std::shared_ptr& d, const std::shared_ptr& trainlab) : DistanceMachine() + NearestCentroid::NearestCentroid(const std::shared_ptr& d) : DistanceMachine() { - init(); ASSERT(d) - ASSERT(trainlab) set_distance(d); - set_labels(trainlab); } NearestCentroid::~NearestCentroid() { } - void NearestCentroid::init() - { - m_shrinking=0; - m_is_trained=false; - } - - bool NearestCentroid::train_machine(std::shared_ptr data) { - ASSERT(m_labels) - ASSERT(distance) - if (data) - { - if (m_labels->get_num_labels() != data->get_num_vectors()) - error("Number of training vectors does not match number of labels"); - distance->init(data, data); - } - else - { - data = distance->get_lhs(); - } + require(distance, "Distance not set"); + require(m_labels->get_num_labels() == data->get_num_vectors(), + "Number of training vectors does not match number of labels"); + distance->init(data, data); auto multiclass_labels = m_labels->as(); auto dense_data = data->as>(); @@ -83,7 +64,7 @@ namespace shogun{ linalg::scale(centroids, centroids, scale); auto centroids_feats = std::make_shared>(centroids); - + m_centroids = centroids_feats; m_is_trained=true; distance->init(centroids_feats, distance->get_rhs()); diff --git a/src/shogun/classifier/NearestCentroid.h b/src/shogun/classifier/NearestCentroid.h index ecb4e87653d..14fd4af4489 100644 --- a/src/shogun/classifier/NearestCentroid.h +++ b/src/shogun/classifier/NearestCentroid.h @@ -45,7 +45,7 @@ class NearestCentroid : public DistanceMachine{ * @param distance distance * @param trainlab labels for training */ - NearestCentroid(const std::shared_ptr& distance, const std::shared_ptr& trainlab); + NearestCentroid(const std::shared_ptr& distance); /** Destructor */ @@ -92,26 +92,19 @@ class NearestCentroid : public DistanceMachine{ */ bool train_machine(std::shared_ptr data=NULL) override; - /** Stores feature data of underlying model. - * - * Sets centroids as lhs - */ - -private: - void init(); protected: /// number of classes (i.e. number of values labels can take) int32_t m_num_classes; /// Shrinking parameter - float64_t m_shrinking; + float64_t m_shrinking = 0; /// The centroids of the trained features std::shared_ptr> m_centroids; /// Tells if the classifier has been trained or not - bool m_is_trained; + bool m_is_trained = false; }; } diff --git a/tests/unit/classifier/NearestCentroid_unittest.cc b/tests/unit/classifier/NearestCentroid_unittest.cc new file mode 100644 index 00000000000..036a433c549 --- /dev/null +++ b/tests/unit/classifier/NearestCentroid_unittest.cc @@ -0,0 +1,31 @@ +/* + * This software is distributed under BSD 3-clause license (see LICENSE file). + * + * Authors: Yuhui Liu + */ +#include +#include +#include +#include + +using namespace shogun; +TEST(NearestCentroid, fit_and_predict) +{ + SGMatrix X{{-10, -1}, {-2, -1}, {-3, -2}, + {1, 1}, {2, 1}, {3, 2}}; + SGVector y{0, 0, 0, 1, 1, 1}; + + auto train_data = std::make_shared>(X); + auto train_labels = std::make_shared(y); + auto distance = std::make_shared(); + + SGMatrix t{{3, 2}, {-10, -1}, {-100, 100}}; + auto test_data = std::make_shared>(t); + auto clf = std::make_shared(distance); + clf->train(train_data, train_labels); + auto result_labels = clf->apply(test_data); + auto result = result_labels->as()->get_labels(); + EXPECT_EQ(result[0], 1); + EXPECT_EQ(result[1], 0); + EXPECT_EQ(result[2], 0); +} \ No newline at end of file