Skip to content

Commit

Permalink
[Fix] Fix DDP wrapper for symbolic computation (PaddlePaddle#675)
Browse files Browse the repository at this point in the history
* fix DDP wrapper for symbolic computation

* refine code

* fix
  • Loading branch information
HydrogenSulfate authored Dec 1, 2023
1 parent 7772ef9 commit 6a3e6c6
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions ppsci/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,22 @@ def __init__(
"Please do not wrap your model with DataParallel "
"before 'Solver.__init__' and keep it's type as 'nn.Layer'."
)
self.model = fleet.distributed_model(self.model)
if hasattr(self.model, "input_keys"):
self.model.input_keys = self.model._layers.input_keys
if hasattr(self.model, "output_keys"):
self.model.output_keys = self.model._layers.output_keys

def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
dist_model = fleet.distributed_model(model)
if hasattr(model, "input_keys"):
dist_model.input_keys = dist_model._layers.input_keys
if hasattr(model, "output_keys"):
dist_model.output_keys = dist_model._layers.output_keys
return dist_model

if isinstance(self.model, ppsci.arch.ModelList):
for i in range(len(self.model.model_list)):
# NOTE: Convert each model in model_list to DataParallel
self.model.model_list[i] = dist_wrapper(self.model.model_list[i])
else:
self.model = dist_wrapper(self.model)

if self.optimizer is not None:
self.optimizer = fleet.distributed_optimizer(self.optimizer)

Expand Down Expand Up @@ -678,12 +689,24 @@ def no_sync_context_manager(
contextlib.AbstractContextManager: Smart no_sync context manager.
"""
if enable:
if not isinstance(ddp_model, paddle.DataParallel):
raise TypeError(
"no_sync interface is only for model with type paddle.DataParallel, "
f"but got type {misc.typename(ddp_model)}"
)
ctx_manager = ddp_model.no_sync()
if isinstance(self.model, ppsci.arch.ModelList):
for model in self.model.model_list:
if not isinstance(model, paddle.DataParallel):
raise TypeError(
"no_sync interface is only for model with type "
"paddle.DataParallel, but got type "
f"{misc.typename(model)}"
)
ctx_manager = contextlib.ExitStack()
for model in self.model.model_list:
ctx_manager.enter_context(model.no_sync())
else:
if not isinstance(self.model, paddle.DataParallel):
raise TypeError(
"no_sync interface is only for model with type "
f"paddle.DataParallel, but got type {misc.typename(ddp_model)}"
)
ctx_manager = ddp_model.no_sync()
else:
ctx_manager = (
contextlib.nullcontext()
Expand Down

0 comments on commit 6a3e6c6

Please sign in to comment.