From 2b5cbbf5e1c8e3ae895b387c41ad032c90e19505 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Thu, 15 Feb 2024 14:36:23 -0800 Subject: [PATCH] Add TP comm overlap knobs to AutocastTransformerLayer (#8290) Signed-off-by: Jaemin Choi Co-authored-by: Jaemin Choi --- .../nlp/modules/common/megatron/transformer.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index ca8c0ecafefd..036f44b4bcca 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -797,6 +797,12 @@ def __init__( drop_path_rate: float = 0, use_emha: bool = False, ub_tp_comm_overlap: bool = False, + ub_bulk_wgrad: bool = True, + ub_bulk_dgrad: bool = True, + ub_split_ag: bool = True, + ub_split_rs: bool = True, + ub_atomic_gemm_ag: bool = False, + ub_atomic_gemm_rs: bool = False, autocast_dtype: Any = 16, zero_centered_gamma: bool = False, device: str = 'cuda', @@ -829,6 +835,12 @@ def __init__( fuse_qkv_params=True, zero_centered_gamma=zero_centered_gamma, ub_tp_comm_overlap=ub_tp_comm_overlap, + ub_bulk_wgrad=ub_bulk_wgrad, + ub_bulk_dgrad=ub_bulk_dgrad, + ub_split_ag=ub_split_ag, + ub_split_rs=ub_split_rs, + ub_atomic_gemm_ag=ub_atomic_gemm_ag, + ub_atomic_gemm_rs=ub_atomic_gemm_rs, device=device, ) # use_emha=use_emha, @@ -1077,6 +1089,12 @@ def build_layer(layer_number): autocast_dtype=precision, use_emha=use_emha, ub_tp_comm_overlap=ub_tp_comm_overlap, + ub_bulk_wgrad=config.tp_comm_bulk_wgrad, + ub_bulk_dgrad=config.tp_comm_bulk_dgrad, + ub_split_ag=config.tp_comm_split_ag, + ub_split_rs=config.tp_comm_split_rs, + ub_atomic_gemm_ag=config.tp_comm_atomic_ag, + ub_atomic_gemm_rs=config.tp_comm_atomic_rs, zero_centered_gamma=normalization == 'layernorm1p', device='cpu' if config.use_cpu_initialization else 'cuda', )