Skip to content

Commit

Permalink
[xla] Avoid repeatedly traversing computations in a module by process…
Browse files Browse the repository at this point in the history
…ing the

computations in post-order.

PiperOrigin-RevId: 678332958
  • Loading branch information
bixia1 authored and Google-ML-Automation committed Sep 24, 2024
1 parent ef85a7b commit bcc98dc
Showing 1 changed file with 81 additions and 82 deletions.
163 changes: 81 additions & 82 deletions xla/service/while_loop_all_reduce_code_motion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ absl::StatusOr<bool> WhileLoopAllReduceCodeMotion::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
bool is_changed = false;
bool run_next_pass = true;

// In case of MPMD, all-reduces might be cross-module and should preserve
// their channel ID. Do not move all-reduces in this case since the channel
// ID might be changed.
Expand Down Expand Up @@ -965,96 +965,95 @@ absl::StatusOr<bool> WhileLoopAllReduceCodeMotion::Run(
// loop. We recursively sink the all-reduce through nested while loops if
// applicable by repeating this process.
uint32_t count_all_reduce = 0, count_reduce_scatter = 0;
while (run_next_pass) {
run_next_pass = false;
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// We process all callees of a computation before processing the computation,
// so that when we process a computation, the all-reduce instructions that
// need to be hoisted to the computation from its callees have been hoisted.
for (HloComputation* computation :
module->MakeComputationPostOrder(execution_threads)) {
// A computation could be the while body of multiple while instructions,
// so we start from the computation and find all of its callers that is a
// kWhile if there is any.
for (HloComputation* computation :
module->computations(execution_threads)) {
std::vector<HloInstruction*> computation_callers =
call_graph->GetComputationCallers(computation);
std::vector<HloInstruction*> while_caller_instructions;
for (HloInstruction* caller_instruction : computation_callers) {
// For simplicity, we only support while instructions whose shape is
// tuple.
if (caller_instruction->opcode() == HloOpcode::kWhile &&
caller_instruction->shape().IsTuple() &&
caller_instruction->while_body() == computation) {
while_caller_instructions.push_back(caller_instruction);
}
}
// Skip to next computation if this computation is not the while body of
// any while instruction.
if (while_caller_instructions.empty()) {
continue;
std::vector<HloInstruction*> computation_callers =
call_graph->GetComputationCallers(computation);
std::vector<HloInstruction*> while_caller_instructions;
for (HloInstruction* caller_instruction : computation_callers) {
// For simplicity, we only support while instructions whose shape is
// tuple.
if (caller_instruction->opcode() == HloOpcode::kWhile &&
caller_instruction->shape().IsTuple() &&
caller_instruction->while_body() == computation) {
while_caller_instructions.push_back(caller_instruction);
}
std::vector<HloAllReduceInstructionBase*> while_body_all_reduces;
for (HloInstruction* while_body_instruction :
computation->MakeInstructionPostOrder()) {
HloOpcode op = while_body_instruction->opcode();
const bool is_candidate =
(op == HloOpcode::kAllReduce) ||
(enable_reduce_scatter_ && op == HloOpcode::kReduceScatter);
if (!is_candidate) {
continue;
}
auto* all_reduce_instruction =
Cast<HloAllReduceInstructionBase>(while_body_instruction);
if (all_reduce_instruction->constrain_layout()) {
return false;
} else {
while_body_all_reduces.push_back(all_reduce_instruction);
}
}
HloInstructionMap<std::vector<AccumulationContext>>
all_reduce_to_accumulations;
for (HloAllReduceInstructionBase* all_reduce : while_body_all_reduces) {
auto movable_all_reduce_context = IsAllReduceMovable(
all_reduce, computation, cross_replica_replication_analysis,
cross_partition_replication_analysis);
if (movable_all_reduce_context.is_movable) {
all_reduce_to_accumulations[all_reduce] =
std::move(movable_all_reduce_context.accumulation_contexts);
}
VLOG(3) << "WhileLoopAllReduceCodeMotion, all-reduce: "
<< all_reduce->ToString()
<< " is_movable: " << movable_all_reduce_context.is_movable
<< " while loop: " << while_caller_instructions.front()->name()
<< " num_accumulations: "
<< (movable_all_reduce_context.is_movable
? all_reduce_to_accumulations[all_reduce].size()
: 0);
}
if (all_reduce_to_accumulations.empty()) {
}
// Skip to next computation if this computation is not the while body of
// any while instruction.
if (while_caller_instructions.empty()) {
continue;
}
std::vector<HloAllReduceInstructionBase*> while_body_all_reduces;
for (HloInstruction* while_body_instruction :
computation->MakeInstructionPostOrder()) {
HloOpcode op = while_body_instruction->opcode();
const bool is_candidate =
(op == HloOpcode::kAllReduce) ||
(enable_reduce_scatter_ && op == HloOpcode::kReduceScatter);
if (!is_candidate) {
continue;
}
// For each while instruction calling this computation, create the
// corresponding all-reduces after the while loop.
for (HloInstruction* while_instruction : while_caller_instructions) {
TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile(
while_instruction, all_reduce_to_accumulations));
is_changed = true;
run_next_pass = true;
auto* all_reduce_instruction =
Cast<HloAllReduceInstructionBase>(while_body_instruction);
if (all_reduce_instruction->constrain_layout()) {
return false;
} else {
while_body_all_reduces.push_back(all_reduce_instruction);
}
// At last, remove the old all-reduce instructions in the while body.
for (const auto& all_reduce_accumulations_pair :
all_reduce_to_accumulations) {
HloInstruction* all_reduce = all_reduce_accumulations_pair.first;
if (all_reduce->opcode() == HloOpcode::kAllReduce) {
count_all_reduce++;
} else {
count_reduce_scatter++;
}
TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape(
all_reduce, all_reduce->mutable_operand(0)));
}
HloInstructionMap<std::vector<AccumulationContext>>
all_reduce_to_accumulations;
for (HloAllReduceInstructionBase* all_reduce : while_body_all_reduces) {
auto movable_all_reduce_context = IsAllReduceMovable(
all_reduce, computation, cross_replica_replication_analysis,
cross_partition_replication_analysis);
if (movable_all_reduce_context.is_movable) {
all_reduce_to_accumulations[all_reduce] =
std::move(movable_all_reduce_context.accumulation_contexts);
}
// Needs to rebuild the call graph or we could access removed
// instructions.
if (run_next_pass) {
break;
VLOG(3) << "WhileLoopAllReduceCodeMotion, all-reduce: "
<< all_reduce->ToString()
<< " is_movable: " << movable_all_reduce_context.is_movable
<< " while loop: " << while_caller_instructions.front()->name()
<< " num_accumulations: "
<< (movable_all_reduce_context.is_movable
? all_reduce_to_accumulations[all_reduce].size()
: 0);
}
if (all_reduce_to_accumulations.empty()) {
continue;
}
// For each while instruction calling this computation, create the
// corresponding all-reduces after the while loop.
for (HloInstruction* while_instruction : while_caller_instructions) {
TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile(
while_instruction, all_reduce_to_accumulations));
is_changed = true;
}
// At last, remove the old all-reduce instructions in the while body.
for (const auto& all_reduce_accumulations_pair :
all_reduce_to_accumulations) {
HloInstruction* all_reduce = all_reduce_accumulations_pair.first;
if (all_reduce->opcode() == HloOpcode::kAllReduce) {
count_all_reduce++;
} else {
count_reduce_scatter++;
}
TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape(
all_reduce, all_reduce->mutable_operand(0)));
}
// Needs to rebuild the call graph after we remove instructions to avoid
// accessing removed instructions.
if (!all_reduce_to_accumulations.empty()) {
call_graph = CallGraph::Build(module);
}
}
VLOG(2) << "Hoisted " << count_all_reduce << " all-reduce and "
Expand Down

0 comments on commit bcc98dc

Please sign in to comment.