Skip to content

Commit

Permalink
Add TP comm overlap knobs to AutocastTransformerLayer (#8290)
Browse files Browse the repository at this point in the history
Signed-off-by: Jaemin Choi <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
Signed-off-by: ataghibakhsh <[email protected]>
  • Loading branch information
2 people authored and JRD971000 committed Feb 16, 2024
1 parent ebfa075 commit f1646e0
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,12 @@ def __init__(
drop_path_rate: float = 0,
use_emha: bool = False,
ub_tp_comm_overlap: bool = False,
ub_bulk_wgrad: bool = True,
ub_bulk_dgrad: bool = True,
ub_split_ag: bool = True,
ub_split_rs: bool = True,
ub_atomic_gemm_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
autocast_dtype: Any = 16,
zero_centered_gamma: bool = False,
device: str = 'cuda',
Expand Down Expand Up @@ -828,6 +834,12 @@ def __init__(
fuse_qkv_params=True,
zero_centered_gamma=zero_centered_gamma,
ub_tp_comm_overlap=ub_tp_comm_overlap,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
ub_split_rs=ub_split_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
device=device,
)
# use_emha=use_emha,
Expand Down Expand Up @@ -1076,6 +1088,12 @@ def build_layer(layer_number):
autocast_dtype=precision,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
ub_bulk_wgrad=config.tp_comm_bulk_wgrad,
ub_bulk_dgrad=config.tp_comm_bulk_dgrad,
ub_split_ag=config.tp_comm_split_ag,
ub_split_rs=config.tp_comm_split_rs,
ub_atomic_gemm_ag=config.tp_comm_atomic_ag,
ub_atomic_gemm_rs=config.tp_comm_atomic_rs,
zero_centered_gamma=normalization == 'layernorm1p',
device='cpu' if config.use_cpu_initialization else 'cuda',
)
Expand Down

0 comments on commit f1646e0

Please sign in to comment.