Skip to content

Commit

Permalink
context parallelism (NVIDIA#7739)
Browse files Browse the repository at this point in the history
* make nemo recognize sequence_parallel_size

Signed-off-by: xren <[email protected]>

* add helper functions to set up SP running in TE

Signed-off-by: xren <[email protected]>

* slice seq length for a specific rank

Signed-off-by: Xiaowei Ren <[email protected]>

* fix data_parallel_size calculation

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* add missing argument of self

Signed-off-by: Xiaowei Ren <[email protected]>

* pass sp_global_ranks to TE transformer layer

Signed-off-by: Xiaowei Ren <[email protected]>

* fix nsys setting

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_len calculation

Signed-off-by: xren <[email protected]>

* fix attn_mask split across seq-length dim

Signed-off-by: xren <[email protected]>

* code update of input split

Signed-off-by: xren <[email protected]>

* fix loss calculation

Signed-off-by: xren <[email protected]>

* fix loss_mask_sum calculation

Signed-off-by: xren <[email protected]>

* fix losss calculation

Signed-off-by: xren <[email protected]>

* rename sequence_parallelism to context_parallelism

Signed-off-by: xren <[email protected]>

* minor change

Signed-off-by: xren <[email protected]>

* fix loss_mask_sum calculation

Signed-off-by: xren <[email protected]>

* make sure do not call megatron-core parallel_state while cp_size is 1

Signed-off-by: xren <[email protected]>

* slice position embedding for different CP rank

Signed-off-by: xren <[email protected]>

* fix mising property decorator

Signed-off-by: xren <[email protected]>

* typo fix

Signed-off-by: xren <[email protected]>

* fix rpe_bias CP slicing

Signed-off-by: xren <[email protected]>

* code style fix

Signed-off-by: xren <[email protected]>

* fix loss_mask_sum calculation

Signed-off-by: xren <[email protected]>

* do not load attention mask if it's not needed

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: xren <[email protected]>

* fix ubuf size with CP > 1

Signed-off-by: Xiaowei Ren <[email protected]>

* address naming confusion of mixed dp and cp

Signed-off-by: xren <[email protected]>

* rewrite cp code by assuming with_context_parallel=False

Signed-off-by: xren <[email protected]>

* pop context_parallel from dist opt kwargs

Signed-off-by: xren <[email protected]>

* make sure amax reduction group is aware of context parallelism

Signed-off-by: xren <[email protected]>

* remove use_fp8 from initialize_model_parallel

Signed-off-by: xren <[email protected]>

* make implementaitons of setup_transformer_engine_tp_groups and setup_transformer_engine_cp_running consistent

Signed-off-by: xren <[email protected]>

* cp function renaming

Signed-off-by: xren <[email protected]>

* make loss logging broadcast aware of cp

Signed-off-by: xren <[email protected]>

* fix a typo

Signed-off-by: Xiaowei Ren <[email protected]>

* var name fix

Signed-off-by: Xiaowei Ren <[email protected]>

* import transformer layer specs from MCore

Signed-off-by: Xiaowei Ren <[email protected]>

* upgrade MCore version

Signed-off-by: Xiaowei Ren <[email protected]>

* add add context_parallel into the kwargs of dist opt

Signed-off-by: Xiaowei Ren <[email protected]>

* remove redundant cp check

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* code style fix

Signed-off-by: Xiaowei Ren <[email protected]>

* recover docker file

Signed-off-by: Xiaowei Ren <[email protected]>

* fix seq_length of CP

Signed-off-by: Xiaowei Ren <[email protected]>

* recover seq-length which has been fixed in mcore

Signed-off-by: Xiaowei Ren <[email protected]>

* function name fix

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: xren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
xrennvidia and pre-commit-ci[bot] authored Jan 10, 2024
1 parent 76a712a commit 58d6bce
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1330,7 +1330,7 @@ def get_samples_mapping(
)
torch.distributed.barrier()
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True))
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size()
Expand Down
48 changes: 33 additions & 15 deletions nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def __init__(
self.indexed_dataset = indexed_dataset
self.drop_last = drop_last
self.seq_length = seq_length
self.get_attention_mask_from_fusion = cfg.get('get_attention_mask_from_fusion', True)

# Checks
assert np.min(documents) >= 0
Expand Down Expand Up @@ -433,13 +434,21 @@ def __getitem__(self, idx):
logging.debug('Got negative index. Masking loss from this sample')
loss_mask = torch.zeros_like(loss_mask)

return {
'tokens': tokens,
'labels': labels,
'attention_mask': attention_mask,
'loss_mask': loss_mask,
'position_ids': position_ids,
}
if self.get_attention_mask_from_fusion:
return {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'position_ids': position_ids,
}
else:
return {
'tokens': tokens,
'labels': labels,
'attention_mask': attention_mask,
'loss_mask': loss_mask,
'position_ids': position_ids,
}


class MockGPTDataset(Dataset):
Expand All @@ -457,6 +466,7 @@ def __init__(
self.vocab_size = tokenizer.vocab_size
self.length = num_samples
self.seed = seed
self.get_attention_mask_from_fusion = cfg.get('get_attention_mask_from_fusion', True)

self.attention_mask = torch.tril(torch.ones((self.seq_length, self.seq_length))).unsqueeze(0)
self.attention_mask = self.attention_mask < 0.5
Expand All @@ -476,13 +486,21 @@ def __getitem__(self, idx):
tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64))
labels = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64))

return {
'tokens': tokens,
'labels': labels,
'attention_mask': self.attention_mask,
'loss_mask': self.loss_mask,
'position_ids': self.position_ids,
}
if self.get_attention_mask_from_fusion:
return {
'tokens': tokens,
'labels': labels,
'loss_mask': self.loss_mask,
'position_ids': self.position_ids,
}
else:
return {
'tokens': tokens,
'labels': labels,
'attention_mask': self.attention_mask,
'loss_mask': self.loss_mask,
'position_ids': self.position_ids,
}


@torch.no_grad()
Expand Down Expand Up @@ -674,7 +692,7 @@ def _build_index_mappings(

torch.distributed.barrier()
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True))
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1),
virtual_pipeline_model_parallel_size=vp_size,
pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0),
context_parallel_size=cfg.get('context_parallel_size', 1),
micro_batch_size=cfg.get('micro_batch_size'),
global_batch_size=cfg.get('global_batch_size'),
rampup_batch_size=cfg.get('rampup_batch_size', None),
Expand Down Expand Up @@ -231,6 +232,27 @@ def setup_transformer_engine_tp_groups(self):
tp_group = parallel_state.get_tensor_model_parallel_group()
child.set_tensor_parallel_group(tp_group)

def setup_transformer_engine_cp_groups(self):
""" This should be called after context parallel groups have been initialized
and only needs to be called when using Transformer Engine.
"""
cp_stream = torch.cuda.Stream()

for module in self.get_model_module_list():
"""Set context parallel running
Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(module.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(
parallel_state.get_context_parallel_group(),
parallel_state.get_context_parallel_global_ranks(),
cp_stream,
)

def _wrap_model_for_O2(self):
""" Wraps self.model in a float16 wrapper if the model is using megatron amp O2.
Args:
Expand Down Expand Up @@ -556,8 +578,10 @@ def allreduce_gradients(self):
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = torch._utils._flatten_dense_tensors(grads)
coalesced /= parallel_state.get_data_parallel_world_size()
torch.distributed.all_reduce(coalesced, group=parallel_state.get_data_parallel_group())
coalesced /= parallel_state.get_data_parallel_world_size(with_context_parallel=True)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_data_parallel_group(with_context_parallel=True)
)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

Expand Down Expand Up @@ -633,7 +657,6 @@ def setup_optimization(
):
optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy()
if self.with_distributed_adam:

# Allocate contiguous buffer to avoid extra copies
optim_kwargs['contiguous_grad_buffer'] = True

Expand Down
62 changes: 50 additions & 12 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# Convert the global-batch-based profile index to micro-batch index
if hasattr(self, '_nsys_profile_enabled'):
mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1)
data_parallel_world_size = trainer.world_size // mp_size
cp_size = cfg.get('context_parallel_size', 1)
data_parallel_world_size = trainer.world_size // (mp_size * cp_size)
grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size)
self._nsys_profile_start_step *= grad_accum_steps
self._nsys_profile_end_step *= grad_accum_steps
Expand Down Expand Up @@ -553,7 +554,9 @@ def initialize_ub_func(self):
)

input_shape = [
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('encoder_seq_length')
* self.cfg.get('micro_batch_size')
// self.cfg.get('context_parallel_size', 1),
self.cfg.get('hidden_size'),
]

Expand Down Expand Up @@ -834,6 +837,32 @@ def __next__(self):
# TODO @tmoon: Use once available in Megatron-LM
# return DataIteratorList(iters)

def get_batch_on_this_context_parallel_rank(self, batch):
cp_size = self.cfg.get('context_parallel_size', 1)
num_valid_tokens_in_ub = None
if 'loss_mask' in batch and batch['loss_mask'] is not None:
num_valid_tokens_in_ub = batch['loss_mask'].sum()

if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val

batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub

return batch

def get_forward_output_and_loss_func(self, validation_step=False):
def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None):

Expand All @@ -852,15 +881,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
required_keys.update(('tokens', 'position_ids'))
if parallel_state.is_pipeline_last_stage():
required_keys.update(('labels', 'loss_mask'))
if self.get_attention_mask_from_fusion:
if self.get_attention_mask_from_fusion and 'attention_mask' in required_keys:
required_keys.remove('attention_mask')
batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()}

batch = self.get_batch_on_this_context_parallel_rank(batch)

# Model forward pass
forward_args = {
'input_ids': batch['tokens'],
'position_ids': batch['position_ids'],
'attention_mask': batch['attention_mask'],
'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'],
'labels': batch['labels'],
'loss_mask': batch['loss_mask'],
}
Expand All @@ -885,9 +916,10 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_

def loss_func(output_tensor):
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(batch['loss_mask'], output_tensor)
loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor)
cp_size = self.cfg.get('context_parallel_size', 1)
if validation_step and not self.cfg.data.get('validation_drop_last', True):
num_valid_tokens_in_ub = batch['loss_mask'].sum()
num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub']
if loss_for_ub.isnan():
assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input'
loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub)
Expand All @@ -904,10 +936,10 @@ def loss_func(output_tensor):
torch.distributed.all_reduce(
loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group()
)
return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu}
return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu}
else:
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
return loss_for_ub, {'avg': reduced_loss}
return loss_for_ub * cp_size, {'avg': reduced_loss}

return output_tensor, loss_func

Expand Down Expand Up @@ -1007,10 +1039,11 @@ def on_validation_epoch_end(self):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.loss_broadcast_src_rank is None:
dp_size = parallel_state.get_data_parallel_world_size()
cp_size = parallel_state.get_context_parallel_world_size()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * tp_size)
last_pipeline_stage_offset = (tp_size * dp_size) * (pp_size - 1)
rank_in_dp_tp_group = torch.distributed.get_rank() % (dp_size * cp_size * tp_size)
last_pipeline_stage_offset = (tp_size * cp_size * dp_size) * (pp_size - 1)
self.loss_broadcast_src_rank = last_pipeline_stage_offset + rank_in_dp_tp_group
torch.distributed.broadcast(
averaged_loss, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(),
Expand All @@ -1029,11 +1062,14 @@ def on_test_epoch_end(self):
logging.info(f'test_loss: {averaged_loss[0]}')
self.test_step_outputs.clear() # free memory

def loss_func(self, loss_mask, output_tensor):
def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
# TODO: add nemo version here
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll
cp_size = self.cfg.get('context_parallel_size', 1)
if cp_size > 1:
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
return loss

def build_train_valid_test_datasets(self):
Expand Down Expand Up @@ -1185,6 +1221,7 @@ def setup(self, stage=None):

if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_gpt', False):
self.setup_transformer_engine_tp_groups()
self.setup_transformer_engine_cp_groups()

def setup_training_data(self, cfg):
if hasattr(self, '_train_ds'):
Expand Down Expand Up @@ -1243,6 +1280,7 @@ def dummy():

if self.cfg.get('transformer_engine', False):
self.setup_transformer_engine_tp_groups()
self.setup_transformer_engine_cp_groups()

# set the default sampling params if it is None.
# default do greedy sampling
Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/nlp/modules/common/megatron/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def build_model(
i = torch.cuda.current_device()
model = [
torch.nn.parallel.distributed.DistributedDataParallel(
model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(),
model_module,
device_ids=[i],
output_device=i,
process_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
for model_module in model
]
Expand Down
25 changes: 25 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.sequence_parallel = config.sequence_parallel
self.context_parallel = parallel_state.get_context_parallel_world_size() > 1
if kv_channels is None:

assert (
Expand Down Expand Up @@ -722,6 +723,19 @@ def set_input_tensor(self, input_tensor):

self.encoder.set_input_tensor(input_tensor[0])

def get_position_embedding_on_this_context_parallel_rank(self, position_embedding, seq_dim):
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=position_embedding.device)
position_embedding = position_embedding.view(
*position_embedding.shape[:seq_dim], 2 * cp_size, -1, *position_embedding.shape[(seq_dim + 1) :]
)
position_embedding = position_embedding.index_select(seq_dim, cp_idx)
position_embedding = position_embedding.view(
*position_embedding.shape[:seq_dim], -1, *position_embedding.shape[(seq_dim + 2) :]
)
return position_embedding

def forward(
self,
enc_input_ids,
Expand Down Expand Up @@ -775,10 +789,16 @@ def forward(
else:
enc_seq_length = encoder_input.size(0)

if self.context_parallel:
enc_seq_length = enc_seq_length * parallel_state.get_context_parallel_world_size()

rotary_pos_emb = None
encoder_self_attention_relative_position_bias = None
if self.position_embedding_type == 'rope':
rotary_pos_emb = self.rotary_pos_emb(enc_seq_length)

if self.context_parallel:
rotary_pos_emb = self.get_position_embedding_on_this_context_parallel_rank(rotary_pos_emb, 0)
elif (
self.position_embedding_type == 'alibi'
or self.position_embedding_type == 'sandwich'
Expand All @@ -790,6 +810,11 @@ def forward(
# causal attention bias: [1, head, 1, k]
# non-causal attention bias: [1, head, q, k]

if self.context_parallel and encoder_self_attention_relative_position_bias.shape[-2] > 1:
encoder_self_attention_relative_position_bias = self.get_position_embedding_on_this_context_parallel_rank(
encoder_self_attention_relative_position_bias, 2
)

# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(
Expand Down
Loading

0 comments on commit 58d6bce

Please sign in to comment.