Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize communication in interleaved pipeline parallelism #331

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 202 additions & 28 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

from contextlib import contextmanager, nullcontext
from typing import Optional, List, Union, Callable, Any
import contextlib
from typing import Callable, Iterator, List, Optional, Union

import torch
from torch.autograd.variable import Variable
Expand Down Expand Up @@ -54,10 +54,11 @@ def forward_step(data_iterator, model):


data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func
passed as is to forward_step_func. Expected to be a list of
iterators in the case of interleaved pipeline parallelism.

model (required): the actual model. A torch.nn.Module or, in the
case or iterleaving, a list of torch.nn.Module
model (required): the actual model. Expected to be a list of
modules in the case of interleaved pipeline parallelism.

num_microbatches (int, required):
The number of microbatches to go through
Expand Down Expand Up @@ -93,6 +94,21 @@ def forward_step(data_iterator, model):
enable_autocast (optional, default=False): If True, runs the
forward_step_func call inside torch.autocast context

no_sync_func (optional): Function that creates a context that
suppresses asynchronous data-parallel communication. If the
model is an instance of torch.nn.DistributedDataParallel, the
default is to use torch.nn.DistributedDataParallel.no_sync.

grad_sync_func (optional): Function that launches asynchronous
gradient reductions (e.g. distributed optimizer gradient
reduce-scatters). The function should take one argument: an
iterable of parameters whose gradients are to be synchronized.

param_sync_func (optional): Function that launches asynchronous
parameter synchronizations (e.g. distributed optimizer
parameter all-gathers). The function should take one argument:
an iterable of parameters to be synchronized.

"""
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1:
Expand Down Expand Up @@ -188,7 +204,7 @@ def forward_step(forward_step_func,
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)

context_manager = torch.autocast("cuda") if enable_autocast else nullcontext()
context_manager = torch.autocast("cuda") if enable_autocast else contextlib.nullcontext()
with context_manager:
output_tensor, loss_func = forward_step_func(data_iterator, model)

Expand Down Expand Up @@ -280,17 +296,9 @@ def backward_step(grad_scaler, input_tensor, output_tensor,
return input_tensor_grad


@contextmanager
def dummy_handler():
try:
yield
finally:
pass


def forward_backward_no_pipelining(*,
forward_step_func,
data_iterator,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
dtype: Optional[torch.dtype] = None, # unused
Expand All @@ -301,7 +309,11 @@ def forward_backward_no_pipelining(*,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False,
enable_autocast: bool = False):
enable_autocast: bool = False,
no_sync_func: Optional[Callable] = None,
grad_sync_func: Optional[Callable] = None, # unused
param_sync_func: Optional[Callable] = None, # unused
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).

Expand All @@ -310,18 +322,26 @@ def forward_backward_no_pipelining(*,

See get_forward_backward_func() for argument details
"""
assert len(model) == 1
model = model[0]

context_handler = dummy_handler
if isinstance(model, torchDDP):
context_handler = model.no_sync
if isinstance(model, list):
assert len(model) == 1, \
"non-pipeline-parallel schedule does not support model chunking"
model = model[0]
if isinstance(data_iterator, list):
assert len(data_iterator) == 1, \
"non-pipeline-parallel schedule does not support model chunking"
data_iterator = data_iterator[0]

if no_sync_func is None and isinstance(model, torchDDP):
no_sync_func = model.no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext

model_type = get_model_type(model)

forward_data_store = []
input_tensor, output_tensor_grad = None, None
with context_handler():
with no_sync_func():
for i in range(num_microbatches - 1):
output_tensor = forward_step(forward_step_func, data_iterator,
model, num_microbatches, input_tensor, forward_data_store,
Expand All @@ -345,7 +365,7 @@ def forward_backward_no_pipelining(*,

def forward_backward_pipelining_with_interleaving(*,
forward_step_func,
data_iterator,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
dtype: torch.dtype,
Expand All @@ -356,11 +376,49 @@ def forward_backward_pipelining_with_interleaving(*,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False,
enable_autocast: bool = False):
enable_autocast: bool = False,
no_sync_func: Optional[Callable] = None,
grad_sync_func: Optional[Callable] = None,
param_sync_func: Optional[Callable] = None,
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.

Returns dictionary with losses if the last stage, empty dict otherwise."""
assert isinstance(model, list), \
"interleaved pipeline parallelism expected model chunking"
assert all(isinstance(chunk, torch.nn.Module) for chunk in model), \
"invalid model chunking"
assert isinstance(data_iterator, list), \
"interleaved pipeline parallelism expected each model chunk to have a data iterator"

# Disable async grad reductions
if no_sync_func is None and all(isinstance(chunk, torchDDP) for chunk in model):
def multi_no_sync():
stack = contextlib.ExitStack()
for chunk in model:
stack.enter_context(chunk.no_sync())
return stack
no_sync_func = multi_no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()

# Model chunk IDs with synchronized grads
synchronized_model_chunks = set()

input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
Expand Down Expand Up @@ -418,6 +476,11 @@ def forward_backward_pipelining_with_interleaving(*,
num_microbatches_remaining = \
total_num_microbatches - num_warmup_microbatches

# Synchronize params for first two model chunks
if param_sync_func is not None:
param_sync_func(model[0].parameters())
param_sync_func(model[1].parameters())

def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
Expand All @@ -426,13 +489,48 @@ def get_model_chunk_id(microbatch_id, forward):
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
return model_chunk_id

def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == 0:
return microbatch_id_in_group % pipeline_parallel_size == 0
else:
return False

def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the last for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == num_microbatch_groups - 1:
return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
else:
return False


def forward_step_helper(microbatch_id):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if param_sync_func is not None:
param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
if param_sync_microbatch_id < num_microbatches and is_first_microbatch_for_model_chunk(param_sync_microbatch_id):
param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
if 1 < param_sync_chunk_id < num_model_chunks:
param_sync_func(model[param_sync_chunk_id].parameters())

# forward step
if parallel_state.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
Expand Down Expand Up @@ -464,6 +562,11 @@ def backward_step_helper(microbatch_id):
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)

# launch grad synchronization (default)
if grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id)

if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
Expand All @@ -478,6 +581,20 @@ def backward_step_helper(microbatch_id):
model_type,
timers)

# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if grad_sync_func is not None:
grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(grad_sync_microbatch_id):
grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
enable_grad_sync()
grad_sync_func(model[grad_sync_chunk_id].parameters())
synchronized_model_chunks.add(grad_sync_chunk_id)
disable_grad_sync()

return input_tensor_grad

# Run warmup forward passes.
Expand Down Expand Up @@ -616,6 +733,17 @@ def backward_step_helper(microbatch_id):
tensor_shape=tensor_shape, dtype=dtype,
timers=timers))

# Launch any remaining grad reductions
enable_grad_sync()
if grad_sync_func is not None:
params = []
for model_chunk_id in range(num_model_chunks):
if model_chunk_id not in synchronized_model_chunks:
params.extend(model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id)
if params:
grad_sync_func(params)

return forward_data_store

def get_tensor_shapes(*,
Expand Down Expand Up @@ -728,7 +856,7 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, dtype, timers)

def forward_backward_pipelining_without_interleaving(*,
forward_step_func,
data_iterator,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
dtype: torch.dtype,
Expand All @@ -739,14 +867,44 @@ def forward_backward_pipelining_without_interleaving(*,
forward_only: bool = False,
timers: Callable = None,
collect_non_loss_data: bool = False,
enable_autocast: bool = False):
enable_autocast: bool = False,
no_sync_func: Optional[Callable] = None,
grad_sync_func: Optional[Callable] = None,
param_sync_func: Optional[Callable] = None, # unused
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.

Returns dictionary with losses if the last stage, empty dict otherwise."""

assert len(model) == 1
model = model[0]
if isinstance(model, list):
assert len(model) == 1, \
"non-interleaved pipeline parallelism does not support model chunking"
model = model[0]
if isinstance(data_iterator, list):
assert len(data_iterator) == 1, \
"non-pipeline-parallel schedule does not support model chunking"
data_iterator = data_iterator[0]

# Disable async grad reductions
if no_sync_func is None and isinstance(model, torchDDP):
no_sync_func = model.no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()

# Compute number of warmup microbatches.
num_warmup_microbatches = \
Expand Down Expand Up @@ -844,6 +1002,16 @@ def forward_backward_pipelining_without_interleaving(*,
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):

# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if i == num_warmup_microbatches-1:
if grad_sync_func is None or rank == 0:
enable_grad_sync()

input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)

Expand All @@ -855,4 +1023,10 @@ def forward_backward_pipelining_without_interleaving(*,

send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)

# Launch any remaining grad reductions
if no_sync_context is not None:
enable_grad_sync()
if grad_sync_func is not None:
grad_sync_func(model.parameters())

return forward_data_store