Skip to content

Commit

Permalink
Performance regression on some benchmarks
Browse files Browse the repository at this point in the history
Reverts b4b3289

PiperOrigin-RevId: 687457925
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Oct 18, 2024
1 parent a7245c5 commit 642f785
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 73 deletions.
11 changes: 8 additions & 3 deletions xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1537,9 +1537,14 @@ PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
i, padded_base_shape.dimensions(i) *
temp_target_sharding.tile_assignment().dim(i));
}
auto offsets =
MakePartitionOffsetsDiff(padded_base_shape, temp_target_sharding,
sharding(), state_.partition_id, state_.b);
auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding,
state_.partition_id, state_.b);
auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(),
state_.partition_id, state_.b);
for (int64_t i = 0; i < offsets.size(); ++i) {
offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary(
offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i]));
}
auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions()));
slice->set_sharding(temp_target_sharding);
Expand Down
32 changes: 11 additions & 21 deletions xla/service/spmd/spmd_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9577,12 +9577,10 @@ ENTRY entry {
AllOf(op::Shape("f32[4,8]"),
op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
op::Constant())));
auto table_look_up =
AllOf(op::Shape("s32[]"),
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())));
auto tiled = AllOf(op::Shape("f32[4,4]"),
op::Copy(op::DynamicSlice(partially_replicated,
op::Constant(), table_look_up)));
auto tiled =
AllOf(op::Shape("f32[4,4]"),
op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
op::Subtract())));
const auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, tiled);
}
Expand Down Expand Up @@ -9636,12 +9634,10 @@ ENTRY entry {
AllOf(op::Shape("f32[4,8]"),
op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
op::Constant())));
auto table_look_up =
AllOf(op::Shape("s32[]"),
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())));
auto tiled = AllOf(op::Shape("f32[4,4]"),
op::Copy(op::DynamicSlice(partially_replicated,
op::Constant(), table_look_up)));
auto tiled =
AllOf(op::Shape("f32[4,4]"),
op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
op::Subtract())));
const auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, tiled);
}
Expand Down Expand Up @@ -9695,13 +9691,10 @@ ENTRY entry {
AllOf(op::Shape("f32[8,4]"),
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
op::Reshape())));
auto table_look_up =
AllOf(op::Shape("s32[]"),
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())));
auto tiled =
AllOf(op::Shape("f32[4,4]"),
op::Copy(op::CollectivePermute(op::DynamicSlice(
partially_replicated, table_look_up, op::Constant()))));
partially_replicated, op::Subtract(), op::Subtract()))));
const auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, tiled);
}
Expand Down Expand Up @@ -10392,13 +10385,10 @@ ENTRY entry {
op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
op::Constant(), op::Constant())),
op::Shape("f32[8,801,1,1024]"));
auto table_look_up =
AllOf(op::Shape("s32[]"),
op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())));
auto resharded_lhs =
AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(
op::Pad(op::DynamicSlice(lhs, op::Constant(), op::Constant(),
op::Constant(), table_look_up),
op::Pad(op::DynamicSlice(lhs, op::Subtract(), op::Subtract(),
op::Subtract(), op::Subtract()),
op::Constant()))))),
op::Shape("f32[16,401,1,512]"));
auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
Expand Down
41 changes: 0 additions & 41 deletions xla/service/spmd/spmd_partitioner_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,47 +202,6 @@ std::vector<HloInstruction*> MakePartitionOffsets(
return offsets;
}

std::vector<HloInstruction*> MakePartitionOffsetsDiff(
const Shape& shape, const HloSharding& sharding_1,
const HloSharding& sharding_2, HloInstruction* partition_id, SpmdBuilder* b,
absl::Span<const int64_t> dims) {
CHECK(!shape.IsTuple());
CHECK_EQ(sharding_1.tile_assignment().num_elements(),
sharding_2.tile_assignment().num_elements());

auto shard_shape_1 = MakePartitionedShape(shape, sharding_1);
auto shard_shape_2 = MakePartitionedShape(shape, sharding_2);
auto const_zero =
b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
std::vector<HloInstruction*> offsets;

for (int64_t i = 0; i < shape.rank(); ++i) {
if (!dims.empty() && !absl::c_linear_search(dims, i)) {
offsets.push_back(const_zero);
} else {
std::vector<int32_t> offset_array(
sharding_1.tile_assignment().num_elements(), 0);
sharding_1.tile_assignment().Each(
[&](absl::Span<const int64_t> indices, int64_t device) {
offset_array[device] = indices[i] * shard_shape_1.dimensions(i);
});
sharding_2.tile_assignment().Each(
[&](absl::Span<const int64_t> indices, int64_t device) {
offset_array[device] -= indices[i] * shard_shape_2.dimensions(i);
});
if (absl::c_all_of(offset_array,
[](int32_t offset) { return offset == 0; })) {
offsets.push_back(const_zero);
} else {
offsets.push_back(
TableLookup<int32_t>(offset_array, S32, partition_id, b));
}
}
}

return offsets;
}

std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) {
CHECK(!sharding.IsTileMaximal());
Expand Down
10 changes: 2 additions & 8 deletions xla/service/spmd/spmd_partitioner_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,13 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,

// Generates the HLO instructions that represent the dimension offsets on any
// device. The size of the returned vector is the rank of the given shape.
// If `dims` is non-empty, the dimensions not in `dims` are constant zero.
// If `dims` is non-empty, the generated offsets will only be non-zero for those
// dimensions.
std::vector<HloInstruction*> MakePartitionOffsets(
const Shape& shape, const HloSharding& sharding,
HloInstruction* partition_id, SpmdBuilder* b,
absl::Span<const int64_t> dims = {});

// Generates the diff between offsets related to two shardings. It is equivalent
// to `MakePartitionOffsets(sharding_1) - MakePartitionOffsets(sharding_2)`.
std::vector<HloInstruction*> MakePartitionOffsetsDiff(
const Shape& shape, const HloSharding& sharding_1,
const HloSharding& sharding_2, HloInstruction* partition_id, SpmdBuilder* b,
absl::Span<const int64_t> dims = {});

// Returns the offsets of the partition in the tile assignment.
std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
Expand Down

0 comments on commit 642f785

Please sign in to comment.