diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 9e9c7b526782..9bbe863d34ff 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -796,6 +796,12 @@ def __init__( drop_path_rate: float = 0, use_emha: bool = False, ub_tp_comm_overlap: bool = False, + ub_bulk_wgrad: bool = True, + ub_bulk_dgrad: bool = True, + ub_split_ag: bool = True, + ub_split_rs: bool = True, + ub_atomic_gemm_ag: bool = False, + ub_atomic_gemm_rs: bool = False, autocast_dtype: Any = 16, zero_centered_gamma: bool = False, device: str = 'cuda', @@ -828,6 +834,12 @@ def __init__( fuse_qkv_params=True, zero_centered_gamma=zero_centered_gamma, ub_tp_comm_overlap=ub_tp_comm_overlap, + ub_bulk_wgrad=ub_bulk_wgrad, + ub_bulk_dgrad=ub_bulk_dgrad, + ub_split_ag=ub_split_ag, + ub_split_rs=ub_split_rs, + ub_atomic_gemm_ag=ub_atomic_gemm_ag, + ub_atomic_gemm_rs=ub_atomic_gemm_rs, device=device, ) # use_emha=use_emha, @@ -1076,6 +1088,12 @@ def build_layer(layer_number): autocast_dtype=precision, use_emha=use_emha, ub_tp_comm_overlap=ub_tp_comm_overlap, + ub_bulk_wgrad=config.tp_comm_bulk_wgrad, + ub_bulk_dgrad=config.tp_comm_bulk_dgrad, + ub_split_ag=config.tp_comm_split_ag, + ub_split_rs=config.tp_comm_split_rs, + ub_atomic_gemm_ag=config.tp_comm_atomic_ag, + ub_atomic_gemm_rs=config.tp_comm_atomic_rs, zero_centered_gamma=normalization == 'layernorm1p', device='cpu' if config.use_cpu_initialization else 'cuda', )