diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index bd645199b6b47..3f988fb6beede 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -1537,9 +1537,14 @@ PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice( i, padded_base_shape.dimensions(i) * temp_target_sharding.tile_assignment().dim(i)); } - auto offsets = - MakePartitionOffsetsDiff(padded_base_shape, temp_target_sharding, - sharding(), state_.partition_id, state_.b); + auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding, + state_.partition_id, state_.b); + auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(), + state_.partition_id, state_.b); + for (int64_t i = 0; i < offsets.size(); ++i) { + offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary( + offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i])); + } auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions())); slice->set_sharding(temp_target_sharding); diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 3a1479d618e89..2d1d352351784 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -9577,12 +9577,10 @@ ENTRY entry { AllOf(op::Shape("f32[4,8]"), op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), op::Constant()))); - auto table_look_up = - AllOf(op::Shape("s32[]"), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()))); - auto tiled = AllOf(op::Shape("f32[4,4]"), - op::Copy(op::DynamicSlice(partially_replicated, - op::Constant(), table_look_up))); + auto tiled = + AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(), + op::Subtract()))); const auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, tiled); } @@ -9636,12 +9634,10 @@ ENTRY entry { AllOf(op::Shape("f32[4,8]"), op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), op::Constant()))); - auto table_look_up = - AllOf(op::Shape("s32[]"), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()))); - auto tiled = AllOf(op::Shape("f32[4,4]"), - op::Copy(op::DynamicSlice(partially_replicated, - op::Constant(), table_look_up))); + auto tiled = + AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(), + op::Subtract()))); const auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, tiled); } @@ -9695,13 +9691,10 @@ ENTRY entry { AllOf(op::Shape("f32[8,4]"), op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape()))); - auto table_look_up = - AllOf(op::Shape("s32[]"), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()))); auto tiled = AllOf(op::Shape("f32[4,4]"), op::Copy(op::CollectivePermute(op::DynamicSlice( - partially_replicated, table_look_up, op::Constant())))); + partially_replicated, op::Subtract(), op::Subtract())))); const auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, tiled); } @@ -10392,13 +10385,10 @@ ENTRY entry { op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), op::Constant(), op::Constant())), op::Shape("f32[8,801,1,1024]")); - auto table_look_up = - AllOf(op::Shape("s32[]"), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()))); auto resharded_lhs = AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape( - op::Pad(op::DynamicSlice(lhs, op::Constant(), op::Constant(), - op::Constant(), table_look_up), + op::Pad(op::DynamicSlice(lhs, op::Subtract(), op::Subtract(), + op::Subtract(), op::Subtract()), op::Constant()))))), op::Shape("f32[16,401,1,512]")); auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), diff --git a/xla/service/spmd/spmd_partitioner_util.cc b/xla/service/spmd/spmd_partitioner_util.cc index eab525a5765fa..ac1d272bac56f 100644 --- a/xla/service/spmd/spmd_partitioner_util.cc +++ b/xla/service/spmd/spmd_partitioner_util.cc @@ -202,47 +202,6 @@ std::vector MakePartitionOffsets( return offsets; } -std::vector MakePartitionOffsetsDiff( - const Shape& shape, const HloSharding& sharding_1, - const HloSharding& sharding_2, HloInstruction* partition_id, SpmdBuilder* b, - absl::Span dims) { - CHECK(!shape.IsTuple()); - CHECK_EQ(sharding_1.tile_assignment().num_elements(), - sharding_2.tile_assignment().num_elements()); - - auto shard_shape_1 = MakePartitionedShape(shape, sharding_1); - auto shard_shape_2 = MakePartitionedShape(shape, sharding_2); - auto const_zero = - b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); - std::vector offsets; - - for (int64_t i = 0; i < shape.rank(); ++i) { - if (!dims.empty() && !absl::c_linear_search(dims, i)) { - offsets.push_back(const_zero); - } else { - std::vector offset_array( - sharding_1.tile_assignment().num_elements(), 0); - sharding_1.tile_assignment().Each( - [&](absl::Span indices, int64_t device) { - offset_array[device] = indices[i] * shard_shape_1.dimensions(i); - }); - sharding_2.tile_assignment().Each( - [&](absl::Span indices, int64_t device) { - offset_array[device] -= indices[i] * shard_shape_2.dimensions(i); - }); - if (absl::c_all_of(offset_array, - [](int32_t offset) { return offset == 0; })) { - offsets.push_back(const_zero); - } else { - offsets.push_back( - TableLookup(offset_array, S32, partition_id, b)); - } - } - } - - return offsets; -} - std::vector MakeTiledPartitionOrdinals( const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { CHECK(!sharding.IsTileMaximal()); diff --git a/xla/service/spmd/spmd_partitioner_util.h b/xla/service/spmd/spmd_partitioner_util.h index 1f72322a7d290..a982c3edf1e8d 100644 --- a/xla/service/spmd/spmd_partitioner_util.h +++ b/xla/service/spmd/spmd_partitioner_util.h @@ -197,19 +197,13 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, // Generates the HLO instructions that represent the dimension offsets on any // device. The size of the returned vector is the rank of the given shape. -// If `dims` is non-empty, the dimensions not in `dims` are constant zero. +// If `dims` is non-empty, the generated offsets will only be non-zero for those +// dimensions. std::vector MakePartitionOffsets( const Shape& shape, const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b, absl::Span dims = {}); -// Generates the diff between offsets related to two shardings. It is equivalent -// to `MakePartitionOffsets(sharding_1) - MakePartitionOffsets(sharding_2)`. -std::vector MakePartitionOffsetsDiff( - const Shape& shape, const HloSharding& sharding_1, - const HloSharding& sharding_2, HloInstruction* partition_id, SpmdBuilder* b, - absl::Span dims = {}); - // Returns the offsets of the partition in the tile assignment. std::vector MakeTiledPartitionOrdinals( const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);