diff --git a/xla/pjrt/cpu/cpu_client.cc b/xla/pjrt/cpu/cpu_client.cc index 53ac1d3175525..559a3fa77f927 100644 --- a/xla/pjrt/cpu/cpu_client.cc +++ b/xla/pjrt/cpu/cpu_client.cc @@ -285,6 +285,12 @@ absl::string_view TfrtCpuDeviceDescription::ToString() const { machine_attributes); } +absl::StatusOr TfrtCpuTopologyDescription::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const { + Shape shape = ShapeUtil::MakeShape(element_type, dims); + return LayoutUtil::GetWithDefaultLayout(shape).layout(); +} + absl::StatusOr TfrtCpuTopologyDescription::Serialize() const { std::string result; if (!tsl::SerializeToStringDeterministic(cpu_topology_.ToProto(), &result)) { diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h index 34a63690b4b0d..b302243ecfecd 100644 --- a/xla/pjrt/cpu/cpu_client.h +++ b/xla/pjrt/cpu/cpu_client.h @@ -183,6 +183,10 @@ class TfrtCpuTopologyDescription : public PjRtTopologyDescription { return attributes_; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override; + private: const PjRtPlatformId platform_id_; const std::string platform_name_; diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 9dec918f99327..c6b1865a6915f 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1130,6 +1130,12 @@ absl::StatusOr StreamExecutorGpuTopologyDescription::Serialize() return result; } +absl::StatusOr StreamExecutorGpuTopologyDescription::GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const { + Shape shape = ShapeUtil::MakeShape(element_type, dims); + return LayoutUtil::GetWithDefaultLayout(shape).layout(); +} + std::vector> BuildLocalDevices( std::map> local_device_states, int node_id) { diff --git a/xla/pjrt/gpu/se_gpu_pjrt_client.h b/xla/pjrt/gpu/se_gpu_pjrt_client.h index 529258e2a90cc..39c9327683d2e 100644 --- a/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -132,6 +132,10 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription { return attributes_; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override; + private: const PjRtPlatformId platform_id_; const std::string platform_name_; diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h index d1d50261f7f6d..92c1c3044cc93 100644 --- a/xla/pjrt/pjrt_c_api_client.h +++ b/xla/pjrt/pjrt_c_api_client.h @@ -218,6 +218,12 @@ class PjRtCApiTopologyDescription : public PjRtTopologyDescription { return attributes_; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override { + return Unimplemented("PJRT C API does not support GetDefaultLayout"); + } + private: std::unique_ptr compiler_; const PJRT_Api* c_api_; diff --git a/xla/pjrt/pjrt_compiler.h b/xla/pjrt/pjrt_compiler.h index d624fd0cf99cd..46c363ace1361 100644 --- a/xla/pjrt/pjrt_compiler.h +++ b/xla/pjrt/pjrt_compiler.h @@ -139,6 +139,15 @@ class PjRtTopologyDescription { // Returns vendor specific attributes about the topology. virtual const absl::flat_hash_map& Attributes() const = 0; + + // Returns the default device layout for a buffer with `element_type` and + // `dims`. The default layout is a platform-specific layout used when no other + // layout is specified, e.g. for host-to-device transfers. When compiling, the + // default layout is used for program arguments and outputs unless + // user-specified or compiler-chosen layouts are requested via the + // "mhlo.layout_mode" attribute. + virtual StatusOr GetDefaultLayout( + PrimitiveType element_type, absl::Span dims) const = 0; }; // Abstract interface that all registered compilers must implement. diff --git a/xla/pjrt/pjrt_compiler_test.cc b/xla/pjrt/pjrt_compiler_test.cc index 182e3ba9f7b85..98a2b8e8d5e16 100644 --- a/xla/pjrt/pjrt_compiler_test.cc +++ b/xla/pjrt/pjrt_compiler_test.cc @@ -56,6 +56,11 @@ class PjRtTestTopology : public PjRtTopologyDescription { const override { LOG(FATAL) << "Unused"; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override { + return Unimplemented("TestTopology does not support GetDefaultLayout"); + } }; TEST(PjRtCompilerTest, CompilerNotRegistered) { @@ -85,6 +90,11 @@ TEST(PjRtCompilerTest, CompilerRegistered) { const override { LOG(FATAL) << "Unused"; } + StatusOr GetDefaultLayout( + PrimitiveType element_type, + absl::Span dims) const override { + return Unimplemented("TestTopology does not support GetDefaultLayout"); + } }; PjRtTestTopology topology;