Skip to content

Commit

Permalink
moved json dependency to library
Browse files Browse the repository at this point in the history
  • Loading branch information
gf712 committed May 7, 2019
1 parent 7cf1d10 commit b70398d
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/shogun/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ SHOGUN_DEPENDENCIES(
CONFIG_FLAG HAVE_XML)
# RapidJSON
include(external/RapidJSON)
SHOGUN_INCLUDE_DIRS(SCOPE PUBLIC ${RAPIDJSON_INCLUDE_DIR})
SHOGUN_INCLUDE_DIRS(SCOPE PRIVATE ${RAPIDJSON_INCLUDE_DIR})

if (NOT WIN32)
# FIXME: HDF5 linking on WIN32 is broken.
Expand Down
197 changes: 165 additions & 32 deletions src/shogun/io/OpenMLFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
*/

#include <shogun/io/OpenMLFlow.h>
#include <shogun/lib/type_case.h>
#include <shogun/util/factory.h>

#include <rapidjson/document.h>

#ifdef HAVE_CURL

Expand Down Expand Up @@ -59,7 +63,7 @@ void OpenMLReader::openml_curl_request_helper(const std::string& url)

if (!curl_handle)
{
SG_SERROR("Failed to initialise curl handle.")
SG_SERROR("Failed to initialise curl handle.\n")
return;
}

Expand All @@ -82,24 +86,61 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
// TODO: call curl_easy_cleanup(curl_handle) ?
SG_SERROR("Curl error: %s\n", curl_easy_strerror(code))
}
// else
// {
// long response_code;
// curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE, &response_code);
// if (response_code == 200)
// return;
// else
// {
// if (response_code == 181)
// SG_SERROR("Unknown flow. The flow with the given ID was not
//found in the database.") else if (response_code == 180) SG_SERROR("")
// SG_SERROR("Server code: %d\n", response_code)
// }
// }
// else
// {
// long response_code;
// curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE,
//&response_code); if (response_code == 200) return;
// else
// {
// if (response_code == 181)
// SG_SERROR("Unknown flow. The flow with the given ID was not
// found in the database.") else if (response_code == 180)
// SG_SERROR("") SG_SERROR("Server code: %d\n", response_code)
// }
// }
}

std::shared_ptr<OpenMLFlow>
OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key)
#endif // HAVE_CURL

static void check_flow_response(rapidjson::Document& doc)
{
if (SG_UNLIKELY(doc.HasMember("error")))
{
const Value& root = doc["error"];
SG_SERROR(
"Server error %s: %s\n", root["code"].GetString(),
root["message"].GetString())
return;
}
REQUIRE(doc.HasMember("flow"), "Unexpected format of OpenML flow.\n");
}

static SG_FORCED_INLINE void emplace_string_to_map(
const rapidjson::GenericValue<rapidjson::UTF8<char>>& v,
std::unordered_map<std::string, std::string>& param_dict,
const std::string& name)
{
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
param_dict.emplace(name, v[name.c_str()].GetString());
else
param_dict.emplace(name, "");
}

static SG_FORCED_INLINE void emplace_string_to_map(
const rapidjson::GenericObject<
true, rapidjson::GenericValue<rapidjson::UTF8<char>>>& v,
std::unordered_map<std::string, std::string>& param_dict,
const std::string& name)
{
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
param_dict.emplace(name, v[name.c_str()].GetString());
else
param_dict.emplace(name, "");
}

std::shared_ptr<OpenMLFlow> OpenMLFlow::download_flow(
const std::string& flow_id, const std::string& api_key)
{
Document document;
parameters_type params;
Expand All @@ -124,7 +165,8 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key

if (root["parameter"].IsArray())
{
for (const auto &v : root["parameter"].GetArray()) {
for (const auto& v : root["parameter"].GetArray())
{
emplace_string_to_map(v, param_dict, "data_type");
emplace_string_to_map(v, param_dict, "default_value");
emplace_string_to_map(v, param_dict, "description");
Expand All @@ -146,11 +188,22 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key
// handle components, i.e. kernels
if (root.HasMember("component"))
{
for (const auto& v : root["component"].GetArray())
if (root["component"].IsArray())
{
for (const auto& v : root["component"].GetArray())
{
components.emplace(
v["identifier"].GetString(),
OpenMLFlow::download_flow(
v["flow"]["id"].GetString(), api_key));
}
}
else
{
components.emplace(
v["identifier"].GetString(),
OpenMLFlow::download_flow(v["flow"]["id"].GetString(), api_key));
root["component"]["identifier"].GetString(),
OpenMLFlow::download_flow(
root["component"]["flow"]["id"].GetString(), api_key));
}
}

Expand All @@ -162,26 +215,106 @@ OpenMLFlow::download_flow(const std::string& flow_id, const std::string& api_key
if (root.HasMember("class_name"))
class_name = root["class_name"].GetString();

auto flow = std::make_shared<OpenMLFlow>(name, description, class_name, components, params);
auto flow = std::make_shared<OpenMLFlow>(
name, description, class_name, components, params);

return flow;
}

void OpenMLFlow::check_flow_response(Document& doc)
void OpenMLFlow::upload_flow(const std::shared_ptr<OpenMLFlow>& flow)
{
if (SG_UNLIKELY(doc.HasMember("error")))
}

void OpenMLFlow::dump()
{
}

std::shared_ptr<OpenMLFlow> OpenMLFlow::from_file()
{
return std::shared_ptr<OpenMLFlow>();
}

std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model(
std::shared_ptr<OpenMLFlow> flow, bool initialize_with_defaults)
{
std::string name;
std::string val_as_string;
std::shared_ptr<CSGObject> obj;
auto params = flow->get_parameters();
auto components = flow->get_components();
auto class_name = get_class_info(flow->get_class_name());
auto module_name = std::get<0>(class_name);
auto algo_name = std::get<1>(class_name);
if (module_name == "machine")
obj = std::shared_ptr<CSGObject>(machine(algo_name));
else if (module_name == "kernel")
obj = std::shared_ptr<CSGObject>(kernel(algo_name));
else if (module_name == "distance")
obj = std::shared_ptr<CSGObject>(distance(algo_name));
else
SG_SERROR("Unsupported factory \"%s\"\n", module_name.c_str())
auto obj_param = obj->get_params();

auto put_lambda = [&obj, &name, &val_as_string](const auto& val) {
// cast value using type from get, i.e. val
auto val_ = char_to_scalar<std::remove_reference_t<decltype(val)>>(
val_as_string.c_str());
obj->put(name, val_);
};

if (initialize_with_defaults)
{
const Value& root = doc["error"];
SG_SERROR(
"Server error %s: %s\n", root["code"].GetString(),
root["message"].GetString())
return;
for (const auto& param : params)
{
Any any_val = obj_param.at(param.first)->get_value();
name = param.first;
val_as_string = param.second.at("default_value");
sg_any_dispatch(any_val, sg_all_typemap, put_lambda);
}
}
REQUIRE(doc.HasMember("flow"), "Unexpected format of OpenML flow.\n");

for (const auto& component : components)
{
CSGObject* a =
flow_to_model(component.second, initialize_with_defaults).get();
// obj->put(component.first, a);
}

return obj;
}

void OpenMLFlow::upload_flow(const std::shared_ptr<OpenMLFlow>& flow)
std::shared_ptr<OpenMLFlow>
ShogunOpenML::model_to_flow(const std::shared_ptr<CSGObject>& model)
{
return std::shared_ptr<OpenMLFlow>();
}

#endif // HAVE_CURL
std::tuple<std::string, std::string>
ShogunOpenML::get_class_info(const std::string& class_name)
{
std::vector<std::string> class_components;
auto begin = class_name.begin();
std::tuple<std::string, std::string> result;

for (auto it = class_name.begin(); it != class_name.end(); ++it)
{
if (*it == '.')
{
class_components.emplace_back(std::string(begin, it));
begin = std::next(it);
}
if (std::next(it) == class_name.end())
class_components.emplace_back(std::string(begin, std::next(it)));
}
if (class_components.size() != 3)
SG_SERROR("Invalid class name format %s\n", class_name.c_str())
if (class_components[0] == "shogun")
result = std::make_tuple(class_components[1], class_components[2]);
else
SG_SERROR(
"The provided flow is not meant for shogun deserialisation! The "
"required library is \"%s\"\n",
class_components[0].c_str())

return result;
}
Loading

0 comments on commit b70398d

Please sign in to comment.