From 873f95f467da28ad6062053e355833d4840d41eb Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Fri, 18 Oct 2024 10:36:30 -0700 Subject: [PATCH] Collapse GpuStream class into StreamCommon. PiperOrigin-RevId: 687352019 --- xla/service/gpu/runtime/nccl_api.cc | 1 + xla/stream_executor/cuda/BUILD | 2 - xla/stream_executor/cuda/cuda_stream.cc | 1 - xla/stream_executor/cuda/cuda_stream.h | 6 +-- xla/stream_executor/cuda/cuda_timer_test.cc | 5 +-- xla/stream_executor/gpu/BUILD | 7 --- xla/stream_executor/gpu/gpu_stream.cc | 27 +----------- xla/stream_executor/gpu/gpu_stream.h | 47 --------------------- xla/stream_executor/host/host_stream.cc | 7 +++ xla/stream_executor/host/host_stream.h | 1 + xla/stream_executor/mock_stream.h | 8 +--- xla/stream_executor/rocm/BUILD | 3 +- xla/stream_executor/rocm/rocm_stream.h | 8 ++-- xla/stream_executor/rocm/rocm_timer_test.cc | 5 +-- xla/stream_executor/stream.h | 27 ++++++++---- xla/stream_executor/stream_common.cc | 17 +++++++- xla/stream_executor/stream_common.h | 14 +++--- 17 files changed, 62 insertions(+), 124 deletions(-) diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc index edbd507b7f6d6..95e7151c50d85 100644 --- a/xla/service/gpu/runtime/nccl_api.cc +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 272a91f008449..dab62ad373bd2 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -1215,8 +1215,6 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", - "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_stream", "@com_google_absl//absl/base", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", diff --git a/xla/stream_executor/cuda/cuda_stream.cc b/xla/stream_executor/cuda/cuda_stream.cc index 475b07c818751..469c19a8b60b5 100644 --- a/xla/stream_executor/cuda/cuda_stream.cc +++ b/xla/stream_executor/cuda/cuda_stream.cc @@ -40,7 +40,6 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" diff --git a/xla/stream_executor/cuda/cuda_stream.h b/xla/stream_executor/cuda/cuda_stream.h index a82f15b51a011..7d8be77df9366 100644 --- a/xla/stream_executor/cuda/cuda_stream.h +++ b/xla/stream_executor/cuda/cuda_stream.h @@ -31,16 +31,16 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_common.h" namespace stream_executor { namespace gpu { -class CudaStream : public GpuStream { +class CudaStream : public StreamCommon { public: absl::Status WaitFor(Stream* other) override; absl::Status RecordEvent(Event* event) override; @@ -82,7 +82,7 @@ class CudaStream : public GpuStream { CudaStream(StreamExecutor* executor, CudaEvent completed_event, std::optional> priority, CUstream stream_handle) - : GpuStream(executor, priority), + : StreamCommon(executor, priority), executor_(executor), completed_event_(std::move(completed_event)), stream_handle_(stream_handle) {} diff --git a/xla/stream_executor/cuda/cuda_timer_test.cc b/xla/stream_executor/cuda/cuda_timer_test.cc index 3eb5b322cded1..021ce4f7d2cdd 100644 --- a/xla/stream_executor/cuda/cuda_timer_test.cc +++ b/xla/stream_executor/cuda/cuda_timer_test.cc @@ -26,7 +26,6 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_executor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -73,7 +72,6 @@ class CudaTimerTest : public ::testing::TestWithParam { StreamExecutor* executor_; std::unique_ptr stream_; - GpuStream* gpu_stream_; private: void SetUp() override { @@ -82,13 +80,12 @@ class CudaTimerTest : public ::testing::TestWithParam { stream_executor::cuda::kCudaPlatformId)); TF_ASSERT_OK_AND_ASSIGN(executor_, platform->ExecutorForDevice(0)); TF_ASSERT_OK_AND_ASSIGN(stream_, executor_->CreateStream(std::nullopt)); - gpu_stream_ = AsGpuStream(stream_.get()); } }; TEST_P(CudaTimerTest, Create) { TF_ASSERT_OK_AND_ASSIGN( - CudaTimer timer, CudaTimer::Create(executor_, gpu_stream_, GetParam())); + CudaTimer timer, CudaTimer::Create(executor_, stream_.get(), GetParam())); // We don't really care what kernel we launch here as long as it takes a // non-zero amount of time. diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 73db02257d4c1..848330c2c5c0a 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -356,16 +356,9 @@ gpu_only_cc_library( hdrs = ["gpu_stream.h"], deps = [ ":gpu_types_header", - "//xla/stream_executor:kernel", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor:platform", "//xla/stream_executor:stream", - "//xla/stream_executor:stream_common", - "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", ], ) diff --git a/xla/stream_executor/gpu/gpu_stream.cc b/xla/stream_executor/gpu/gpu_stream.cc index af199af296924..ee9b15487bab6 100644 --- a/xla/stream_executor/gpu/gpu_stream.cc +++ b/xla/stream_executor/gpu/gpu_stream.cc @@ -15,43 +15,18 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_stream.h" -#include - #include "absl/base/casts.h" #include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" namespace stream_executor { namespace gpu { -absl::Status GpuStream::Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& kernel, - const KernelArgs& args) { - return Launch(thread_dims, block_dims, std::nullopt, kernel, args); -} - -absl::Status GpuStream::Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - return Launch(thread_dims, block_dims, std::make_optional(cluster_dims), - kernel, args); -} - -GpuStream* AsGpuStream(Stream* stream) { - DCHECK(stream != nullptr); - return static_cast(stream); -} - GpuStreamHandle AsGpuStreamValue(Stream* stream) { DCHECK(stream != nullptr); return absl::bit_cast( - AsGpuStream(stream)->platform_specific_handle().stream); + stream->platform_specific_handle().stream); } } // namespace gpu diff --git a/xla/stream_executor/gpu/gpu_stream.h b/xla/stream_executor/gpu/gpu_stream.h index 9222b3bdffea5..ec95ec50e2522 100644 --- a/xla/stream_executor/gpu/gpu_stream.h +++ b/xla/stream_executor/gpu/gpu_stream.h @@ -19,59 +19,12 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_STREAM_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_STREAM_H_ -#include -#include - -#include "absl/log/check.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_common.h" -#include "xla/stream_executor/stream_executor.h" namespace stream_executor { namespace gpu { -// Wraps a GpuStreamHandle in order to satisfy the platform-independent -// StreamInterface. -// -// Thread-safe post-initialization. -class GpuStream : public StreamCommon { - public: - GpuStream(StreamExecutor* parent, - std::optional> priority) - : StreamCommon(parent) { - if (priority.has_value()) { - stream_priority_ = priority.value(); - } - } - - std::variant priority() const override { - return stream_priority_; - } - - absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, - const Kernel& k, const KernelArgs& args) override; - absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, - const ClusterDim& cluster_dims, const Kernel& k, - const KernelArgs& args) override; - - private: - // Helper method to launch a kernel with optional cluster dimensions. - virtual absl::Status Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) = 0; - - std::variant stream_priority_; -}; - -// Helper functions to simplify extremely common flows. -// Converts a Stream to the underlying GpuStream implementation. -GpuStream* AsGpuStream(Stream* stream); - // Extracts a GpuStreamHandle from a GpuStream-backed Stream object. GpuStreamHandle AsGpuStreamValue(Stream* stream); } // namespace gpu diff --git a/xla/stream_executor/host/host_stream.cc b/xla/stream_executor/host/host_stream.cc index 76b66711e03d6..1cbf01298ce21 100644 --- a/xla/stream_executor/host/host_stream.cc +++ b/xla/stream_executor/host/host_stream.cc @@ -22,6 +22,7 @@ limitations under the License. #include // NOLINT #include #include +#include #include #include @@ -197,7 +198,13 @@ absl::Status HostStream::BlockUntilDone() { absl::Status HostStream::Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, const Kernel& kernel, const KernelArgs& args) { + if (cluster_dims.has_value()) { + if (cluster_dims->x != 1 || cluster_dims->y != 1 || cluster_dims->z != 1) { + return absl::UnimplementedError("Not implemented for Host"); + } + } const HostKernel* host_kernel = AsHostKernel(&kernel); const KernelArgsDeviceMemoryArray* device_mem = diff --git a/xla/stream_executor/host/host_stream.h b/xla/stream_executor/host/host_stream.h index e1b530386ab41..dc6760f8f629c 100644 --- a/xla/stream_executor/host/host_stream.h +++ b/xla/stream_executor/host/host_stream.h @@ -73,6 +73,7 @@ class HostStream : public StreamCommon { absl::Status DoHostCallbackWithStatus( absl::AnyInvocable callback) override; absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, const Kernel& kernel, const KernelArgs& args) override; private: diff --git a/xla/stream_executor/mock_stream.h b/xla/stream_executor/mock_stream.h index 5e5b51fe51378..41d06aa4f6e60 100644 --- a/xla/stream_executor/mock_stream.h +++ b/xla/stream_executor/mock_stream.h @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" @@ -77,11 +77,7 @@ class MockStream : public Stream { (const, override)); MOCK_METHOD(absl::Status, Launch, (const ThreadDim &thread_dims, const BlockDim &block_dims, - const Kernel &k, const KernelArgs &args), - (override)); - MOCK_METHOD(absl::Status, Launch, - (const ThreadDim &thread_dims, const BlockDim &block_dims, - const ClusterDim &cluster_dims, const Kernel &k, + const std::optional &cluster_dims, const Kernel &k, const KernelArgs &args), (override)); MOCK_METHOD(const std::string &, GetName, (), (const, override)); diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 1b453fdb0ec65..6f5a0dad7fac4 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -1002,7 +1002,7 @@ cc_library( "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:stream", - "//xla/stream_executor/gpu:gpu_stream", + "//xla/stream_executor:stream_common", "@com_google_absl//absl/base", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", @@ -1097,7 +1097,6 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream", "//xla/stream_executor:typed_kernel_factory", - "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:gpu_test_kernels_rocm", "@com_google_absl//absl/status", "@com_google_absl//absl/time", diff --git a/xla/stream_executor/rocm/rocm_stream.h b/xla/stream_executor/rocm/rocm_stream.h index de5565cdd34fc..693335daa187b 100644 --- a/xla/stream_executor/rocm/rocm_stream.h +++ b/xla/stream_executor/rocm/rocm_stream.h @@ -17,29 +17,27 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_ROCM_ROCM_STREAM_H_ #include -#include #include #include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "rocm/include/hip/hip_runtime.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_common.h" namespace stream_executor { namespace gpu { -class RocmStream : public GpuStream { +class RocmStream : public StreamCommon { public: absl::Status WaitFor(Stream* other) override; absl::Status RecordEvent(Event* event) override; @@ -79,7 +77,7 @@ class RocmStream : public GpuStream { RocmStream(StreamExecutor* executor, RocmEvent completed_event, std::optional> priority, hipStream_t stream_handle) - : GpuStream(executor, priority), + : StreamCommon(executor, priority), executor_(executor), completed_event_(std::move(completed_event)), stream_handle_(stream_handle) {} diff --git a/xla/stream_executor/rocm/rocm_timer_test.cc b/xla/stream_executor/rocm/rocm_timer_test.cc index a270e22fe99fb..958c5dfa53316 100644 --- a/xla/stream_executor/rocm/rocm_timer_test.cc +++ b/xla/stream_executor/rocm/rocm_timer_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/time/time.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -73,7 +72,6 @@ class RocmTimerTest : public ::testing::Test { RocmExecutor* executor_; std::unique_ptr stream_; - GpuStream* gpu_stream_; private: void SetUp() override { @@ -84,13 +82,12 @@ class RocmTimerTest : public ::testing::Test { platform->ExecutorForDevice(0)); executor_ = reinterpret_cast(executor); TF_ASSERT_OK_AND_ASSIGN(stream_, executor_->CreateStream(std::nullopt)); - gpu_stream_ = AsGpuStream(stream_.get()); } }; TEST_F(RocmTimerTest, Create) { TF_ASSERT_OK_AND_ASSIGN(RocmTimer timer, - RocmTimer::Create(executor_, gpu_stream_)); + RocmTimer::Create(executor_, stream_.get())); // We don't really care what kernel we launch here as long as it takes a // non-zero amount of time. diff --git a/xla/stream_executor/stream.h b/xla/stream_executor/stream.h index 7983290852430..220cbf761c24f 100644 --- a/xla/stream_executor/stream.h +++ b/xla/stream_executor/stream.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -271,20 +272,19 @@ class Stream { // Launches a data parallel kernel with the given thread/block // dimensionality and already-packed args/sizes to pass to the underlying // platform driver. - virtual absl::Status Launch(const ThreadDim &thread_dims, - const BlockDim &block_dims, const Kernel &k, - const KernelArgs &args) { - return absl::UnimplementedError("Not implemented"); + absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, + const Kernel &kernel, const KernelArgs &args) { + return Launch(thread_dims, block_dims, std::nullopt, kernel, args); } // Launches a data parallel kernel with the given thread/block // dimensionality and already-packed args/sizes to pass to the underlying // platform driver. - virtual absl::Status Launch(const ThreadDim &thread_dims, - const BlockDim &block_dims, - const ClusterDim &cluster_dims, const Kernel &k, - const KernelArgs &args) { - return absl::UnimplementedError("Not implemented"); + absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, + const ClusterDim &cluster_dims, const Kernel &kernel, + const KernelArgs &args) { + return Launch(thread_dims, block_dims, std::make_optional(cluster_dims), + kernel, args); } // Get/set a name for a stream, which can be shown in profiling tools @@ -305,6 +305,15 @@ class Stream { return absl::UnimplementedError( "This stream does not support EventBasedTimers."); } + + private: + // Helper method to launch a kernel with optional cluster dimensions. + virtual absl::Status Launch(const ThreadDim &thread_dims, + const BlockDim &block_dims, + const std::optional &cluster_dims, + const Kernel &kernel, const KernelArgs &args) { + return absl::UnimplementedError("Not implemented"); + } }; template diff --git a/xla/stream_executor/stream_common.cc b/xla/stream_executor/stream_common.cc index 644c8edbba601..2c4dd7828c539 100644 --- a/xla/stream_executor/stream_common.cc +++ b/xla/stream_executor/stream_common.cc @@ -18,14 +18,16 @@ limitations under the License. #include #include #include +#include #include +#include #include #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/blas.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -33,10 +35,21 @@ limitations under the License. namespace stream_executor { StreamCommon::StreamCommon(StreamExecutor *parent) - : parent_(parent), status_(absl::OkStatus()) { + : parent_(parent), + status_(absl::OkStatus()), + stream_priority_(StreamPriority::Default) { CHECK_NE(parent, nullptr); } +StreamCommon::StreamCommon( + StreamExecutor *parent, + std::optional> priority) + : StreamCommon(parent) { + if (priority.has_value()) { + stream_priority_ = priority.value(); + } +} + StreamCommon::PlatformSpecificHandle StreamCommon::platform_specific_handle() const { PlatformSpecificHandle handle; diff --git a/xla/stream_executor/stream_common.h b/xla/stream_executor/stream_common.h index 25c106ae4dd54..5832a8a195014 100644 --- a/xla/stream_executor/stream_common.h +++ b/xla/stream_executor/stream_common.h @@ -22,6 +22,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_STREAM_COMMON_H_ #include +#include #include #include #include @@ -31,7 +32,6 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/fft.h" @@ -60,6 +60,9 @@ class StreamCommon : public Stream { // StreamExecutor's platform. explicit StreamCommon(StreamExecutor *parent); + StreamCommon(StreamExecutor *parent, + std::optional> priority); + PlatformSpecificHandle platform_specific_handle() const override; bool ok() const override { return !InErrorState(); } absl::StatusOr GetOrCreateSubStream() override @@ -69,6 +72,9 @@ class StreamCommon : public Stream { CHECK(parent_ != nullptr); return parent_; } + std::variant priority() const override { + return stream_priority_; + } CudaComputeCapability GetCudaComputeCapability() const override { return parent()->GetDeviceDescription().cuda_compute_capability(); @@ -77,9 +83,6 @@ class StreamCommon : public Stream { RocmComputeCapability GetRocmComputeCapability() const override { return parent()->GetDeviceDescription().rocm_compute_capability(); } - std::variant priority() const override { - return StreamPriority::Default; - } // Doesn't do anything interesting by default; GpuStream connects this to NVTX const std::string &GetName() const override { return name_; } @@ -117,8 +120,7 @@ class StreamCommon : public Stream { std::vector, bool>> sub_streams_ ABSL_GUARDED_BY(mu_); - StreamCommon(const StreamCommon &) = delete; - void operator=(const StreamCommon &) = delete; + std::variant stream_priority_; }; } // namespace stream_executor