Skip to content

Commit

Permalink
Refactor NearestCentroid class (#5053)
Browse files Browse the repository at this point in the history
* Add NonParametricMachine class (#5055)
* add nonparametric machine
* fix notebooks
* Refactor NearestCentroid class
  • Loading branch information
LiuYuHui authored and gf712 committed Dec 8, 2020
1 parent 0d54ccf commit 2b093cf
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 38 deletions.
4 changes: 2 additions & 2 deletions doc/ipython-notebooks/classification/Classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@
"distances_linear.init(shogun_feats_linear, shogun_feats_linear)\n",
"knn_linear = sg.create_machine(\"KNN\", k=number_of_neighbors, distance=distances_linear, \n",
" labels=shogun_labels_linear)\n",
"knn_linear.train()\n",
"knn_linear.train(shogun_feats_linear)\n",
"classifiers_linear.append(knn_linear)\n",
"classifiers_names.append(\"Nearest Neighbors\")\n",
"fadings.append(False)\n",
Expand All @@ -455,7 +455,7 @@
"distances_non_linear.init(shogun_feats_non_linear, shogun_feats_non_linear)\n",
"knn_non_linear = sg.create_machine(\"KNN\", k=number_of_neighbors, distance=distances_non_linear, \n",
" labels=shogun_labels_non_linear)\n",
"knn_non_linear.train()\n",
"knn_non_linear.train(shogun_feats_non_linear)\n",
"classifiers_non_linear.append(knn_non_linear)\n",
"\n",
"plt.subplot(122)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/meta/src/multiclass/lmnn.sg.in
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Machine knn = create_machine("KNN", k=k,distance=lmnn_distance,labels=labels_tra
#![create_instance]

#![train_and_apply]
knn.train()
knn.train(features_train)
Labels labels_predict = knn.apply(features_test)
RealVector output = labels_predict.get_real_vector("labels")
#![train_and_apply]
Expand Down
31 changes: 6 additions & 25 deletions src/shogun/classifier/NearestCentroid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,24 @@ namespace shogun{

NearestCentroid::NearestCentroid() : DistanceMachine()
{
init();
}

NearestCentroid::NearestCentroid(const std::shared_ptr<Distance>& d, const std::shared_ptr<Labels>& trainlab) : DistanceMachine()
NearestCentroid::NearestCentroid(const std::shared_ptr<Distance>& d) : DistanceMachine()
{
init();
ASSERT(d)
ASSERT(trainlab)
set_distance(d);
set_labels(trainlab);
}

NearestCentroid::~NearestCentroid()
{
}

void NearestCentroid::init()
{
m_shrinking=0;
m_is_trained=false;
}


bool NearestCentroid::train_machine(std::shared_ptr<Features> data)
{
ASSERT(m_labels)
ASSERT(distance)
if (data)
{
if (m_labels->get_num_labels() != data->get_num_vectors())
error("Number of training vectors does not match number of labels");
distance->init(data, data);
}
else
{
data = distance->get_lhs();
}
require(distance, "Distance not set");
require(m_labels->get_num_labels() == data->get_num_vectors(),
"Number of training vectors does not match number of labels");
distance->init(data, data);

auto multiclass_labels = m_labels->as<MulticlassLabels>();
auto dense_data = data->as<DenseFeatures<float64_t>>();
Expand Down Expand Up @@ -83,7 +64,7 @@ namespace shogun{
linalg::scale(centroids, centroids, scale);

auto centroids_feats = std::make_shared<DenseFeatures<float64_t>>(centroids);

m_centroids = centroids_feats;
m_is_trained=true;
distance->init(centroids_feats, distance->get_rhs());

Expand Down
13 changes: 3 additions & 10 deletions src/shogun/classifier/NearestCentroid.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class NearestCentroid : public DistanceMachine{
* @param distance distance
* @param trainlab labels for training
*/
NearestCentroid(const std::shared_ptr<Distance>& distance, const std::shared_ptr<Labels>& trainlab);
NearestCentroid(const std::shared_ptr<Distance>& distance);

/** Destructor
*/
Expand Down Expand Up @@ -92,26 +92,19 @@ class NearestCentroid : public DistanceMachine{
*/
bool train_machine(std::shared_ptr<Features> data=NULL) override;

/** Stores feature data of underlying model.
*
* Sets centroids as lhs
*/

private:
void init();

protected:
/// number of classes (i.e. number of values labels can take)
int32_t m_num_classes;

/// Shrinking parameter
float64_t m_shrinking;
float64_t m_shrinking = 0;

/// The centroids of the trained features
std::shared_ptr<DenseFeatures<float64_t>> m_centroids;

/// Tells if the classifier has been trained or not
bool m_is_trained;
bool m_is_trained = false;
};

}
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/classifier/NearestCentroid_unittest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Yuhui Liu
*/
#include <gtest/gtest.h>
#include <shogun/classifier/NearestCentroid.h>
#include <shogun/distance/EuclideanDistance.h>
#include <shogun/labels/MulticlassLabels.h>

using namespace shogun;
TEST(NearestCentroid, fit_and_predict)
{
SGMatrix<float64_t> X{{-10, -1}, {-2, -1}, {-3, -2},
{1, 1}, {2, 1}, {3, 2}};
SGVector<float64_t> y{0, 0, 0, 1, 1, 1};

auto train_data = std::make_shared<DenseFeatures<float64_t>>(X);
auto train_labels = std::make_shared<MulticlassLabels>(y);
auto distance = std::make_shared<EuclideanDistance>();

SGMatrix<float64_t> t{{3, 2}, {-10, -1}, {-100, 100}};
auto test_data = std::make_shared<DenseFeatures<float64_t>>(t);
auto clf = std::make_shared<NearestCentroid>(distance);
clf->train(train_data, train_labels);
auto result_labels = clf->apply(test_data);
auto result = result_labels->as<MulticlassLabels>()->get_labels();
EXPECT_EQ(result[0], 1);
EXPECT_EQ(result[1], 0);
EXPECT_EQ(result[2], 0);
}

0 comments on commit 2b093cf

Please sign in to comment.