Skip to content

Commit

Permalink
[xla:ffi] Add auto-binding for FFI results
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622006316
  • Loading branch information
ezhulenev authored and copybara-github committed Apr 4, 2024
1 parent df104c0 commit 4d135db
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
23 changes: 23 additions & 0 deletions xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,20 @@ struct ArgBinding {
using Arg = void;
};

// XLA FFI binding for a returned result.
//
// Example: binding for the `MyType` result
//
// template <>
// struct RetBinding<MyType> {
// using Ret = MyType;
// };
//
template <typename T>
struct RetBinding {
using Ret = void;
};

// XLA FFI binding for a named attribute.
//
// Example: binding for the `MyType` attribute
Expand Down Expand Up @@ -394,6 +408,10 @@ template <typename Param>
inline constexpr bool is_arg_binding_v =
!std::is_void_v<typename ArgBinding<Param>::Arg>;

template <typename Param>
inline constexpr bool is_ret_binding_v =
!std::is_void_v<typename RetBinding<Param>::Ret>;

template <typename Param>
inline constexpr bool is_attr_binding_v =
!std::is_void_v<typename AttrBinding<Param>::Attr>;
Expand Down Expand Up @@ -422,6 +440,11 @@ struct BindOne<Fn, Param, Params...> {
return BindOne<Fn, Params...>::To(
std::move(fn),
std::move(binding).template Arg<typename ArgBinding<Param>::Arg>());
} else if constexpr (is_ret_binding_v<Param>) {
// Bind parameter as an FFI handler result.
return BindOne<Fn, Params...>::To(
std::move(fn),
std::move(binding).template Ret<typename RetBinding<Param>::Ret>());

} else if constexpr (is_attr_binding_v<Param>) {
// Bind parameter as a named FFI handler attribute.
Expand Down
14 changes: 14 additions & 0 deletions xla/ffi/api/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,20 @@ struct ArgBinding<Buffer<dtype, rank>> {
using Arg = Buffer<dtype, rank>;
};

//===----------------------------------------------------------------------===//
// Results binding
//===----------------------------------------------------------------------===//

template <>
struct RetBinding<Result<BufferBase>> {
using Ret = BufferBase;
};

template <DataType dtype, size_t rank>
struct RetBinding<Result<Buffer<dtype, rank>>> {
using Ret = Buffer<dtype, rank>;
};

//===----------------------------------------------------------------------===//
// Arguments decoding
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions xla/ffi/api/ffi_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,18 @@ TEST(FfiTest, AutoBinding) {
TF_ASSERT_OK(status);
}

TEST(FfiTest, AutoBindingResult) {
auto handler =
Ffi::BindTo(+[](Result<BufferBase> buffer) { return Error::Success(); });

CallFrameBuilder builder;
builder.AddBufferRet(se::DeviceMemoryBase(), PrimitiveType::F32, /*dims=*/{});
auto call_frame = builder.Build();

auto status = Call(*handler, call_frame);
TF_ASSERT_OK(status);
}

struct I32AndF32 {
int32_t i32;
float f32;
Expand Down

0 comments on commit 4d135db

Please sign in to comment.