Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust pipeline optimizer for 3d parallelism #31939

Merged
merged 5 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 3 additions & 24 deletions paddle/fluid/framework/pipeline_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,37 +71,16 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
const ProgramDesc& program,
const platform::Place& place) {
auto& global_block = program.Block(0);
std::map<std::string, int> param_map;
for (auto& var : global_block.AllVars()) {
if (var->Persistable()) {
param_map[var->Name()] = 1;
}
}

for (auto& var : global_block.AllVars()) {
bool is_param_grad = false;
size_t pos = 0;
// A magic suffix to indicate the merged gradient
std::string magicSuffix = std::string(kGradVarSuffix) + "@MERGED";
if ((pos = var->Name().find(magicSuffix)) != std::string::npos) {
auto prefix_name = var->Name().substr(0, pos);
if (param_map.find(prefix_name) != param_map.end()) {
is_param_grad = true;
}
}
if (var->Persistable() && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (is_param_grad && microbatch_id == 0) {
auto* ptr = minibatch_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create grad for persistable var: " << var->Name()
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable() && !is_param_grad) {
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(3) << "Create variable " << var->Name() << " for microbatch "
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr;
InitializeVariable(ptr, var->GetType());
}
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/distributed/fleet/meta_optimizers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def _add_sync_by_allreduce(block):
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_calc_stream',
inputs={'X': sync_var},
outputs={'Out': sync_var},
attrs={OP_ROLE_KEY: OpRole.Forward})

block = program.global_block()
if core.is_compiled_with_cuda():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def minimize_impl(self,
program._pipeline_opt['ring_id'] = self.start_pipeline_ring_id
program._pipeline_opt['micro_batch_size'] = self.micro_batch_size
program._pipeline_opt['schedule_mode'] = self.schedule_mode
program._pipeline_opt['use_sharding'] = False
optimize_ops, params_grads, prog_list, pp_pair, ring_map = self.wrapped_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.startup_program = orig_startup_program._pipeline_opt[
Expand Down Expand Up @@ -218,7 +219,6 @@ def _insert_allreduce_ops(self, ring_id):
grad = None
processed_param_name = set()
first_optimize_op_idx = None
add_sync_calc_stream = False
for idx, op in reversed(list(enumerate(block.ops))):
if is_backward_op(op) and not first_optimize_op_idx:
first_optimize_op_idx = idx + 1
Expand All @@ -242,15 +242,6 @@ def _insert_allreduce_ops(self, ring_id):
origin_param = origin_block.vars[op_role_var[i]]
if origin_param.is_distributed:
continue
if not add_sync_calc_stream:
add_sync_calc_stream = True
block._insert_op(
first_optimize_op_idx + offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={OP_ROLE_KEY: OpRole.Optimize})
offset += 1

block._insert_op(
first_optimize_op_idx + offset,
Expand Down
Loading