diff --git a/xla/ffi/api/api.h b/xla/ffi/api/api.h index a9d93ecdc1e91..6169f06b2ce82 100644 --- a/xla/ffi/api/api.h +++ b/xla/ffi/api/api.h @@ -352,6 +352,20 @@ struct ArgBinding { using Arg = void; }; +// XLA FFI binding for a returned result. +// +// Example: binding for the `MyType` result +// +// template <> +// struct RetBinding { +// using Ret = MyType; +// }; +// +template +struct RetBinding { + using Ret = void; +}; + // XLA FFI binding for a named attribute. // // Example: binding for the `MyType` attribute @@ -394,6 +408,10 @@ template inline constexpr bool is_arg_binding_v = !std::is_void_v::Arg>; +template +inline constexpr bool is_ret_binding_v = + !std::is_void_v::Ret>; + template inline constexpr bool is_attr_binding_v = !std::is_void_v::Attr>; @@ -422,6 +440,11 @@ struct BindOne { return BindOne::To( std::move(fn), std::move(binding).template Arg::Arg>()); + } else if constexpr (is_ret_binding_v) { + // Bind parameter as an FFI handler result. + return BindOne::To( + std::move(fn), + std::move(binding).template Ret::Ret>()); } else if constexpr (is_attr_binding_v) { // Bind parameter as a named FFI handler attribute. diff --git a/xla/ffi/api/ffi.h b/xla/ffi/api/ffi.h index a3238e4102a32..b652d5accda0d 100644 --- a/xla/ffi/api/ffi.h +++ b/xla/ffi/api/ffi.h @@ -216,6 +216,20 @@ struct ArgBinding> { using Arg = Buffer; }; +//===----------------------------------------------------------------------===// +// Results binding +//===----------------------------------------------------------------------===// + +template <> +struct RetBinding> { + using Ret = BufferBase; +}; + +template +struct RetBinding>> { + using Ret = Buffer; +}; + //===----------------------------------------------------------------------===// // Arguments decoding //===----------------------------------------------------------------------===// diff --git a/xla/ffi/api/ffi_test.cc b/xla/ffi/api/ffi_test.cc index 0d9ef1df688e3..b1dc769ca8de9 100644 --- a/xla/ffi/api/ffi_test.cc +++ b/xla/ffi/api/ffi_test.cc @@ -189,6 +189,18 @@ TEST(FfiTest, AutoBinding) { TF_ASSERT_OK(status); } +TEST(FfiTest, AutoBindingResult) { + auto handler = + Ffi::BindTo(+[](Result buffer) { return Error::Success(); }); + + CallFrameBuilder builder; + builder.AddBufferRet(se::DeviceMemoryBase(), PrimitiveType::F32, /*dims=*/{}); + auto call_frame = builder.Build(); + + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); +} + struct I32AndF32 { int32_t i32; float f32;