Skip to content

Commit

Permalink
Collapse GpuStream class into StreamCommon.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687352019
  • Loading branch information
klucke authored and Google-ML-Automation committed Oct 18, 2024
1 parent 8761b7e commit 873f95f
Show file tree
Hide file tree
Showing 17 changed files with 62 additions and 124 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
Expand Down
2 changes: 0 additions & 2 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1215,8 +1215,6 @@ cc_library(
"//xla/stream_executor:platform",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_common",
"//xla/stream_executor/gpu:gpu_driver_header",
"//xla/stream_executor/gpu:gpu_stream",
"@com_google_absl//absl/base",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
Expand Down
1 change: 0 additions & 1 deletion xla/stream_executor/cuda/cuda_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_status.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
Expand Down
6 changes: 3 additions & 3 deletions xla/stream_executor/cuda/cuda_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ limitations under the License.
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/event_based_timer.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_common.h"

namespace stream_executor {
namespace gpu {

class CudaStream : public GpuStream {
class CudaStream : public StreamCommon {
public:
absl::Status WaitFor(Stream* other) override;
absl::Status RecordEvent(Event* event) override;
Expand Down Expand Up @@ -82,7 +82,7 @@ class CudaStream : public GpuStream {
CudaStream(StreamExecutor* executor, CudaEvent completed_event,
std::optional<std::variant<StreamPriority, int>> priority,
CUstream stream_handle)
: GpuStream(executor, priority),
: StreamCommon(executor, priority),
executor_(executor),
completed_event_(std::move(completed_event)),
stream_handle_(stream_handle) {}
Expand Down
5 changes: 1 addition & 4 deletions xla/stream_executor/cuda/cuda_timer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_executor.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/gpu_test_kernels.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/kernel_spec.h"
Expand Down Expand Up @@ -73,7 +72,6 @@ class CudaTimerTest : public ::testing::TestWithParam<CudaTimer::TimerType> {

StreamExecutor* executor_;
std::unique_ptr<Stream> stream_;
GpuStream* gpu_stream_;

private:
void SetUp() override {
Expand All @@ -82,13 +80,12 @@ class CudaTimerTest : public ::testing::TestWithParam<CudaTimer::TimerType> {
stream_executor::cuda::kCudaPlatformId));
TF_ASSERT_OK_AND_ASSIGN(executor_, platform->ExecutorForDevice(0));
TF_ASSERT_OK_AND_ASSIGN(stream_, executor_->CreateStream(std::nullopt));
gpu_stream_ = AsGpuStream(stream_.get());
}
};

TEST_P(CudaTimerTest, Create) {
TF_ASSERT_OK_AND_ASSIGN(
CudaTimer timer, CudaTimer::Create(executor_, gpu_stream_, GetParam()));
CudaTimer timer, CudaTimer::Create(executor_, stream_.get(), GetParam()));

// We don't really care what kernel we launch here as long as it takes a
// non-zero amount of time.
Expand Down
7 changes: 0 additions & 7 deletions xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -356,16 +356,9 @@ gpu_only_cc_library(
hdrs = ["gpu_stream.h"],
deps = [
":gpu_types_header",
"//xla/stream_executor:kernel",
"//xla/stream_executor:launch_dim",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_common",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/base",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
],
)

Expand Down
27 changes: 1 addition & 26 deletions xla/stream_executor/gpu/gpu_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,18 @@ limitations under the License.

#include "xla/stream_executor/gpu/gpu_stream.h"

#include <optional>

#include "absl/base/casts.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/stream.h"

namespace stream_executor {
namespace gpu {

absl::Status GpuStream::Launch(const ThreadDim& thread_dims,
const BlockDim& block_dims, const Kernel& kernel,
const KernelArgs& args) {
return Launch(thread_dims, block_dims, std::nullopt, kernel, args);
}

absl::Status GpuStream::Launch(const ThreadDim& thread_dims,
const BlockDim& block_dims,
const ClusterDim& cluster_dims,
const Kernel& kernel, const KernelArgs& args) {
return Launch(thread_dims, block_dims, std::make_optional(cluster_dims),
kernel, args);
}

GpuStream* AsGpuStream(Stream* stream) {
DCHECK(stream != nullptr);
return static_cast<GpuStream*>(stream);
}

GpuStreamHandle AsGpuStreamValue(Stream* stream) {
DCHECK(stream != nullptr);
return absl::bit_cast<GpuStreamHandle>(
AsGpuStream(stream)->platform_specific_handle().stream);
stream->platform_specific_handle().stream);
}

} // namespace gpu
Expand Down
47 changes: 0 additions & 47 deletions xla/stream_executor/gpu/gpu_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,59 +19,12 @@ limitations under the License.
#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_STREAM_H_
#define XLA_STREAM_EXECUTOR_GPU_GPU_STREAM_H_

#include <optional>
#include <variant>

#include "absl/log/check.h"
#include "xla/stream_executor/gpu/gpu_types.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_common.h"
#include "xla/stream_executor/stream_executor.h"

namespace stream_executor {
namespace gpu {

// Wraps a GpuStreamHandle in order to satisfy the platform-independent
// StreamInterface.
//
// Thread-safe post-initialization.
class GpuStream : public StreamCommon {
public:
GpuStream(StreamExecutor* parent,
std::optional<std::variant<StreamPriority, int>> priority)
: StreamCommon(parent) {
if (priority.has_value()) {
stream_priority_ = priority.value();
}
}

std::variant<StreamPriority, int> priority() const override {
return stream_priority_;
}

absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
const Kernel& k, const KernelArgs& args) override;
absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
const ClusterDim& cluster_dims, const Kernel& k,
const KernelArgs& args) override;

private:
// Helper method to launch a kernel with optional cluster dimensions.
virtual absl::Status Launch(const ThreadDim& thread_dims,
const BlockDim& block_dims,
const std::optional<ClusterDim>& cluster_dims,
const Kernel& kernel, const KernelArgs& args) = 0;

std::variant<StreamPriority, int> stream_priority_;
};

// Helper functions to simplify extremely common flows.
// Converts a Stream to the underlying GpuStream implementation.
GpuStream* AsGpuStream(Stream* stream);

// Extracts a GpuStreamHandle from a GpuStream-backed Stream object.
GpuStreamHandle AsGpuStreamValue(Stream* stream);
} // namespace gpu
Expand Down
7 changes: 7 additions & 0 deletions xla/stream_executor/host/host_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <cfenv> // NOLINT
#include <cstdint>
#include <memory>
#include <optional>
#include <queue>
#include <utility>

Expand Down Expand Up @@ -197,7 +198,13 @@ absl::Status HostStream::BlockUntilDone() {

absl::Status HostStream::Launch(const ThreadDim& thread_dims,
const BlockDim& block_dims,
const std::optional<ClusterDim>& cluster_dims,
const Kernel& kernel, const KernelArgs& args) {
if (cluster_dims.has_value()) {
if (cluster_dims->x != 1 || cluster_dims->y != 1 || cluster_dims->z != 1) {
return absl::UnimplementedError("Not implemented for Host");
}
}
const HostKernel* host_kernel = AsHostKernel(&kernel);

const KernelArgsDeviceMemoryArray* device_mem =
Expand Down
1 change: 1 addition & 0 deletions xla/stream_executor/host/host_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class HostStream : public StreamCommon {
absl::Status DoHostCallbackWithStatus(
absl::AnyInvocable<absl::Status() &&> callback) override;
absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims,
const std::optional<ClusterDim>& cluster_dims,
const Kernel& kernel, const KernelArgs& args) override;

private:
Expand Down
8 changes: 2 additions & 6 deletions xla/stream_executor/mock_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <variant>

#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event.h"
Expand Down Expand Up @@ -77,11 +77,7 @@ class MockStream : public Stream {
(const, override));
MOCK_METHOD(absl::Status, Launch,
(const ThreadDim &thread_dims, const BlockDim &block_dims,
const Kernel &k, const KernelArgs &args),
(override));
MOCK_METHOD(absl::Status, Launch,
(const ThreadDim &thread_dims, const BlockDim &block_dims,
const ClusterDim &cluster_dims, const Kernel &k,
const std::optional<ClusterDim> &cluster_dims, const Kernel &k,
const KernelArgs &args),
(override));
MOCK_METHOD(const std::string &, GetName, (), (const, override));
Expand Down
3 changes: 1 addition & 2 deletions xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ cc_library(
"//xla/stream_executor:launch_dim",
"//xla/stream_executor:platform",
"//xla/stream_executor:stream",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor:stream_common",
"@com_google_absl//absl/base",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
Expand Down Expand Up @@ -1097,7 +1097,6 @@ xla_test(
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream",
"//xla/stream_executor:typed_kernel_factory",
"//xla/stream_executor/gpu:gpu_stream",
"//xla/stream_executor/gpu:gpu_test_kernels_rocm",
"@com_google_absl//absl/status",
"@com_google_absl//absl/time",
Expand Down
8 changes: 3 additions & 5 deletions xla/stream_executor/rocm/rocm_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,27 @@ limitations under the License.
#define XLA_STREAM_EXECUTOR_ROCM_ROCM_STREAM_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <variant>

#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "rocm/include/hip/hip_runtime.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/event_based_timer.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/rocm/rocm_event.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_common.h"

namespace stream_executor {
namespace gpu {

class RocmStream : public GpuStream {
class RocmStream : public StreamCommon {
public:
absl::Status WaitFor(Stream* other) override;
absl::Status RecordEvent(Event* event) override;
Expand Down Expand Up @@ -79,7 +77,7 @@ class RocmStream : public GpuStream {
RocmStream(StreamExecutor* executor, RocmEvent completed_event,
std::optional<std::variant<StreamPriority, int>> priority,
hipStream_t stream_handle)
: GpuStream(executor, priority),
: StreamCommon(executor, priority),
executor_(executor),
completed_event_(std::move(completed_event)),
stream_handle_(stream_handle) {}
Expand Down
5 changes: 1 addition & 4 deletions xla/stream_executor/rocm/rocm_timer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/time/time.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/gpu/gpu_test_kernels.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/kernel_spec.h"
Expand Down Expand Up @@ -73,7 +72,6 @@ class RocmTimerTest : public ::testing::Test {

RocmExecutor* executor_;
std::unique_ptr<Stream> stream_;
GpuStream* gpu_stream_;

private:
void SetUp() override {
Expand All @@ -84,13 +82,12 @@ class RocmTimerTest : public ::testing::Test {
platform->ExecutorForDevice(0));
executor_ = reinterpret_cast<RocmExecutor*>(executor);
TF_ASSERT_OK_AND_ASSIGN(stream_, executor_->CreateStream(std::nullopt));
gpu_stream_ = AsGpuStream(stream_.get());
}
};

TEST_F(RocmTimerTest, Create) {
TF_ASSERT_OK_AND_ASSIGN(RocmTimer timer,
RocmTimer::Create(executor_, gpu_stream_));
RocmTimer::Create(executor_, stream_.get()));

// We don't really care what kernel we launch here as long as it takes a
// non-zero amount of time.
Expand Down
Loading

0 comments on commit 873f95f

Please sign in to comment.