diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index 74ad866d2f9..87eae280f08 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -131,7 +131,7 @@ SG_FORCED_INLINE const char* convert_string_to_char(const char* name) */ class SGObject: public std::enable_shared_from_this { - template + template struct ParameterGetterInterface { ReturnType& m_value; @@ -529,21 +529,21 @@ class SGObject: public std::enable_shared_from_this const auto& value = param.get_value(); try { - if (param.get_properties().has_property(ParameterProperties::CONSTFUNCTION)) - { - ParameterGetterInterface> visitor{result}; - value.visit_with(&visitor); - } - else if (param.get_properties().has_property(ParameterProperties::AUTO)) - { - ParameterGetterInterface> visitor{result}; - value.visit_with(&visitor); - } - else - { - ParameterGetterInterface visitor{result}; + // if (param.get_properties().has_property(ParameterProperties::CONSTFUNCTION)) + // { + ParameterGetterInterface visitor{result}; value.visit_with(&visitor); - } + // } + // else if (param.get_properties().has_property(ParameterProperties::AUTO)) + // { + // ParameterGetterInterface visitor{result}; + // value.visit_with(&visitor); + // } + // else + // { + // ParameterGetterInterface visitor{result}; + // value.visit_with(&visitor); + // } } catch (const std::bad_optional_access&) { @@ -767,24 +767,22 @@ class SGObject: public std::enable_shared_from_this template void register_parameter_visitor() const { - if constexpr (is_auto_value_v) + using Type = std::conditional_t, traits::variant_type_t<0, T>, T>; + + if constexpr (std::is_arithmetic_v) { - using ReturnType = traits::get_variant_type_t<0, T>; - Any::register_visitor>( - [](T* value, auto* visitor) - { - *value = visitor->m_value; - } - ); + Any::register_visitor>( + [](T* value, auto* visitor) { *value = utils::safe_convert(visitor->m_value);}); + Any::register_visitor>( + [](T* value, auto* visitor) { *value = utils::safe_convert(visitor->m_value);}); + Any::register_visitor>( + [](T* value, auto* visitor) { *value = utils::safe_convert(visitor->m_value);}); + Any::register_visitor>( + [](T* value, auto* visitor) { *value = utils::safe_convert(visitor->m_value);}); } - else - { - Any::register_visitor>( - [](T* value, auto* visitor) - { - *value = visitor->m_value; - } - ); + else { + Any::register_visitor>( + [](Type* value, auto* visitor) { *value = visitor->m_value;}); } if constexpr (traits::is_functional::value) @@ -792,20 +790,28 @@ class SGObject: public std::enable_shared_from_this if constexpr (!traits::returns_void::value) { using ReturnType = typename T::result_type; - Any::register_visitor>( - [](T* value, auto* visitor) - { - visitor->m_value = value->operator()(); - } - ); + if constexpr (std::is_arithmetic_v) { + Any::register_visitor>( + [](T* value, auto* visitor) {visitor->m_value = utils::safe_convert(value->operator()());}); + Any::register_visitor>( + [](T* value, auto* visitor) {visitor->m_value = utils::safe_convert(value->operator()());}); + Any::register_visitor>( + [](T* value, auto* visitor) {visitor->m_value = utils::safe_convert(value->operator()());}); + Any::register_visitor>( + [](T* value, auto* visitor) {visitor->m_value = utils::safe_convert(value->operator()());}); + } + else { + Any::register_visitor>( + [](T* value, auto* visitor) {visitor->m_value = value->operator()();}); + } } } else if constexpr (is_auto_value_v) { - using ReturnType = traits::get_variant_type_t<0, T>; - Any::register_visitor>( - [](T* value, auto* visitor) - { + using ReturnType = traits::variant_type_t<0, T>; + static_assert(std::is_arithmetic_v, "Cannot handle non arithmetic types in AutoValue yet"); + Any::register_visitor>( + [](T* value, auto* visitor) { if (std::holds_alternative(*value)) { // std::bad_optional_access does not support error messages @@ -814,18 +820,50 @@ class SGObject: public std::enable_shared_from_this throw std::bad_optional_access{}; } else - visitor->m_value = std::get(*value); - } - ); + visitor->m_value = utils::safe_convert(std::get(*value)); + }); + Any::register_visitor>( + [](T* value, auto* visitor) { + if (std::holds_alternative(*value)) + throw std::bad_optional_access{}; + else + visitor->m_value = utils::safe_convert(std::get(*value)); + }); + Any::register_visitor>( + [](T* value, auto* visitor) { + if (std::holds_alternative(*value)) + throw std::bad_optional_access{}; + else + visitor->m_value = utils::safe_convert(std::get(*value)); + }); + Any::register_visitor>( + [](T* value, auto* visitor) { + if (std::holds_alternative(*value)) + throw std::bad_optional_access{}; + else + visitor->m_value = utils::safe_convert(std::get(*value)); + }); } else { - Any::register_visitor>( - [](T* value, auto* visitor) - { - visitor->m_value = *value; - } - ); + if constexpr(std::is_arithmetic_v) { + Any::register_visitor>( + [](T* value, auto* visitor) {visitor->m_value = utils::safe_convert(*value);}); + Any::register_visitor>( + [](T* value, auto* visitor) {visitor->m_value = utils::safe_convert(*value);}); + Any::register_visitor>( + [](T* value, auto* visitor) {visitor->m_value = utils::safe_convert(*value);}); + Any::register_visitor>( + [](T* value, auto* visitor) {visitor->m_value = utils::safe_convert(*value);}); + } + else { + Any::register_visitor>( + [](T* value, auto* visitor) + { + visitor->m_value = *value; + } + ); + } } } /** Registers a class parameter which is identified by a tag. @@ -840,6 +878,7 @@ class SGObject: public std::enable_shared_from_this void register_param(Tag& _tag, const T& value) { create_parameter(_tag, AnyParameter(make_any(value))); + register_parameter_visitor(); } /** Registers a class parameter which is identified by a name. @@ -854,6 +893,7 @@ class SGObject: public std::enable_shared_from_this void register_param(std::string_view name, const T& value) { create_parameter(BaseTag(name), AnyParameter(make_any(value))); + register_parameter_visitor(); } /** Puts a pointer to some parameter into the parameter map. diff --git a/src/shogun/lib/any.h b/src/shogun/lib/any.h index ecef7026af7..115e3ccf47e 100644 --- a/src/shogun/lib/any.h +++ b/src/shogun/lib/any.h @@ -1316,11 +1316,11 @@ namespace shogun if constexpr (std::is_base_of_v) { Any::register_caster( - [](T value) { return dynamic_cast(value); }); + [](T value) { return static_cast(value); }); if constexpr (!std::is_same_v> && !std::is_same_v>) Any::register_caster*>([](T value) { - return dynamic_cast*>(value); + return static_cast*>(value); }); } if constexpr (traits::is_shared_ptr::value) @@ -1329,11 +1329,11 @@ namespace shogun if constexpr (std::is_base_of_v) { Any::register_caster>( - [](T value) { return std::dynamic_pointer_cast(value); }); + [](T value) { return std::static_pointer_cast(value); }); if constexpr (!std::is_same_v> && !std::is_same_v>) Any::register_caster>>([](T value) { - return std::dynamic_pointer_cast>(value); + return std::static_pointer_cast>(value); }); } } diff --git a/src/shogun/util/traits.h b/src/shogun/util/traits.h index e80dbc125d0..fb1dcf4d340 100644 --- a/src/shogun/util/traits.h +++ b/src/shogun/util/traits.h @@ -175,15 +175,19 @@ namespace shogun inline constexpr bool is_any_of_v = is_any_of::value; template - struct get_variant_type{}; + struct variant_type{ + using type = Ts; + static constexpr bool value = false; + }; template - struct get_variant_type>{ + struct variant_type>{ using type = typename std::tuple_element>::type; + static constexpr bool value = true; }; template - using get_variant_type_t = typename get_variant_type::type; + using variant_type_t = typename variant_type::type; #endif // DOXYGEN_SHOULD_SKIP_THIS } // namespace traits } // namespace shogun diff --git a/tests/unit/base/SGObject_unittest.cc b/tests/unit/base/SGObject_unittest.cc index 8f1f4e3c035..d34e57f93a2 100644 --- a/tests/unit/base/SGObject_unittest.cc +++ b/tests/unit/base/SGObject_unittest.cc @@ -483,9 +483,10 @@ TEST(SGObject, tags_set_get_int) EXPECT_THROW(obj->get("foo"), ShogunException); obj->put(MockObject::kInt, 10); + EXPECT_NO_THROW(obj->put(MockObject::kInt, 10.0)); EXPECT_EQ(obj->get(Tag(MockObject::kInt)), 10); - EXPECT_THROW(obj->get(MockObject::kInt), ShogunException); - EXPECT_THROW(obj->get(Tag(MockObject::kInt)), ShogunException); + EXPECT_EQ(obj->get(MockObject::kInt), 10.0); + EXPECT_EQ(obj->get(Tag(MockObject::kInt)), 10.0); EXPECT_EQ(obj->get(MockObject::kInt), 10); }