From 704d1fb0b4b9bc6af226eca0fccf6199f679d9b2 Mon Sep 17 00:00:00 2001 From: gf712 Date: Tue, 4 Feb 2020 15:55:14 +0000 Subject: [PATCH] fix bugs switched to shared_ptr, std::variant, etc.. --- src/shogun/io/ARFFFile.cpp | 214 +++++++++++++++++-------------------- src/shogun/io/ARFFFile.h | 108 ++++++++----------- 2 files changed, 144 insertions(+), 178 deletions(-) diff --git a/src/shogun/io/ARFFFile.cpp b/src/shogun/io/ARFFFile.cpp index f5cadf66503..542ffef436e 100644 --- a/src/shogun/io/ARFFFile.cpp +++ b/src/shogun/io/ARFFFile.cpp @@ -13,13 +13,6 @@ using namespace shogun; using namespace shogun::arff_detail; -const char* ARFFDeserializer::m_comment_string = "%"; -const char* ARFFDeserializer::m_relation_string = "@relation"; -const char* ARFFDeserializer::m_attribute_string = "@attribute"; -const char* ARFFDeserializer::m_data_string = "@data"; -const char* ARFFDeserializer::m_default_date_format = "%Y-%M-%DT%H:%M:%S"; -const char* ARFFDeserializer::m_missing_value_string = "?"; - /** * Visitor pattern to reserve memory for a std::vector * wrapped in a variant class. @@ -51,9 +44,9 @@ struct VectorSizeVisitor template T buffer_to_type(const std::string& buffer) { - SG_SERROR( - "No conversion from \"%s\" to \"%s\"!\n", buffer.c_str(), - demangled_type()) + error( + "No conversion from {} to {}!\n", buffer.c_str(), + demangled_type()); } template <> int8_t buffer_to_type(const std::string& buffer) @@ -94,7 +87,7 @@ floatmax_t buffer_to_type(const std::string& buffer) template void ARFFDeserializer::read_helper() { - std::vector, std::vector>>> data_vectors; m_line_number = 0; @@ -104,7 +97,7 @@ void ARFFDeserializer::read_helper() if (to_lower(m_current_line.substr(0, 1)) == m_comment_string) m_comments.push_back(m_current_line.substr(1, std::string::npos)); else if ( - to_lower(m_current_line.substr(0, strlen(m_relation_string))) == + to_lower(m_current_line.substr(0, m_relation_string.size())) == m_relation_string) m_state = true; }; @@ -112,12 +105,12 @@ void ARFFDeserializer::read_helper() process_chunk(read_comment, check_comment, false); auto read_relation = [this]() { - if (to_lower(m_current_line.substr(0, strlen(m_relation_string))) == + if (to_lower(m_current_line.substr(0, m_relation_string.size())) == m_relation_string) m_relation = remove_whitespace( - m_current_line.substr(strlen(m_relation_string))); + m_current_line.substr(m_relation_string.size())); else if ( - to_lower(m_current_line.substr(0, strlen(m_attribute_string))) == + to_lower(m_current_line.substr(0, m_attribute_string.size())) == m_attribute_string) m_state = true; }; @@ -127,12 +120,12 @@ void ARFFDeserializer::read_helper() // parse the @attributes section auto read_attributes = [this, &data_vectors]() { - if (to_lower(m_current_line.substr(0, strlen(m_attribute_string))) == + if (to_lower(m_current_line.substr(0, m_attribute_string.size())) == m_attribute_string) { std::string name, type; auto inner_string = - m_current_line.substr(strlen(m_attribute_string)); + m_current_line.substr(m_attribute_string.size()); left_trim(inner_string, [](const auto& val) { return !std::isspace(val); }); @@ -145,10 +138,10 @@ void ARFFDeserializer::read_helper() while (*it != quote_type && it != inner_string.end()) ++it; if (it == inner_string.end()) - SG_SERROR( + error( "Encountered unbalanced parenthesis in attribute " - "declaration on line %d: \"%s\"\n", - m_line_number, m_current_line.c_str()) + "declaration on line {}: \"{}\"\n", + m_line_number, m_current_line); name = {begin, it}; type = trim({std::next(it), inner_string.end()}); } @@ -158,26 +151,26 @@ void ARFFDeserializer::read_helper() while (!std::isspace(*it)) ++it; if (it == inner_string.end() && it != inner_string.end()) - SG_SERROR( + error( "Expected at least two elements in attribute " - "declaration on line %d: \"%s\"", - m_line_number, m_current_line.c_str()) + "declaration on line {}: \"{}\"", + m_line_number, m_current_line); name = {begin, it}; type = trim({std::next(it), inner_string.end()}); } - SG_SDEBUG("name: %s\n", name.c_str()) - SG_SDEBUG("type: %s\n", type.c_str()) + SG_DEBUG("name: {}\n", name); + SG_DEBUG("type: {}\n", type); if (name.empty() || type.empty()) - SG_SERROR( - "Could not find the name and type on line %d: \"%s\".\n", - m_line_number, m_current_line.c_str()) + error( + "Could not find the name and type on line {}: \"{}\".\n", + m_line_number, m_current_line); if (it == inner_string.end()) - SG_SERROR( - "Could not split attibute name and type on line %d: " - "\"%s\".\n", - m_line_number, m_current_line.c_str()) + error( + "Could not split attibute name and type on line {}: " + "\"{}\".\n", + m_line_number, m_current_line); // check if it is nominal if (type[0] == '{') @@ -226,9 +219,9 @@ void ARFFDeserializer::read_helper() } else { - SG_SERROR( - "Error parsing date on line %d: %s\n", m_line_number, - m_current_line.c_str()) + error( + "Error parsing date on line {}: {}\n", m_line_number, + m_current_line); } m_attributes.push_back(Attribute::DATE); data_vectors.emplace_back(std::vector{}); @@ -260,15 +253,15 @@ void ARFFDeserializer::read_helper() std::vector>{}); } else - SG_SERROR( - "Unexpected attribute type identifier \"%s\" " - "on line %d: %s\n", - type.c_str(), m_line_number, m_current_line.c_str()) + error( + "Unexpected attribute type identifier \"{}\" " + "on line {}: {}\n", + type, m_line_number, m_current_line); } else - SG_SERROR( - "Unexpected format in @ATTRIBUTE on line %d: %s\n", - m_line_number, m_current_line.c_str()); + error( + "Unexpected format in @ATTRIBUTE on line {}: {}\n", + m_line_number, m_current_line); auto processed_name = trim(name, [](const auto& val) { return !std::isspace(val) && val != '\'' && val != '\"'; }); @@ -279,7 +272,7 @@ void ARFFDeserializer::read_helper() { } else if ( - to_lower(m_current_line.substr(0, strlen(m_data_string))) == + to_lower(m_current_line.substr(0, m_data_string.size())) == m_data_string) m_state = true; }; @@ -307,7 +300,7 @@ void ARFFDeserializer::read_helper() if (m_current_line.substr(0, 1) == m_comment_string) return; // it's the data string (i.e. @data"), does not provide information - if (to_lower(m_current_line.substr(0, strlen(m_data_string))) == + if (to_lower(m_current_line.substr(0, m_data_string.size())) == m_data_string) return; @@ -315,10 +308,10 @@ void ARFFDeserializer::read_helper() elems.clear(); split(m_current_line, ",", std::back_inserter(elems), "\'\""); if (elems.size() != m_attributes.size()) - SG_SERROR( - "Unexpected number of values on line %d, expected %d " - "values, but found %d.\n", - m_line_number, m_attributes.size(), elems.size()) + error( + "Unexpected number of values on line {}, expected {} " + "values, but found {}.\n", + m_line_number, m_attributes.size(), elems.size()); // only parse rows that do not contain missing values if (std::find(elems.begin(), elems.end(), m_missing_value_string) == elems.end()) @@ -336,22 +329,22 @@ void ARFFDeserializer::read_helper() { try { - shogun::get>(data_vectors[i]) + std::get>(data_vectors[i]) .push_back(buffer_to_type(elems[i])); } catch (const std::invalid_argument&) { - SG_SERROR( - "Failed to covert \"%s\" to numeric on line %d.\n", - elems[i].c_str(), m_line_number) + error( + "Failed to covert \"{}\" to numeric on line %d.\n", + elems[i], m_line_number); } } break; case (Attribute::NOMINAL): { if (nominal_pos == m_nominal_attributes.end()) - SG_SERROR( - "Unexpected nominal value \"%s\" on line %d\n", + error( + "Unexpected nominal value \"{}\" on line {}\n", elems[i].c_str(), m_line_number); auto encoding = (*nominal_pos).second; auto trimmed_el = trim(elems[i]); @@ -359,11 +352,11 @@ void ARFFDeserializer::read_helper() auto pos = std::find(encoding.begin(), encoding.end(), trimmed_el); if (pos == encoding.end()) - SG_SERROR( - "Unexpected value \"%s\" on line %d\n", - trimmed_el.c_str(), m_line_number); + error( + "Unexpected value \"{}\" on line %d\n", + trimmed_el, m_line_number); ScalarType idx = std::distance(encoding.begin(), pos); - shogun::get>(data_vectors[i]) + std::get>(data_vectors[i]) .push_back(idx); ++nominal_pos; } @@ -373,27 +366,27 @@ void ARFFDeserializer::read_helper() date::sys_seconds t; std::istringstream ss(elems[i]); if (date_pos == m_date_formats.end()) - SG_SERROR( - "Unexpected date value \"%s\" on line %d.\n", - elems[i].c_str(), m_line_number); + error( + "Unexpected date value \"{}\" on line {}.\n", + elems[i], m_line_number); ss >> date::parse(*date_pos, t); if (bool(ss)) { auto value_timestamp = t.time_since_epoch().count(); - shogun::get>(data_vectors[i]) + std::get>(data_vectors[i]) .push_back(value_timestamp); } else - SG_SERROR( - "Error parsing date \"%s\" with date format \"%s\" " - "on line %d.\n", - elems[i].c_str(), (*date_pos).c_str(), - m_line_number) + error( + "Error parsing date \"{}\" with date format \"{}\" " + "on line {}.\n", + elems[i], *date_pos, + m_line_number); ++date_pos; } break; case (Attribute::STRING): - shogun::get>>( + std::get>>( data_vectors[i]) .emplace_back(elems[i]); } @@ -405,14 +398,14 @@ void ARFFDeserializer::read_helper() if (!data_vectors.empty()) { auto feature_count = data_vectors.size(); - index_t row_count = - shogun::visit(VectorSizeVisitor{}, data_vectors[0]); + size_t row_count = + std::visit(VectorSizeVisitor{}, data_vectors[0]); for (int i = 1; i < feature_count; ++i) { - REQUIRE( - shogun::visit(VectorSizeVisitor{}, data_vectors[i]) == + require( + std::visit(VectorSizeVisitor{}, data_vectors[i]) == row_count, - "All columns must have the same number of features!\n") + "All columns must have the same number of features!\n"); } } else @@ -422,7 +415,7 @@ void ARFFDeserializer::read_helper() process_chunk(read_data, check_data, true); // transform data into a feature object - index_t row_count = shogun::visit(VectorSizeVisitor{}, data_vectors[0]); + index_t row_count = std::visit(VectorSizeVisitor{}, data_vectors[0]); for (int i = 0; i < data_vectors.size(); ++i) { Attribute att = m_attributes[i]; @@ -435,35 +428,35 @@ void ARFFDeserializer::read_helper() case Attribute::DATE: case Attribute::NOMINAL: { - auto casted_vec = shogun::get>(vec); + auto casted_vec = std::get>(vec); SGMatrix mat(1, row_count); memcpy( mat.matrix, casted_vec.data(), casted_vec.size() * sizeof(ScalarType)); - m_features.emplace_back(new CDenseFeatures(mat)); + m_features.push_back(std::make_shared>(mat)); } break; case Attribute::STRING: { auto casted_vec = - shogun::get>>(vec); + std::get>>(vec); index_t max_string_length = 0; for (const auto& el : casted_vec) { if (max_string_length < el.size()) max_string_length = el.size(); } - SGStringList strings(row_count, max_string_length); + std::vector> strings(row_count, max_string_length); for (int j = 0; j < row_count; ++j) { - SGString current(max_string_length); + SGVector current(max_string_length); memcpy( - current.string, casted_vec[j].data(), + current.vector, casted_vec[j].data(), (casted_vec.size() + 1) * sizeof(CharType)); - strings.strings[j] = current; + strings[j] = current; } - m_features.emplace_back( - new CStringFeatures(strings, EAlphabet::RAWBYTE)); + m_features.push_back( + std::make_shared>(strings, EAlphabet::RAWBYTE)); } } } @@ -481,11 +474,11 @@ void ARFFDeserializer::read_string_dispatcher() break; case EPrimitiveType::PT_UINT16: { - SG_SNOTIMPLEMENTED + error("16-bit wide string conversion not available."); } break; default: - SG_SERROR("The provided type for string parsing is not valid!\n") + error("The provided type for string parsing is not valid!\n"); } } @@ -529,19 +522,19 @@ void ARFFDeserializer::read() } break; default: - SG_SERROR("The provided type for scalar parsing is not valid!\n") + error("The provided type for scalar parsing is not valid!\n"); } } template void ARFFDeserializer::reserve_vector_memory( size_t line_count, - std::vector, std::vector>>>& v) { VectorResizeVisitor visitor{line_count}; for (auto& vec : v) - shogun::visit(visitor, vec); + std::visit(visitor, vec); } /** @@ -549,7 +542,7 @@ void ARFFDeserializer::reserve_vector_memory( * @param obj * @return */ -std::vector features_to_string(CSGObject* obj, Attribute att) +std::vector features_to_string(const std::shared_ptr& obj, Attribute att) { std::vector result_string; switch (att) @@ -597,14 +590,14 @@ std::vector features_to_string(CSGObject* obj, Attribute att) } break; default: - SG_SERROR("Unsupported type: %d\n", static_cast(att)) + error("Unsupported type: {}\n", static_cast(att)); } - SG_SERROR("The provided feature object does not have a feature matrix!\n") + error("The provided feature object does not have a feature matrix!\n"); return std::vector{}; } std::vector features_to_string( - CSGObject* obj, const std::vector& nominal_values) + const std::shared_ptr& obj, const std::vector& nominal_values) { std::vector result_string; auto mat_to_string = [&result_string, &nominal_values](const auto& mat) { @@ -624,7 +617,7 @@ std::vector features_to_string( shogun::None{}, mat_to_string); return result_string; } - SG_SERROR("The provided feature object does not have a feature matrix!\n") + error("The provided feature object does not have a feature matrix!\n"); return std::vector{}; } @@ -657,7 +650,7 @@ std::unique_ptr ARFFSerializer::write() << " string\n"; break; case Attribute::DATE: - SG_SNOTIMPLEMENTED + error("C++ to Java date format conversion is not implement!"); break; case Attribute::NOMINAL: { @@ -679,49 +672,38 @@ std::unique_ptr ARFFSerializer::write() // @data *ss << "\n" << ARFFDeserializer::m_data_string << "\n\n"; - auto* obj = m_feature_list->get_first_element(); - auto num_vectors = obj->as()->get_num_vectors(); - for (int i = 0; i < m_feature_list->get_num_elements(); ++i) - { - auto n_i = obj->as()->get_num_vectors(); - SG_UNREF(obj) - REQUIRE( - n_i == num_vectors, - "Expected all features to have the same number of examples!\n") - // in the last iteration this will be nullptr so don't need to deref - obj = m_feature_list->get_next_element(); - } - + auto num_vectors = m_feature_list.back()->as()->get_num_vectors(); std::vector> result; auto att_iter = m_attributes.begin(); - obj = m_feature_list->get_first_element(); - - for (int i = 0; i < m_feature_list->get_num_elements(); ++i) + for (const auto& feature: m_feature_list) { + auto n_i = feature->as()->get_num_vectors(); + require( + n_i == num_vectors, + "Expected all features to have the same number of examples!\n"); + switch (att_iter->second) { case Attribute::NUMERIC: case Attribute::REAL: case Attribute::INTEGER: - result.push_back(features_to_string(obj, att_iter->second)); + result.push_back(features_to_string(feature, att_iter->second)); break; case Attribute::NOMINAL: result.push_back( - features_to_string(obj, m_nominal_mapping.at(att_iter->first))); + features_to_string(feature, m_nominal_mapping.at(att_iter->first))); break; case Attribute::DATE: case Attribute::STRING: - SG_SNOTIMPLEMENTED + error("Writing out strings and dates has not been implemented!"); } - SG_UNREF(obj) - obj = m_feature_list->get_next_element(); ++att_iter; } std::vector result_rows(num_vectors); - for (auto col = 0; col != result.size(); ++col) + for (size_t col = 0; col != result.size(); ++col) { if (col != result.size() - 1) for (auto row = 0; row != num_vectors; ++row) diff --git a/src/shogun/io/ARFFFile.h b/src/shogun/io/ARFFFile.h index 171e6cc0222..622ea3fca2c 100644 --- a/src/shogun/io/ARFFFile.h +++ b/src/shogun/io/ARFFFile.h @@ -7,8 +7,6 @@ #ifndef SHOGUN_ARFFFILE_H #define SHOGUN_ARFFFILE_H -#include -#include #include #include #include @@ -20,6 +18,7 @@ #include #include #include +#include namespace shogun { @@ -131,9 +130,9 @@ namespace shogun ++it; } if (it == s.end()) - SG_SERROR( - "Encountered unbalanced parenthesis in \"%s\"\n", - std::string(std::prev(begin), it).c_str()) + error( + "Encountered unbalanced parenthesis in \"{}\"\n", + std::string(std::prev(begin), it)); *(result++) = {begin, it}; } else @@ -231,8 +230,8 @@ namespace shogun if (java_token == "Z") return "%z"; if (java_token == "z") - SG_SERROR( - "Timezone abbreviations are currently not supported.\n") + error( + "Timezone abbreviations are currently not supported.\n"); if (java_token.empty()) return ""; if (java_token == "SSS") @@ -264,10 +263,10 @@ namespace shogun if (auto cpp_token = process_javatoken(java_time_token)) return cpp_token; else - SG_SERROR( - "Could not convert Java time token \"%s\" to C++ time " + error( + "Could not convert Java time token \"{}\" to C++ time " "token.\n", - java_time_token.c_str()) + java_time_token); return nullptr; } @@ -277,10 +276,10 @@ namespace shogun if (auto cpp_token = process_javatoken(java_time_token)) return cpp_token; else - SG_SERROR( - "Could not convert Java time token \"%c\" to C++ time " + error( + "Could not convert Java time token \"{}\" to C++ time " "token.\n", - java_time_token) + java_time_token); return nullptr; } @@ -329,10 +328,10 @@ namespace shogun cpp_time.append(cpp_token); } else - SG_SERROR( - "Could not convert Java time token %s to C++ time " + error( + "Could not convert Java time token {} to C++ time " "token.\n", - token.c_str()) + token); } ++it; } @@ -385,10 +384,10 @@ namespace shogun auto* file_stream = new std::ifstream(filename); if (file_stream->fail()) { - SG_SERROR( - "Cannot open %s. Please check if file exists and if you " + error( + "Cannot open {}. Please check if file exists and if you " "have the right permissions to open it.\n", - filename.c_str()) + filename); } m_stream = std::unique_ptr(file_stream); } @@ -439,17 +438,16 @@ namespace shogun * column to be excluded, i.e. it's a label and not a feature. * @return a list of features */ - CList* get_features(const std::string& label_name) const + std::vector> get_features(const std::string& label_name) const { auto find_label = std::find( m_attribute_names.begin(), m_attribute_names.end(), label_name); if (find_label == m_attribute_names.end()) - SG_SERROR( - "The provided label \"%s\" was not found!\n", - label_name.c_str()) + error( + "The provided label \"{}\" was not found!\n", + label_name); - auto result = new CList(true); - SG_REF(result) + std::vector> result; int idx = 0; int label_idx = @@ -457,10 +455,7 @@ namespace shogun for (const auto& feat : m_features) { if (idx != label_idx) - { - auto* feat_i = feat.get(); - result->append_element(feat_i); - } + result.push_back(feat); ++idx; } @@ -471,37 +466,27 @@ namespace shogun * Get list of features from parsed data. * @return a list of features */ - CList* get_features() const + std::vector> get_features() const { - auto result = new CList(true); - SG_REF(result) - - for (const auto& feat : m_features) - { - auto* feat_i = feat.get(); - result->append_element(feat_i); - } - - return result; + return m_features; } /** * Get feature by name. * @return the requested feature if it exists. */ - CFeatures* get_feature(const std::string& feature_name) const + std::shared_ptr get_feature(const std::string& feature_name) const { auto find_feature = std::find( m_attribute_names.begin(), m_attribute_names.end(), feature_name); if (find_feature == m_attribute_names.end()) - SG_SERROR( - "The provided label \"%s\" was not found!\n", - feature_name.c_str()) + error( + "The provided label \"{}\" was not found!\n", + feature_name); int feature_idx = std::distance(m_attribute_names.begin(), find_feature); - auto* result = m_features[feature_idx].get(); - SG_REF(result) + auto result = m_features[feature_idx]; return result; } @@ -530,23 +515,23 @@ namespace shogun if (nom_att.first == feature_name) return nom_att.second; } - SG_SERROR("The provided feature name is not a nominal feature!\n") + error("The provided feature name is not a nominal feature!\n"); return std::vector{}; } protected: /** character used in file to comment out a line */ - static const char* m_comment_string; + static constexpr std::string_view m_comment_string = "%"; /** characters to declare relations, i.e. @relation */ - static const char* m_relation_string; + static constexpr std::string_view m_relation_string = "@relation"; /** characters to declare attributes, i.e. @attribute */ - static const char* m_attribute_string; + static constexpr std::string_view m_attribute_string = "@attribute"; /** characters to declare data fields, i.e. @data */ - static const char* m_data_string; + static constexpr std::string_view m_data_string = "@data"; /** the default C++ date format specified by the ARFF standard */ - static const char* m_default_date_format; + static constexpr std::string_view m_default_date_format = "%Y-%M-%DT%H:%M:%S"; /** missing data */ - static const char* m_missing_value_string; + static constexpr std::string_view m_missing_value_string = "?"; private: /** @@ -590,9 +575,9 @@ namespace shogun } if (!check_func()) { - SG_SERROR( - "Parsing error on line %d: %s\n", m_line_number, - m_current_line.c_str()); + error( + "Parsing error on line {}: {}\n", m_line_number, + m_current_line); } } @@ -635,7 +620,7 @@ namespace shogun template void reserve_vector_memory( size_t line_count, - std::vector, std::vector>>>& v); @@ -672,7 +657,7 @@ namespace shogun m_nominal_attributes; /** the parsed features */ - std::vector> m_features; + std::vector> m_features; }; /** @@ -694,15 +679,14 @@ namespace shogun * strings whose index will be used to infer the nominal value */ ARFFSerializer( - const std::string& name, CList* feature_list, + const std::string& name, std::vector> feature_list, const std::vector>& attributes, const std::unordered_map>& nominal_mapping) : m_name(name), m_attributes(attributes), m_nominal_mapping(nominal_mapping) { - SG_REF(feature_list) - m_feature_list = feature_list; + m_feature_list = std::move(feature_list); } #ifndef SWIG @@ -725,7 +709,7 @@ namespace shogun /** the name of the dataset */ std::string m_name; /** the list of features to write out */ - CList* m_feature_list; + std::vector> m_feature_list; /** the attributes */ std::vector> m_attributes; /** the nominal attributes, if any */