From b70398d08f8973f0a0889e57bfc6bee49f62a160 Mon Sep 17 00:00:00 2001 From: gf712 Date: Tue, 7 May 2019 16:29:34 +0100 Subject: [PATCH] moved json dependency to library --- src/shogun/CMakeLists.txt | 2 +- src/shogun/io/OpenMLFlow.cpp | 197 +++++++++++++++++++++++++++++------ src/shogun/io/OpenMLFlow.h | 97 ++++++++++++----- 3 files changed, 236 insertions(+), 60 deletions(-) diff --git a/src/shogun/CMakeLists.txt b/src/shogun/CMakeLists.txt index 8290d3a49b4..506d4863cfd 100644 --- a/src/shogun/CMakeLists.txt +++ b/src/shogun/CMakeLists.txt @@ -412,7 +412,7 @@ SHOGUN_DEPENDENCIES( CONFIG_FLAG HAVE_XML) # RapidJSON include(external/RapidJSON) -SHOGUN_INCLUDE_DIRS(SCOPE PUBLIC ${RAPIDJSON_INCLUDE_DIR}) +SHOGUN_INCLUDE_DIRS(SCOPE PRIVATE ${RAPIDJSON_INCLUDE_DIR}) if (NOT WIN32) # FIXME: HDF5 linking on WIN32 is broken. diff --git a/src/shogun/io/OpenMLFlow.cpp b/src/shogun/io/OpenMLFlow.cpp index 505869d6789..a7ef3279f97 100644 --- a/src/shogun/io/OpenMLFlow.cpp +++ b/src/shogun/io/OpenMLFlow.cpp @@ -5,6 +5,10 @@ */ #include +#include +#include + +#include #ifdef HAVE_CURL @@ -59,7 +63,7 @@ void OpenMLReader::openml_curl_request_helper(const std::string& url) if (!curl_handle) { - SG_SERROR("Failed to initialise curl handle.") + SG_SERROR("Failed to initialise curl handle.\n") return; } @@ -82,24 +86,61 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code) // TODO: call curl_easy_cleanup(curl_handle) ? SG_SERROR("Curl error: %s\n", curl_easy_strerror(code)) } -// else -// { -// long response_code; -// curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE, &response_code); - // if (response_code == 200) - // return; - // else - // { - // if (response_code == 181) - // SG_SERROR("Unknown flow. The flow with the given ID was not - //found in the database.") else if (response_code == 180) SG_SERROR("") - // SG_SERROR("Server code: %d\n", response_code) - // } -// } + // else + // { + // long response_code; + // curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE, + //&response_code); if (response_code == 200) return; + // else + // { + // if (response_code == 181) + // SG_SERROR("Unknown flow. The flow with the given ID was not + // found in the database.") else if (response_code == 180) + // SG_SERROR("") SG_SERROR("Server code: %d\n", response_code) + // } + // } } -std::shared_ptr -OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key) +#endif // HAVE_CURL + +static void check_flow_response(rapidjson::Document& doc) +{ + if (SG_UNLIKELY(doc.HasMember("error"))) + { + const Value& root = doc["error"]; + SG_SERROR( + "Server error %s: %s\n", root["code"].GetString(), + root["message"].GetString()) + return; + } + REQUIRE(doc.HasMember("flow"), "Unexpected format of OpenML flow.\n"); +} + +static SG_FORCED_INLINE void emplace_string_to_map( + const rapidjson::GenericValue>& v, + std::unordered_map& param_dict, + const std::string& name) +{ + if (v[name.c_str()].GetType() == rapidjson::Type::kStringType) + param_dict.emplace(name, v[name.c_str()].GetString()); + else + param_dict.emplace(name, ""); +} + +static SG_FORCED_INLINE void emplace_string_to_map( + const rapidjson::GenericObject< + true, rapidjson::GenericValue>>& v, + std::unordered_map& param_dict, + const std::string& name) +{ + if (v[name.c_str()].GetType() == rapidjson::Type::kStringType) + param_dict.emplace(name, v[name.c_str()].GetString()); + else + param_dict.emplace(name, ""); +} + +std::shared_ptr OpenMLFlow::download_flow( + const std::string& flow_id, const std::string& api_key) { Document document; parameters_type params; @@ -124,7 +165,8 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key if (root["parameter"].IsArray()) { - for (const auto &v : root["parameter"].GetArray()) { + for (const auto& v : root["parameter"].GetArray()) + { emplace_string_to_map(v, param_dict, "data_type"); emplace_string_to_map(v, param_dict, "default_value"); emplace_string_to_map(v, param_dict, "description"); @@ -146,11 +188,22 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key // handle components, i.e. kernels if (root.HasMember("component")) { - for (const auto& v : root["component"].GetArray()) + if (root["component"].IsArray()) + { + for (const auto& v : root["component"].GetArray()) + { + components.emplace( + v["identifier"].GetString(), + OpenMLFlow::download_flow( + v["flow"]["id"].GetString(), api_key)); + } + } + else { components.emplace( - v["identifier"].GetString(), - OpenMLFlow::download_flow(v["flow"]["id"].GetString(), api_key)); + root["component"]["identifier"].GetString(), + OpenMLFlow::download_flow( + root["component"]["flow"]["id"].GetString(), api_key)); } } @@ -162,26 +215,106 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key if (root.HasMember("class_name")) class_name = root["class_name"].GetString(); - auto flow = std::make_shared(name, description, class_name, components, params); + auto flow = std::make_shared( + name, description, class_name, components, params); return flow; } -void OpenMLFlow::check_flow_response(Document& doc) +void OpenMLFlow::upload_flow(const std::shared_ptr& flow) { - if (SG_UNLIKELY(doc.HasMember("error"))) +} + +void OpenMLFlow::dump() +{ +} + +std::shared_ptr OpenMLFlow::from_file() +{ + return std::shared_ptr(); +} + +std::shared_ptr ShogunOpenML::flow_to_model( + std::shared_ptr flow, bool initialize_with_defaults) +{ + std::string name; + std::string val_as_string; + std::shared_ptr obj; + auto params = flow->get_parameters(); + auto components = flow->get_components(); + auto class_name = get_class_info(flow->get_class_name()); + auto module_name = std::get<0>(class_name); + auto algo_name = std::get<1>(class_name); + if (module_name == "machine") + obj = std::shared_ptr(machine(algo_name)); + else if (module_name == "kernel") + obj = std::shared_ptr(kernel(algo_name)); + else if (module_name == "distance") + obj = std::shared_ptr(distance(algo_name)); + else + SG_SERROR("Unsupported factory \"%s\"\n", module_name.c_str()) + auto obj_param = obj->get_params(); + + auto put_lambda = [&obj, &name, &val_as_string](const auto& val) { + // cast value using type from get, i.e. val + auto val_ = char_to_scalar>( + val_as_string.c_str()); + obj->put(name, val_); + }; + + if (initialize_with_defaults) { - const Value& root = doc["error"]; - SG_SERROR( - "Server error %s: %s\n", root["code"].GetString(), - root["message"].GetString()) - return; + for (const auto& param : params) + { + Any any_val = obj_param.at(param.first)->get_value(); + name = param.first; + val_as_string = param.second.at("default_value"); + sg_any_dispatch(any_val, sg_all_typemap, put_lambda); + } } - REQUIRE(doc.HasMember("flow"), "Unexpected format of OpenML flow.\n"); + + for (const auto& component : components) + { + CSGObject* a = + flow_to_model(component.second, initialize_with_defaults).get(); + // obj->put(component.first, a); + } + + return obj; } -void OpenMLFlow::upload_flow(const std::shared_ptr& flow) +std::shared_ptr +ShogunOpenML::model_to_flow(const std::shared_ptr& model) { + return std::shared_ptr(); } -#endif // HAVE_CURL +std::tuple +ShogunOpenML::get_class_info(const std::string& class_name) +{ + std::vector class_components; + auto begin = class_name.begin(); + std::tuple result; + + for (auto it = class_name.begin(); it != class_name.end(); ++it) + { + if (*it == '.') + { + class_components.emplace_back(std::string(begin, it)); + begin = std::next(it); + } + if (std::next(it) == class_name.end()) + class_components.emplace_back(std::string(begin, std::next(it))); + } + if (class_components.size() != 3) + SG_SERROR("Invalid class name format %s\n", class_name.c_str()) + if (class_components[0] == "shogun") + result = std::make_tuple(class_components[1], class_components[2]); + else + SG_SERROR( + "The provided flow is not meant for shogun deserialisation! The " + "required library is \"%s\"\n", + class_components[0].c_str()) + + return result; +} diff --git a/src/shogun/io/OpenMLFlow.h b/src/shogun/io/OpenMLFlow.h index dee68423269..8fc46594a08 100644 --- a/src/shogun/io/OpenMLFlow.h +++ b/src/shogun/io/OpenMLFlow.h @@ -15,7 +15,6 @@ #include #include -#include #include #include @@ -150,9 +149,9 @@ namespace shogun public: using components_type = - std::unordered_map>; + std::unordered_map>; using parameters_type = std::unordered_map< - std::string, std::unordered_map>; + std::string, std::unordered_map>; OpenMLFlow( const std::string& name, const std::string& description, @@ -163,13 +162,15 @@ namespace shogun { } - ~OpenMLFlow()= default; - static std::shared_ptr download_flow(const std::string& flow_id, const std::string& api_key); + static std::shared_ptr from_file(); + static void upload_flow(const std::shared_ptr& flow); + void dump(); + std::shared_ptr get_subflow(const std::string& name) { auto find_flow = m_components.find(name); @@ -181,40 +182,82 @@ namespace shogun return nullptr; } +#ifndef SWIG + SG_FORCED_INLINE parameters_type get_parameters() + { + return m_parameters; + } + + SG_FORCED_INLINE components_type get_components() + { + return m_components; + } + + SG_FORCED_INLINE std::string get_class_name() + { + return m_class_name; + } +#endif // SWIG + private: std::string m_name; std::string m_description; std::string m_class_name; parameters_type m_parameters; components_type m_components; + }; #ifndef SWIG - static void check_flow_response(rapidjson::Document& doc); + template + T char_to_scalar(const char* string_val) + { + SG_SERROR( + "No registered conversion from string to type \"s\"\n", + demangled_type().c_str()) + return 0; + } + + template <> + float32_t char_to_scalar(const char* string_val) + { + char* end; + return std::strtof(string_val, &end); + } - static SG_FORCED_INLINE void emplace_string_to_map( - const rapidjson::GenericValue>& v, - std::unordered_map& param_dict, - const std::string& name) - { - if (v[name.c_str()].GetType() == rapidjson::Type::kStringType) - param_dict.emplace(name, v[name.c_str()].GetString()); - else - param_dict.emplace(name, ""); - } + template <> + float64_t char_to_scalar(const char* string_val) + { + char* end; + return std::strtod(string_val, &end); + } - static SG_FORCED_INLINE void emplace_string_to_map( - const rapidjson::GenericObject< - true, rapidjson::GenericValue>>& v, - std::unordered_map& param_dict, - const std::string& name) - { - if (v[name.c_str()].GetType() == rapidjson::Type::kStringType) - param_dict.emplace(name, v[name.c_str()].GetString()); - else - param_dict.emplace(name, ""); - } + template <> + floatmax_t char_to_scalar(const char* string_val) + { + char* end; + return std::strtold(string_val, &end); + } + + template <> + bool char_to_scalar(const char* string_val) + { + return strcmp(string_val, "true"); + } #endif // SWIG + + class ShogunOpenML + { + public: + static std::shared_ptr flow_to_model( + std::shared_ptr flow, bool initialize_with_defaults); + + static std::shared_ptr + model_to_flow(const std::shared_ptr& model); + + private: + static std::tuple + get_class_info(const std::string& class_name); }; } // namespace shogun #endif // HAVE_CURL