diff --git a/src/shogun/io/ARFFFile.cpp b/src/shogun/io/ARFFFile.cpp index 998f9e32f6a..a3679dfd4bd 100644 --- a/src/shogun/io/ARFFFile.cpp +++ b/src/shogun/io/ARFFFile.cpp @@ -88,6 +88,7 @@ void ARFFDeserializer::read() std::make_pair(name, attributes)); m_attributes.push_back(Attribute::Nominal); m_data_vectors.emplace_back(std::vector{}); + m_attribute_names.emplace_back(name); return; } @@ -105,6 +106,7 @@ void ARFFDeserializer::read() else m_date_formats.push_back( javatime_to_cpptime(date_elements[1])); + name = ""; } else if (date_elements[1] == "date" && date_elements.size() < 4) { @@ -159,6 +161,7 @@ void ARFFDeserializer::read() SG_SERROR( "Unexpected format in @ATTRIBUTE on line %d: %s\n", m_line_number, m_current_line.c_str()); + m_attribute_names.emplace_back(name); } // comments in this section are ignored else if (m_current_line.substr(0, 1) == m_comment_string) @@ -224,13 +227,14 @@ void ARFFDeserializer::read() "Unexpected nominal value \"%s\" on line %d\n", elems[i].c_str(), m_line_number); auto encoding = (*nominal_pos).second; - remove_char_inplace(elems[i], '\''); + auto trimmed_el = trim(elems[i]); + remove_char_inplace(trimmed_el, '\''); auto pos = - std::find(encoding.begin(), encoding.end(), elems[i]); + std::find(encoding.begin(), encoding.end(), trimmed_el); if (pos == encoding.end()) SG_SERROR( "Unexpected value \"%s\" on line %d\n", - elems[i].c_str(), m_line_number); + trimmed_el.c_str(), m_line_number); float64_t idx = std::distance(encoding.begin(), pos); shogun::get>(m_data_vectors[i]) .push_back(idx); diff --git a/src/shogun/io/ARFFFile.h b/src/shogun/io/ARFFFile.h index 2336ed88048..71e02971782 100644 --- a/src/shogun/io/ARFFFile.h +++ b/src/shogun/io/ARFFFile.h @@ -397,6 +397,15 @@ namespace shogun return m_relation; } + /** + * Returns string parsed in @relation line + * @return the relation string + */ + SG_FORCED_INLINE std::vector get_feature_names() const noexcept + { + return m_attribute_names; + } + /** * Get combined features from parsed data * @return @@ -496,6 +505,8 @@ namespace shogun static const char* m_data_string; /** the default C++ date format specified by the ARFF standard */ static const char* m_default_date_format; + /** the name of the attributes */ + std::vector m_attribute_names; /** internal line number counter for exceptions */ size_t m_line_number; diff --git a/tests/unit/io/ARFFFile_unittest.cc b/tests/unit/io/ARFFFile_unittest.cc index cb8e662a8b6..158f3e46446 100644 --- a/tests/unit/io/ARFFFile_unittest.cc +++ b/tests/unit/io/ARFFFile_unittest.cc @@ -143,10 +143,10 @@ TEST(ARFFFileTest, Parse_nominal) "% \n" "% \n" "@data \n" - "\"a\", 50 \n" + " \'a\', 50 \n" "b, 26 \n" - "\"b\", 34 \n" - "\'c 1\', 41 \n" + "\"b\" , 34 \n" + " \'c 1\' , 41 \n" "\"¯\\_(ツ)_/¯\", 44 \n" "a, 45 "; @@ -174,4 +174,7 @@ TEST(ARFFFileTest, Parse_nominal) ASSERT_EQ(mat2[i], solution2[i]); } ASSERT_EQ(parser->get_relation(), "test_nominal"); + ASSERT_EQ(parser->get_feature_names().size(), 2); + ASSERT_EQ(parser->get_feature_names()[0], "VAR1"); + ASSERT_EQ(parser->get_feature_names()[1], "VAR2"); } \ No newline at end of file