Skip to content

Commit

Permalink
added feature name getter
Browse files Browse the repository at this point in the history
  • Loading branch information
gf712 committed May 14, 2019
1 parent 4162485 commit 4174168
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
10 changes: 7 additions & 3 deletions src/shogun/io/ARFFFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void ARFFDeserializer::read()
std::make_pair(name, attributes));
m_attributes.push_back(Attribute::Nominal);
m_data_vectors.emplace_back(std::vector<float64_t>{});
m_attribute_names.emplace_back(name);
return;
}

Expand All @@ -105,6 +106,7 @@ void ARFFDeserializer::read()
else
m_date_formats.push_back(
javatime_to_cpptime(date_elements[1]));
name = "";
}
else if (date_elements[1] == "date" && date_elements.size() < 4)
{
Expand Down Expand Up @@ -159,6 +161,7 @@ void ARFFDeserializer::read()
SG_SERROR(
"Unexpected format in @ATTRIBUTE on line %d: %s\n",
m_line_number, m_current_line.c_str());
m_attribute_names.emplace_back(name);
}
// comments in this section are ignored
else if (m_current_line.substr(0, 1) == m_comment_string)
Expand Down Expand Up @@ -224,13 +227,14 @@ void ARFFDeserializer::read()
"Unexpected nominal value \"%s\" on line %d\n",
elems[i].c_str(), m_line_number);
auto encoding = (*nominal_pos).second;
remove_char_inplace(elems[i], '\'');
auto trimmed_el = trim(elems[i]);
remove_char_inplace(trimmed_el, '\'');
auto pos =
std::find(encoding.begin(), encoding.end(), elems[i]);
std::find(encoding.begin(), encoding.end(), trimmed_el);
if (pos == encoding.end())
SG_SERROR(
"Unexpected value \"%s\" on line %d\n",
elems[i].c_str(), m_line_number);
trimmed_el.c_str(), m_line_number);
float64_t idx = std::distance(encoding.begin(), pos);
shogun::get<std::vector<float64_t>>(m_data_vectors[i])
.push_back(idx);
Expand Down
11 changes: 11 additions & 0 deletions src/shogun/io/ARFFFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,15 @@ namespace shogun
return m_relation;
}

/**
* Returns string parsed in @relation line
* @return the relation string
*/
SG_FORCED_INLINE std::vector<std::string> get_feature_names() const noexcept
{
return m_attribute_names;
}

/**
* Get combined features from parsed data
* @return
Expand Down Expand Up @@ -496,6 +505,8 @@ namespace shogun
static const char* m_data_string;
/** the default C++ date format specified by the ARFF standard */
static const char* m_default_date_format;
/** the name of the attributes */
std::vector<std::string> m_attribute_names;

/** internal line number counter for exceptions */
size_t m_line_number;
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/io/ARFFFile_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ TEST(ARFFFileTest, Parse_nominal)
"% \n"
"% \n"
"@data \n"
"\"a\", 50 \n"
" \'a\', 50 \n"
"b, 26 \n"
"\"b\", 34 \n"
"\'c 1\', 41 \n"
"\"b\" , 34 \n"
" \'c 1\' , 41 \n"
"\"¯\\_(ツ)_/¯\", 44 \n"
"a, 45 ";

Expand Down Expand Up @@ -174,4 +174,7 @@ TEST(ARFFFileTest, Parse_nominal)
ASSERT_EQ(mat2[i], solution2[i]);
}
ASSERT_EQ(parser->get_relation(), "test_nominal");
ASSERT_EQ(parser->get_feature_names().size(), 2);
ASSERT_EQ(parser->get_feature_names()[0], "VAR1");
ASSERT_EQ(parser->get_feature_names()[1], "VAR2");
}

0 comments on commit 4174168

Please sign in to comment.