Skip to content

Commit

Permalink
implicit conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
gf712 committed Jun 4, 2020
1 parent a3f8d98 commit 96b8998
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 59 deletions.
140 changes: 90 additions & 50 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ SG_FORCED_INLINE const char* convert_string_to_char(const char* name)
*/
class SGObject: public std::enable_shared_from_this<SGObject>
{
template <typename ReturnType, typename CastType>
template <typename ReturnType>
struct ParameterGetterInterface
{
ReturnType& m_value;
Expand Down Expand Up @@ -529,21 +529,21 @@ class SGObject: public std::enable_shared_from_this<SGObject>
const auto& value = param.get_value();
try
{
if (param.get_properties().has_property(ParameterProperties::CONSTFUNCTION))
{
ParameterGetterInterface<ReturnType, std::function<ReturnType()>> visitor{result};
value.visit_with(&visitor);
}
else if (param.get_properties().has_property(ParameterProperties::AUTO))
{
ParameterGetterInterface<ReturnType, AutoValue<ReturnType>> visitor{result};
value.visit_with(&visitor);
}
else
{
ParameterGetterInterface<ReturnType, ReturnType> visitor{result};
// if (param.get_properties().has_property(ParameterProperties::CONSTFUNCTION))
// {
ParameterGetterInterface<ReturnType> visitor{result};
value.visit_with(&visitor);
}
// }
// else if (param.get_properties().has_property(ParameterProperties::AUTO))
// {
// ParameterGetterInterface<ReturnType> visitor{result};
// value.visit_with(&visitor);
// }
// else
// {
// ParameterGetterInterface<ReturnType> visitor{result};
// value.visit_with(&visitor);
// }
}
catch (const std::bad_optional_access&)
{
Expand Down Expand Up @@ -767,45 +767,51 @@ class SGObject: public std::enable_shared_from_this<SGObject>
template<typename T>
void register_parameter_visitor() const
{
if constexpr (is_auto_value_v<T>)
using Type = std::conditional_t<is_auto_value_v<T>, traits::variant_type_t<0, T>, T>;

if constexpr (std::is_arithmetic_v<Type>)
{
using ReturnType = traits::get_variant_type_t<0, T>;
Any::register_visitor<T, ParameterPutInterface<ReturnType>>(
[](T* value, auto* visitor)
{
*value = visitor->m_value;
}
);
Any::register_visitor<T, ParameterPutInterface<float32_t>>(
[](T* value, auto* visitor) { *value = utils::safe_convert<Type>(visitor->m_value);});
Any::register_visitor<T, ParameterPutInterface<float64_t>>(
[](T* value, auto* visitor) { *value = utils::safe_convert<Type>(visitor->m_value);});
Any::register_visitor<T, ParameterPutInterface<int32_t>>(
[](T* value, auto* visitor) { *value = utils::safe_convert<Type>(visitor->m_value);});
Any::register_visitor<T, ParameterPutInterface<int64_t>>(
[](T* value, auto* visitor) { *value = utils::safe_convert<Type>(visitor->m_value);});
}
else
{
Any::register_visitor<T, ParameterPutInterface<T>>(
[](T* value, auto* visitor)
{
*value = visitor->m_value;
}
);
else {
Any::register_visitor<T, ParameterPutInterface<Type>>(
[](Type* value, auto* visitor) { *value = visitor->m_value;});
}

if constexpr (traits::is_functional<T>::value)
{
if constexpr (!traits::returns_void<T>::value)
{
using ReturnType = typename T::result_type;
Any::register_visitor<T, ParameterGetterInterface<ReturnType, T>>(
[](T* value, auto* visitor)
{
visitor->m_value = value->operator()();
}
);
if constexpr (std::is_arithmetic_v<ReturnType>) {
Any::register_visitor<T, ParameterGetterInterface<float32_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<float32_t>(value->operator()());});
Any::register_visitor<T, ParameterGetterInterface<float64_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<float64_t>(value->operator()());});
Any::register_visitor<T, ParameterGetterInterface<int32_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<int32_t>(value->operator()());});
Any::register_visitor<T, ParameterGetterInterface<int64_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<int64_t>(value->operator()());});
}
else {
Any::register_visitor<T, ParameterGetterInterface<ReturnType>>(
[](T* value, auto* visitor) {visitor->m_value = value->operator()();});
}
}
}
else if constexpr (is_auto_value_v<T>)
{
using ReturnType = traits::get_variant_type_t<0, T>;
Any::register_visitor<T, ParameterGetterInterface<ReturnType, T>>(
[](T* value, auto* visitor)
{
using ReturnType = traits::variant_type_t<0, T>;
static_assert(std::is_arithmetic_v<ReturnType>, "Cannot handle non arithmetic types in AutoValue yet");
Any::register_visitor<T, ParameterGetterInterface<float32_t>>(
[](T* value, auto* visitor) {
if (std::holds_alternative<AutoValueEmpty>(*value))
{
// std::bad_optional_access does not support error messages
Expand All @@ -814,18 +820,50 @@ class SGObject: public std::enable_shared_from_this<SGObject>
throw std::bad_optional_access{};
}
else
visitor->m_value = std::get<ReturnType>(*value);
}
);
visitor->m_value = utils::safe_convert<float32_t>(std::get<ReturnType>(*value));
});
Any::register_visitor<T, ParameterGetterInterface<float64_t>>(
[](T* value, auto* visitor) {
if (std::holds_alternative<AutoValueEmpty>(*value))
throw std::bad_optional_access{};
else
visitor->m_value = utils::safe_convert<float64_t>(std::get<ReturnType>(*value));
});
Any::register_visitor<T, ParameterGetterInterface<int32_t>>(
[](T* value, auto* visitor) {
if (std::holds_alternative<AutoValueEmpty>(*value))
throw std::bad_optional_access{};
else
visitor->m_value = utils::safe_convert<int32_t>(std::get<ReturnType>(*value));
});
Any::register_visitor<T, ParameterGetterInterface<int64_t>>(
[](T* value, auto* visitor) {
if (std::holds_alternative<AutoValueEmpty>(*value))
throw std::bad_optional_access{};
else
visitor->m_value = utils::safe_convert<int64_t>(std::get<ReturnType>(*value));
});
}
else
{
Any::register_visitor<T, ParameterGetterInterface<T, T>>(
[](T* value, auto* visitor)
{
visitor->m_value = *value;
}
);
if constexpr(std::is_arithmetic_v<T>) {
Any::register_visitor<T, ParameterGetterInterface<float32_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<float32_t>(*value);});
Any::register_visitor<T, ParameterGetterInterface<float64_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<float64_t>(*value);});
Any::register_visitor<T, ParameterGetterInterface<int32_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<int32_t>(*value);});
Any::register_visitor<T, ParameterGetterInterface<int64_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<int64_t>(*value);});
}
else {
Any::register_visitor<T, ParameterGetterInterface<T>>(
[](T* value, auto* visitor)
{
visitor->m_value = *value;
}
);
}
}
}
/** Registers a class parameter which is identified by a tag.
Expand All @@ -840,6 +878,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
void register_param(Tag<T>& _tag, const T& value)
{
create_parameter(_tag, AnyParameter(make_any(value)));
register_parameter_visitor<T>();
}

/** Registers a class parameter which is identified by a name.
Expand All @@ -854,6 +893,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
void register_param(std::string_view name, const T& value)
{
create_parameter(BaseTag(name), AnyParameter(make_any(value)));
register_parameter_visitor<T>();
}

/** Puts a pointer to some parameter into the parameter map.
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/lib/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,11 +1316,11 @@ namespace shogun
if constexpr (std::is_base_of_v<SGObject, Derived>)
{
Any::register_caster<T, SGObject*>(
[](T value) { return dynamic_cast<SGObject*>(value); });
[](T value) { return static_cast<SGObject*>(value); });
if constexpr (!std::is_same_v<std::nullptr_t, base_type<Derived>>
&& !std::is_same_v<Derived, base_type<Derived>>)
Any::register_caster<T, base_type<Derived>*>([](T value) {
return dynamic_cast<base_type<Derived>*>(value);
return static_cast<base_type<Derived>*>(value);
});
}
if constexpr (traits::is_shared_ptr<T>::value)
Expand All @@ -1329,11 +1329,11 @@ namespace shogun
if constexpr (std::is_base_of_v<SGObject, SharedType>)
{
Any::register_caster<T, std::shared_ptr<SGObject>>(
[](T value) { return std::dynamic_pointer_cast<SGObject>(value); });
[](T value) { return std::static_pointer_cast<SGObject>(value); });
if constexpr (!std::is_same_v<std::nullptr_t, base_type<SharedType>>
&& !std::is_same_v<SharedType, base_type<SharedType>>)
Any::register_caster<T, std::shared_ptr<base_type<SharedType>>>([](T value) {
return std::dynamic_pointer_cast<base_type<SharedType>>(value);
return std::static_pointer_cast<base_type<SharedType>>(value);
});
}
}
Expand Down
10 changes: 7 additions & 3 deletions src/shogun/util/traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,19 @@ namespace shogun
inline constexpr bool is_any_of_v = is_any_of<T, Ts...>::value;

template<uint32_t idx, typename Ts>
struct get_variant_type{};
struct variant_type{
using type = Ts;
static constexpr bool value = false;
};

template<uint32_t idx, typename ...Ts>
struct get_variant_type<idx, std::variant<Ts...>>{
struct variant_type<idx, std::variant<Ts...>>{
using type = typename std::tuple_element<idx, std::tuple<Ts...>>::type;
static constexpr bool value = true;
};

template<uint32_t idx, typename ...Ts>
using get_variant_type_t = typename get_variant_type<idx, Ts...>::type;
using variant_type_t = typename variant_type<idx, Ts...>::type;
#endif // DOXYGEN_SHOULD_SKIP_THIS
} // namespace traits
} // namespace shogun
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/base/SGObject_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,10 @@ TEST(SGObject, tags_set_get_int)

EXPECT_THROW(obj->get<int32_t>("foo"), ShogunException);
obj->put(MockObject::kInt, 10);
EXPECT_NO_THROW(obj->put(MockObject::kInt, 10.0));
EXPECT_EQ(obj->get(Tag<int32_t>(MockObject::kInt)), 10);
EXPECT_THROW(obj->get<float64_t>(MockObject::kInt), ShogunException);
EXPECT_THROW(obj->get(Tag<float64_t>(MockObject::kInt)), ShogunException);
EXPECT_EQ(obj->get<float64_t>(MockObject::kInt), 10.0);
EXPECT_EQ(obj->get(Tag<float64_t>(MockObject::kInt)), 10.0);
EXPECT_EQ(obj->get<int>(MockObject::kInt), 10);
}

Expand Down

0 comments on commit 96b8998

Please sign in to comment.