Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implicit conversions for put/get #5056

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 79 additions & 51 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ SG_FORCED_INLINE const char* convert_string_to_char(const char* name)
*/
class SGObject: public std::enable_shared_from_this<SGObject>
{
template <typename ReturnType, typename CastType>
template <typename ReturnType>
struct ParameterGetterInterface
{
ReturnType& m_value;
Expand Down Expand Up @@ -529,21 +529,8 @@ class SGObject: public std::enable_shared_from_this<SGObject>
const auto& value = param.get_value();
try
{
if (param.get_properties().has_property(ParameterProperties::CONSTFUNCTION))
{
ParameterGetterInterface<ReturnType, std::function<ReturnType()>> visitor{result};
value.visit_with(&visitor);
}
else if (param.get_properties().has_property(ParameterProperties::AUTO))
{
ParameterGetterInterface<ReturnType, AutoValue<ReturnType>> visitor{result};
value.visit_with(&visitor);
}
else
{
ParameterGetterInterface<ReturnType, ReturnType> visitor{result};
value.visit_with(&visitor);
}
ParameterGetterInterface<ReturnType> visitor{result};
value.visit_with(&visitor);
}
catch (const std::bad_optional_access&)
{
Expand Down Expand Up @@ -767,45 +754,52 @@ class SGObject: public std::enable_shared_from_this<SGObject>
template<typename T>
void register_parameter_visitor() const
{
if constexpr (is_auto_value_v<T>)
using Type = std::conditional_t<is_auto_value_v<T>, traits::variant_type_t<0, T>, T>;

if constexpr (std::is_arithmetic_v<Type> && !std::is_same_v<Type, bool>)
{
using ReturnType = traits::get_variant_type_t<0, T>;
Any::register_visitor<T, ParameterPutInterface<ReturnType>>(
[](T* value, auto* visitor)
{
*value = visitor->m_value;
}
);
Any::register_visitor<T, ParameterPutInterface<float32_t>>(
[](T* value, auto* visitor) { *value = utils::safe_convert<Type>(visitor->m_value);});
Any::register_visitor<T, ParameterPutInterface<float64_t>>(
[](T* value, auto* visitor) { *value = utils::safe_convert<Type>(visitor->m_value);});
Any::register_visitor<T, ParameterPutInterface<int32_t>>(
[](T* value, auto* visitor) { *value = utils::safe_convert<Type>(visitor->m_value);});
Any::register_visitor<T, ParameterPutInterface<int64_t>>(
[](T* value, auto* visitor) { *value = utils::safe_convert<Type>(visitor->m_value);});
}
else
{
Any::register_visitor<T, ParameterPutInterface<T>>(
[](T* value, auto* visitor)
{
*value = visitor->m_value;
}
);
else {
Any::register_visitor<T, ParameterPutInterface<Type>>(
[](Type* value, auto* visitor) { *value = visitor->m_value;});
}

if constexpr (traits::is_functional<T>::value)
{
if constexpr (!traits::returns_void<T>::value)
{
using ReturnType = typename T::result_type;
Any::register_visitor<T, ParameterGetterInterface<ReturnType, T>>(
[](T* value, auto* visitor)
{
visitor->m_value = value->operator()();
}
);
if constexpr (std::is_arithmetic_v<ReturnType> && !std::is_same_v<Type, bool>) {
Any::register_visitor<T, ParameterGetterInterface<float32_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<float32_t>(value->operator()());});
Any::register_visitor<T, ParameterGetterInterface<float64_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<float64_t>(value->operator()());});
Any::register_visitor<T, ParameterGetterInterface<int32_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<int32_t>(value->operator()());});
Any::register_visitor<T, ParameterGetterInterface<int64_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<int64_t>(value->operator()());});
}
else {
Any::register_visitor<T, ParameterGetterInterface<ReturnType>>(
[](T* value, auto* visitor) {visitor->m_value = value->operator()();});
}
}
}
else if constexpr (is_auto_value_v<T>)
{
using ReturnType = traits::get_variant_type_t<0, T>;
Any::register_visitor<T, ParameterGetterInterface<ReturnType, T>>(
[](T* value, auto* visitor)
{
using ReturnType = traits::variant_type_t<0, T>;
static_assert(std::is_arithmetic_v<ReturnType> && !std::is_same_v<Type, bool>,
"Cannot handle non arithmetic types in AutoValue yet");
Any::register_visitor<T, ParameterGetterInterface<float32_t>>(
[](T* value, auto* visitor) {
if (std::holds_alternative<AutoValueEmpty>(*value))
{
// std::bad_optional_access does not support error messages
Expand All @@ -814,18 +808,50 @@ class SGObject: public std::enable_shared_from_this<SGObject>
throw std::bad_optional_access{};
}
else
visitor->m_value = std::get<ReturnType>(*value);
}
);
visitor->m_value = utils::safe_convert<float32_t>(std::get<ReturnType>(*value));
});
Any::register_visitor<T, ParameterGetterInterface<float64_t>>(
[](T* value, auto* visitor) {
if (std::holds_alternative<AutoValueEmpty>(*value))
throw std::bad_optional_access{};
else
visitor->m_value = utils::safe_convert<float64_t>(std::get<ReturnType>(*value));
});
Any::register_visitor<T, ParameterGetterInterface<int32_t>>(
[](T* value, auto* visitor) {
if (std::holds_alternative<AutoValueEmpty>(*value))
throw std::bad_optional_access{};
else
visitor->m_value = utils::safe_convert<int32_t>(std::get<ReturnType>(*value));
});
Any::register_visitor<T, ParameterGetterInterface<int64_t>>(
[](T* value, auto* visitor) {
if (std::holds_alternative<AutoValueEmpty>(*value))
throw std::bad_optional_access{};
else
visitor->m_value = utils::safe_convert<int64_t>(std::get<ReturnType>(*value));
});
}
else
{
Any::register_visitor<T, ParameterGetterInterface<T, T>>(
[](T* value, auto* visitor)
{
visitor->m_value = *value;
}
);
if constexpr(std::is_arithmetic_v<T> && !std::is_same_v<Type, bool>) {
Any::register_visitor<T, ParameterGetterInterface<float32_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<float32_t>(*value);});
Any::register_visitor<T, ParameterGetterInterface<float64_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<float64_t>(*value);});
Any::register_visitor<T, ParameterGetterInterface<int32_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<int32_t>(*value);});
Any::register_visitor<T, ParameterGetterInterface<int64_t>>(
[](T* value, auto* visitor) {visitor->m_value = utils::safe_convert<int64_t>(*value);});
}
else {
Any::register_visitor<T, ParameterGetterInterface<T>>(
[](T* value, auto* visitor)
{
visitor->m_value = *value;
}
);
}
}
}
/** Registers a class parameter which is identified by a tag.
Expand All @@ -840,6 +866,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
void register_param(Tag<T>& _tag, const T& value)
{
create_parameter(_tag, AnyParameter(make_any(value)));
register_parameter_visitor<T>();
}

/** Registers a class parameter which is identified by a name.
Expand All @@ -854,6 +881,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
void register_param(std::string_view name, const T& value)
{
create_parameter(BaseTag(name), AnyParameter(make_any(value)));
register_parameter_visitor<T>();
}

/** Puts a pointer to some parameter into the parameter map.
Expand Down
8 changes: 4 additions & 4 deletions src/shogun/lib/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,11 +1316,11 @@ namespace shogun
if constexpr (std::is_base_of_v<SGObject, Derived>)
{
Any::register_caster<T, SGObject*>(
[](T value) { return dynamic_cast<SGObject*>(value); });
[](T value) { return static_cast<SGObject*>(value); });
if constexpr (!std::is_same_v<std::nullptr_t, base_type<Derived>>
&& !std::is_same_v<Derived, base_type<Derived>>)
Any::register_caster<T, base_type<Derived>*>([](T value) {
return dynamic_cast<base_type<Derived>*>(value);
return static_cast<base_type<Derived>*>(value);
});
}
if constexpr (traits::is_shared_ptr<T>::value)
Expand All @@ -1329,11 +1329,11 @@ namespace shogun
if constexpr (std::is_base_of_v<SGObject, SharedType>)
{
Any::register_caster<T, std::shared_ptr<SGObject>>(
[](T value) { return std::dynamic_pointer_cast<SGObject>(value); });
[](T value) { return std::static_pointer_cast<SGObject>(value); });
if constexpr (!std::is_same_v<std::nullptr_t, base_type<SharedType>>
&& !std::is_same_v<SharedType, base_type<SharedType>>)
Any::register_caster<T, std::shared_ptr<base_type<SharedType>>>([](T value) {
return std::dynamic_pointer_cast<base_type<SharedType>>(value);
return std::static_pointer_cast<base_type<SharedType>>(value);
});
}
}
Expand Down
10 changes: 7 additions & 3 deletions src/shogun/util/traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,19 @@ namespace shogun
inline constexpr bool is_any_of_v = is_any_of<T, Ts...>::value;

template<uint32_t idx, typename Ts>
struct get_variant_type{};
struct variant_type{
using type = Ts;
static constexpr bool value = false;
};

template<uint32_t idx, typename ...Ts>
struct get_variant_type<idx, std::variant<Ts...>>{
struct variant_type<idx, std::variant<Ts...>>{
using type = typename std::tuple_element<idx, std::tuple<Ts...>>::type;
static constexpr bool value = true;
};

template<uint32_t idx, typename ...Ts>
using get_variant_type_t = typename get_variant_type<idx, Ts...>::type;
using variant_type_t = typename variant_type<idx, Ts...>::type;
#endif // DOXYGEN_SHOULD_SKIP_THIS
} // namespace traits
} // namespace shogun
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/base/SGObject_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,10 @@ TEST(SGObject, tags_set_get_int)

EXPECT_THROW(obj->get<int32_t>("foo"), ShogunException);
obj->put(MockObject::kInt, 10);
EXPECT_NO_THROW(obj->put(MockObject::kInt, 10.0));
EXPECT_EQ(obj->get(Tag<int32_t>(MockObject::kInt)), 10);
EXPECT_THROW(obj->get<float64_t>(MockObject::kInt), ShogunException);
EXPECT_THROW(obj->get(Tag<float64_t>(MockObject::kInt)), ShogunException);
EXPECT_EQ(obj->get<float64_t>(MockObject::kInt), 10.0);
EXPECT_EQ(obj->get(Tag<float64_t>(MockObject::kInt)), 10.0);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karlnapf so this won't throw anymore

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep.
But everything that looses information throws?
maybe a test for some examples?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by lose information you mean something like 1.3 becomes 1? Right now I think this function only supports checks for overflow, e.g. int -10 is converted to unsigned int.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we wouldnt want that to happen for now I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vigsterkr would it make sense to extend safe_convert to check floating point casting to integers and see if the delta is larger than fepsilon?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always thought that the safe_convert functions would be for overflow checks, not for loss of information when going from discrete to floating?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(on the other hand, why not :) )

EXPECT_EQ(obj->get<int>(MockObject::kInt), 10);
}

Expand Down