Skip to content

Commit

Permalink
initial ShogunOpenML class
Browse files Browse the repository at this point in the history
  • Loading branch information
gf712 committed May 8, 2019
1 parent b70398d commit 45ac04e
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 83 deletions.
2 changes: 2 additions & 0 deletions src/interfaces/swig/IO.i
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
%rename(MemoryMappedFile) CMemoryMappedFile;

%shared_ptr(shogun::OpenMLFlow)
%shared_ptr(shogun::ShogunOpenML::flow_to_model)
%shared_ptr(shogun::ShogunOpenML::model_to_flow)

%include <shogun/io/File.h>
%include <shogun/io/streaming/StreamingFile.h>
Expand Down
1 change: 0 additions & 1 deletion src/shogun/base/SGObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,5 +1112,4 @@ std::string CSGObject::string_enum_reverse_lookup(
return p.second == enum_value;
});
return enum_map_it->first;

}
287 changes: 242 additions & 45 deletions src/shogun/io/OpenMLFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
using namespace shogun;
using namespace rapidjson;

/**
* The writer callback function used to write the packets to a C++ string.
* @param data the data received in CURL request
* @param size always 1
* @param nmemb the size of data
* @param buffer_in the buffer to write to
* @return the size of buffer that was written
*/
size_t writer(char* data, size_t size, size_t nmemb, std::string* buffer_in)
{
// adapted from https://stackoverflow.com/a/5780603
Expand All @@ -30,13 +38,16 @@ size_t writer(char* data, size_t size, size_t nmemb, std::string* buffer_in)
return 0;
}

/* OpenML server format */
const char* OpenMLReader::xml_server = "https://www.openml.org/api/v1/xml";
const char* OpenMLReader::json_server = "https://www.openml.org/api/v1/json";
/* DATA API */
const char* OpenMLReader::dataset_description = "/data/{}";
const char* OpenMLReader::list_data_qualities = "/data/qualities/list";
const char* OpenMLReader::data_features = "/data/features/{}";
const char* OpenMLReader::list_dataset_qualities = "/data/qualities/{}";
const char* OpenMLReader::list_dataset_filter = "/data/list/{}";
/* FLOW API */
const char* OpenMLReader::flow_file = "/flow/{}";

const std::unordered_map<std::string, std::string>
Expand Down Expand Up @@ -84,25 +95,16 @@ void OpenMLReader::openml_curl_error_helper(CURL* curl_handle, CURLcode code)
if (code != CURLE_OK)
{
// TODO: call curl_easy_cleanup(curl_handle) ?
SG_SERROR("Curl error: %s\n", curl_easy_strerror(code))
SG_SERROR("Connection error: %s.\n", curl_easy_strerror(code))
}
// else
// {
// long response_code;
// curl_easy_getinfo(curl_handle, CURLINFO_RESPONSE_CODE,
//&response_code); if (response_code == 200) return;
// else
// {
// if (response_code == 181)
// SG_SERROR("Unknown flow. The flow with the given ID was not
// found in the database.") else if (response_code == 180)
// SG_SERROR("") SG_SERROR("Server code: %d\n", response_code)
// }
// }
}

#endif // HAVE_CURL

/**
* Checks the returned flow in JSON format
* @param doc the parsed flow
*/
static void check_flow_response(rapidjson::Document& doc)
{
if (SG_UNLIKELY(doc.HasMember("error")))
Expand All @@ -116,24 +118,36 @@ static void check_flow_response(rapidjson::Document& doc)
REQUIRE(doc.HasMember("flow"), "Unexpected format of OpenML flow.\n");
}

/**
* Helper function to add JSON objects as string in map
* @param v a RapidJSON GenericValue, i.e. string
* @param param_dict the map to write to
* @param name the name of the key
*/
static SG_FORCED_INLINE void emplace_string_to_map(
const rapidjson::GenericValue<rapidjson::UTF8<char>>& v,
const GenericValue<UTF8<char>>& v,
std::unordered_map<std::string, std::string>& param_dict,
const std::string& name)
{
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
if (v[name.c_str()].GetType() == Type::kStringType)
param_dict.emplace(name, v[name.c_str()].GetString());
else
param_dict.emplace(name, "");
}

/**
* Helper function to add JSON objects as string in map
* @param v a RapidJSON GenericObject, i.e. array
* @param param_dict the map to write to
* @param name the name of the key
*/
static SG_FORCED_INLINE void emplace_string_to_map(
const rapidjson::GenericObject<
true, rapidjson::GenericValue<rapidjson::UTF8<char>>>& v,
const GenericObject<
true, GenericValue<UTF8<char>>>& v,
std::unordered_map<std::string, std::string>& param_dict,
const std::string& name)
{
if (v[name.c_str()].GetType() == rapidjson::Type::kStringType)
if (v[name.c_str()].GetType() == Type::kStringType)
param_dict.emplace(name, v[name.c_str()].GetString());
else
param_dict.emplace(name, "");
Expand Down Expand Up @@ -234,52 +248,235 @@ std::shared_ptr<OpenMLFlow> OpenMLFlow::from_file()
return std::shared_ptr<OpenMLFlow>();
}

/**
* Class using the Any visitor pattern to convert
* a string to a C++ type that can be used as a parameter
* in a Shogun model.
*/
class StringToShogun : public AnyVisitor
{
public:
explicit StringToShogun(std::shared_ptr<CSGObject> model)
: m_model(model), m_parameter(""), m_string_val(""){};

StringToShogun(
std::shared_ptr<CSGObject> model, const std::string& parameter,
const std::string& string_val)
: m_model(model), m_parameter(parameter), m_string_val(string_val){};

void on(bool* v) final
{
if (!is_null())
{
SG_SDEBUG("bool: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
bool result = strcmp(m_string_val.c_str(), "true") == 0;
m_model->put(m_parameter, result);
}
}
void on(int32_t* v) final
{
if (!is_null())
{
SG_SDEBUG("int32: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
try
{
int32_t result = std::stoi(m_string_val);
m_model->put(m_parameter, result);
}
catch (const std::invalid_argument&)
{
// it's an option, i.e. internally represented
// as an enum but in swig exposed as a string
m_string_val.erase(
std::remove_if(
m_string_val.begin(), m_string_val.end(),
// remove quotes
[](const auto& val) { return val == '\"'; }),
m_string_val.end());
m_model->put(m_parameter, m_string_val);
}
}
}
void on(int64_t* v) final
{
if (!is_null())
{
SG_SDEBUG("int64: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
int64_t result = std::stol(m_string_val);
m_model->put(m_parameter, result);
}
}
void on(float* v) final
{
if (!is_null())
{
SG_SDEBUG("float: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
char* end;
float32_t result = std::strtof(m_string_val.c_str(), &end);
m_model->put(m_parameter, result);
}
}
void on(double* v) final
{
if (!is_null())
{
SG_SDEBUG("double: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
char* end;
float64_t result = std::strtod(m_string_val.c_str(), &end);
m_model->put(m_parameter, result);
}
}
void on(long double* v)
{
if (!is_null())
{
SG_SDEBUG("long double: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
char* end;
floatmax_t result = std::strtold(m_string_val.c_str(), &end);
m_model->put(m_parameter, result);
}
}
void on(CSGObject** v) final
{
SG_SDEBUG("CSGObject: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
}
void on(SGVector<int>* v) final
{
SG_SDEBUG("SGVector<int>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
}
void on(SGVector<float>* v) final
{
SG_SDEBUG("SGVector<float>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
}
void on(SGVector<double>* v) final
{
SG_SDEBUG("SGVector<double>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
}
void on(SGMatrix<int>* mat) final
{
SG_SDEBUG("SGMatrix<int>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
}
void on(SGMatrix<float>* mat) final
{
SG_SDEBUG("SGMatrix<float>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
}
void on(SGMatrix<double>* mat) final
{
SG_SDEBUG("SGMatrix<double>: %s=%s\n", m_parameter.c_str(), m_string_val.c_str())
}

bool is_null()
{
bool result = strcmp(m_string_val.c_str(), "null") == 0;
return result;
}

void set_parameter_name(const std::string& name)
{
m_parameter = name;
}

void set_string_value(const std::string& value)
{
m_string_val = value;
}

private:
std::shared_ptr<CSGObject> m_model;
std::string m_parameter;
std::string m_string_val;
};

/**
* Instantiates a CSGObject using a factory
* @param factory_name the name of the factory
* @param algo_name the name of algorithm passed to factory
* @return the instantiated object using a factory
*/
std::shared_ptr<CSGObject> instantiate_model_from_factory(
const std::string& factory_name, const std::string& algo_name)
{
std::shared_ptr<CSGObject> obj;
if (factory_name == "machine")
obj = std::shared_ptr<CSGObject>(machine(algo_name));
else if (factory_name == "kernel")
obj = std::shared_ptr<CSGObject>(kernel(algo_name));
else if (factory_name == "distance")
obj = std::shared_ptr<CSGObject>(distance(algo_name));
else
SG_SERROR("Unsupported factory \"%s\".\n", factory_name.c_str())

return obj;
}

/**
* Downcasts a CSGObject and puts it in the map of obj.
* @param obj the main object
* @param nested_obj the object to be casted and put in the obj map.
* @param parameter_name the name of nested_obj
*/
void cast_and_put(
const std::shared_ptr<CSGObject>& obj,
const std::shared_ptr<CSGObject>& nested_obj,
const std::string& parameter_name)
{
if (auto casted_obj = std::dynamic_pointer_cast<CMachine>(nested_obj))
{
// TODO: remove clone
// temporary fix until shared_ptr PR merged
auto* tmp_clone = dynamic_cast<CMachine*>(casted_obj->clone());
obj->put(parameter_name, tmp_clone);
}
else if (auto casted_obj = std::dynamic_pointer_cast<CKernel>(nested_obj))
{
auto* tmp_clone = dynamic_cast<CKernel*>(casted_obj->clone());
obj->put(parameter_name, tmp_clone);
}
else if (auto casted_obj = std::dynamic_pointer_cast<CDistance>(nested_obj))
{
auto* tmp_clone = dynamic_cast<CDistance*>(casted_obj->clone());
obj->put(parameter_name, tmp_clone);
}
else
SG_SERROR("Could not cast SGObject.\n")
}

std::shared_ptr<CSGObject> ShogunOpenML::flow_to_model(
std::shared_ptr<OpenMLFlow> flow, bool initialize_with_defaults)
{
std::string name;
std::string val_as_string;
std::shared_ptr<CSGObject> obj;
auto params = flow->get_parameters();
auto components = flow->get_components();
auto class_name = get_class_info(flow->get_class_name());
auto module_name = std::get<0>(class_name);
auto algo_name = std::get<1>(class_name);
if (module_name == "machine")
obj = std::shared_ptr<CSGObject>(machine(algo_name));
else if (module_name == "kernel")
obj = std::shared_ptr<CSGObject>(kernel(algo_name));
else if (module_name == "distance")
obj = std::shared_ptr<CSGObject>(distance(algo_name));
else
SG_SERROR("Unsupported factory \"%s\"\n", module_name.c_str())

auto obj = instantiate_model_from_factory(module_name, algo_name);
auto obj_param = obj->get_params();

auto put_lambda = [&obj, &name, &val_as_string](const auto& val) {
// cast value using type from get, i.e. val
auto val_ = char_to_scalar<std::remove_reference_t<decltype(val)>>(
val_as_string.c_str());
obj->put(name, val_);
};
std::unique_ptr<StringToShogun> visitor(new StringToShogun(obj));

if (initialize_with_defaults)
{
for (const auto& param : params)
{
Any any_val = obj_param.at(param.first)->get_value();
name = param.first;
val_as_string = param.second.at("default_value");
sg_any_dispatch(any_val, sg_all_typemap, put_lambda);
std::string name = param.first;
std::string val_as_string = param.second.at("default_value");
visitor->set_parameter_name(name);
visitor->set_string_value(val_as_string);
any_val.visit(visitor.get());
}
}

for (const auto& component : components)
{
CSGObject* a =
flow_to_model(component.second, initialize_with_defaults).get();
// obj->put(component.first, a);
std::shared_ptr<CSGObject> nested_obj =
flow_to_model(component.second, initialize_with_defaults);
cast_and_put(obj, nested_obj, component.first);
}

SG_SDEBUG("Final object: %s.\n", obj->to_string().c_str());

return obj;
}

Expand All @@ -306,15 +503,15 @@ ShogunOpenML::get_class_info(const std::string& class_name)
if (std::next(it) == class_name.end())
class_components.emplace_back(std::string(begin, std::next(it)));
}
if (class_components.size() != 3)
SG_SERROR("Invalid class name format %s\n", class_name.c_str())
if (class_components[0] == "shogun")
result = std::make_tuple(class_components[1], class_components[2]);
else
SG_SERROR(
"The provided flow is not meant for shogun deserialisation! The "
"required library is \"%s\"\n",
"required library is \"%s\".\n",
class_components[0].c_str())
if (class_components.size() != 3)
SG_SERROR("Invalid class name format %s.\n", class_name.c_str())

return result;
}
Loading

0 comments on commit 45ac04e

Please sign in to comment.