Skip to content

Commit

Permalink
Crash on HLOs with nested tuples in conditionals.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622009093
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Apr 5, 2024
1 parent 4d135db commit 0960128
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
55 changes: 41 additions & 14 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2309,10 +2309,10 @@ Status SetHloShardingPostProcessing(
continue;
} else {
if (inst->shape().IsTuple()) {
// While we do not support nested tuples fully, this is a hack to get
// things to work in some cases (specifically observed for the llama and
// gemma models) where nested tuples as used as inputs/outputs of the
// kOptimizationBarrier instruction.
// While we do not support nested tuples fully (b/332951306), this is a
// hack to get things to work in some cases (specifically observed for
// the llama and gemma models) where nested tuples as used as
// inputs/outputs of the kOptimizationBarrier instruction.
if (absl::c_any_of(
inst->shape().tuple_shapes(),
[](const Shape& shape) { return shape.IsTuple(); })) {
Expand Down Expand Up @@ -2355,7 +2355,7 @@ Status SetHloShardingPostProcessing(
for (size_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
CHECK(!inst->shape().tuple_shapes(i).IsTuple())
<< "We currently do not support ops with nested tuples as "
"output.";
"output. See b/332951306.";
const ShardingStrategy& stra =
GetShardingStrategyForTuple(inst, {static_cast<int64_t>(i)},
strategy_map, cost_graph, s_val);
Expand Down Expand Up @@ -2842,7 +2842,7 @@ void FindReplicateSet(
}

// Substitute all-reduce strategies with their reduce-scatter variants.
void GenerateReduceScatter(
absl::Status GenerateReduceScatter(
const HloInstructionSequence& sequence, const AliasMap& alias_map,
const InstructionDepthMap& depth_map, const StrategyMap& strategy_map,
const CostGraph& cost_graph, absl::Span<const NodeStrategyIdx> s_val,
Expand Down Expand Up @@ -3107,8 +3107,9 @@ void GenerateReduceScatter(
replace_with->set_sharding(
GetShardingStrategy(inst, strategy_map, cost_graph, s_val)
.output_sharding);
TF_CHECK_OK(inst->ReplaceAllUsesWith(replace_with));
TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(replace_with));
}
return OkStatus();
}

void AnnotateShardingWithSimpleHeuristic(
Expand Down Expand Up @@ -3837,8 +3838,9 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(

// ----- Substitute all-reduce with reduce-scatter -----
if (option_.prefer_reduce_scatter) {
GenerateReduceScatter(sequence, alias_map, ins_depth_map, strategy_map,
cost_graph, s_val, cluster_env, option_);
TF_RETURN_IF_ERROR(GenerateReduceScatter(
sequence, alias_map, ins_depth_map, strategy_map, cost_graph, s_val,
cluster_env, option_));
}
// ----- Set Sharding -----
SetHloSharding(sequence, strategy_map, cost_graph, s_val,
Expand Down Expand Up @@ -3918,6 +3920,21 @@ bool ShardedOnTooManyMeshAxes(const HloModule& module) {
return false;
}

bool HasUnsupportedNestedTuples(const HloModule& module) {
for (const auto* computation : module.computations()) {
for (const auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kConditional) {
for (const HloInstruction* operand : instruction->operands()) {
if (ShapeUtil::IsNestedTuple(operand->shape())) {
return true;
}
}
}
}
}
return false;
}

std::unique_ptr<HloModule> CloneModule(const HloModule* module) {
auto module_clone = module->Clone("");
module_clone->set_layout_canonicalization_callback(
Expand All @@ -3938,15 +3955,25 @@ absl::StatusOr<bool> AutoSharding::Run(

if (IsModuleManuallySharded(module)) {
LOG(FATAL)
<< "Auto-sharding on partially manually sharded modules is not yet "
"supported. Please fall back on the sharding propagation pass.";
<< "Auto-sharding on partially manually sharded modules " // Crash OK
"is not yet supported. Please fall back on the sharding "
"propagation pass.";
return false;
}

if (ShardedOnTooManyMeshAxes(*module)) {
LOG(FATAL) << "The input module contains sharding annotations over a mesh "
"with too many axes (>2). This case is currently not well "
"supported.";
LOG(FATAL) << "The input module contains sharding annotations " // Crash OK
"over a mesh with too many axes (>2). This case is currently "
"not well supported.";
return false;
}

// TODO(b/332951306): Remove this check once nested tuples are supported
// everywhere
if (HasUnsupportedNestedTuples(*module)) {
LOG(FATAL) << "The input module contains nested tuples " // Crash OK
"which we do not currently support well. See b/332951306 to "
"track progress on this.";
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Status CheckAliasSetCompatibility(const AliasSet& alias_set,
const HloInstructionSequence& sequence,
bool crash_on_error);

void GenerateReduceScatter(
absl::Status GenerateReduceScatter(
const HloInstructionSequence& sequence, const AliasMap& alias_map,
const InstructionDepthMap& depth_map, const StrategyMap& strategy_map,
const CostGraph& cost_graph, absl::Span<const int64_t> s_val,
Expand Down

0 comments on commit 0960128

Please sign in to comment.