-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support of DDP and CompiledAutograd. #319
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,7 +187,20 @@ def __init__(self): | |
"--training.data_parallel_degree", | ||
type=int, | ||
default=-1, | ||
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.", | ||
help="Data Parallelism degree (FSDP). -1 means leftover ranks will be used (After SP/PP/replicate). 1 means disabled.", | ||
) | ||
self.parser.add_argument( | ||
"--training.data_parallel_replicate_degree", | ||
type=int, | ||
default=1, | ||
help=""" | ||
Data Parallelism with parameters being replicated degree. 1 means disabled. | ||
If data_parallel_degree is > 1 and data_parallel_replicate_degree > 1, | ||
the parallelism is HSDP. HSDP is not yet neabled and but will be supported soon. | ||
When data_parallel_degree is -1 and data_parallel_replicate_degree > 1, | ||
the parallelism is DDP. DDP should only be used for small model as | ||
DDP + TP is not yet supported. | ||
""", | ||
) | ||
self.parser.add_argument( | ||
"--training.tensor_parallel_degree", | ||
|
@@ -210,7 +223,16 @@ def __init__(self): | |
self.parser.add_argument( | ||
"--training.compile", | ||
action="store_true", | ||
help="Whether to compile the model", | ||
help="Whether to compile the model.", | ||
) | ||
self.parser.add_argument( | ||
"--training.compiled_autograd", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be added to the experimental space IMO |
||
action="store_true", | ||
help=""" | ||
Whether to use CompiledAutograd to trace the backward. | ||
This is an experimental feature and should not be used | ||
unless you are familiar with CompiledAutograd. | ||
""", | ||
) | ||
self.parser.add_argument( | ||
"--training.fp8_linear", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
@dataclass | ||
class ParallelDims: | ||
dp: int | ||
dp_replicate: int | ||
tp: int | ||
pp: int | ||
world_size: int | ||
|
@@ -29,21 +30,27 @@ def __post_init__(self): | |
self._validate() | ||
|
||
def _validate(self): | ||
dp, tp, pp = self.dp, self.tp, self.pp | ||
dp, dp_replicate, tp, pp = self.dp, self.dp_replicate, self.tp, self.pp | ||
if dp == -1: | ||
self.dp = dp = self.world_size // (tp * pp) | ||
self.dp = dp = self.world_size // (dp_replicate * tp * pp) | ||
assert dp >= 1, dp | ||
assert dp_replicate >= 1, dp_replicate | ||
assert tp >= 1, tp | ||
assert pp >= 1, pp | ||
assert ( | ||
dp * tp * pp == self.world_size | ||
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" | ||
dp * dp_replicate * tp * pp == self.world_size | ||
), ( | ||
f"Invalid parallel dims: dp({dp}) * dp_replicate({dp_replicate}) * " | ||
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})." | ||
) | ||
|
||
def build_mesh(self, device_type): | ||
dims = [] | ||
names = [] | ||
for d, name in zip( | ||
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True | ||
[self.pp, self.dp_replicate, self.dp, self.tp], | ||
["pp", "dp_replicate", "dp", "tp"], | ||
strict=True | ||
): | ||
if d > 1: | ||
dims.append(d) | ||
|
@@ -56,6 +63,10 @@ def build_mesh(self, device_type): | |
def dp_enabled(self): | ||
return self.dp > 1 | ||
|
||
@property | ||
def dp_replicate_enabled(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: the comments should be addressed together. |
||
return self.dp_replicate > 1 | ||
|
||
@property | ||
def tp_enabled(self): | ||
return self.tp > 1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,10 @@ | |
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy | ||
from torch.distributed._composable.replicate import replicate | ||
from torch.distributed._tensor import Replicate, Shard | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
checkpoint_wrapper as ptd_checkpoint_wrapper, | ||
|
@@ -129,7 +131,56 @@ def get_tp_parallel_strategy( | |
return RowwiseParallel, ColwiseParallel | ||
|
||
|
||
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | ||
def maybe_enable_activation_checkpoint( | ||
model: nn.Module, job_config: JobConfig | ||
) -> nn.Module: | ||
config = job_config.activation_checkpoint | ||
ac_mode = config.mode | ||
if ac_mode in ("full", "selective"): | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
model.layers[layer_id] = checkpoint_wrapper(transformer_block, config) | ||
logger.info(f"Applied {ac_mode} activation checkpointing to the model") | ||
|
||
return model | ||
|
||
|
||
def enable_fsdp(model: nn.Module, dp_mesh, job_config: JobConfig) -> nn.Module: | ||
# TODO: Expose `reduce_dtype` as a config option. | ||
mp_policy = MixedPrecisionPolicy( | ||
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 | ||
) | ||
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
# As an optimization, do not reshard after forward for the last | ||
# transformer block since FSDP would prefetch it immediately | ||
reshard_after_forward = layer_id < len(model.layers) - 1 | ||
fully_shard( | ||
transformer_block, | ||
**fsdp_config, | ||
reshard_after_forward=reshard_after_forward, | ||
) | ||
model.layers[layer_id] = transformer_block | ||
model = fully_shard(model, **fsdp_config) | ||
logger.info("Applied FSDP to the model") | ||
|
||
return model | ||
|
||
|
||
def enable_ddp(model: nn.Module, dp_mesh, job_config: JobConfig) -> nn.Module: | ||
if job_config.training.compile: | ||
if job_config.training.compiled_autograd: | ||
torch._dynamo.config.optimize_ddp = "python_reducer" | ||
else: | ||
torch._dynamo.config.optimize_ddp = "ddp_optimizer" | ||
model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) | ||
logger.info("Applied DDP to the model") | ||
|
||
return model | ||
|
||
|
||
def parallelize_llama( | ||
model: nn.Module, world_mesh, parallel_dims, job_config: JobConfig | ||
) -> nn.Module: | ||
""" | ||
Apply parallelisms and activation checkpointing to the model. | ||
|
||
|
@@ -144,6 +195,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
raise NotImplementedError( | ||
"fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm." | ||
) | ||
if parallel_dims.dp_replicate_enabled: | ||
raise NotImplementedError("DDP/HSDP + TP are not supported yet.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make DDP + TP work and see if it could support llama3_8b or llama2_7b. If not, we could try to import other models instead of Llama, and have DDP to apply to that model instead :) |
||
|
||
tp_mesh = world_mesh["tp"] | ||
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( | ||
|
@@ -206,32 +259,15 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
|
||
logger.info("Applied Tensor Parallelism to the model") | ||
|
||
model = maybe_enable_activation_checkpoint(model, job_config) | ||
if parallel_dims.dp_enabled: | ||
if parallel_dims.dp_replicate_enabled: | ||
raise NotImplementedError("HSDP is not supported yet.") | ||
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh | ||
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names | ||
# TODO: Expose `reduce_dtype` as a config option. | ||
mp_policy = MixedPrecisionPolicy( | ||
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 | ||
) | ||
ac_mode = job_config.activation_checkpoint.mode | ||
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
if job_config.activation_checkpoint.mode in ("full", "selective"): | ||
transformer_block = checkpoint_wrapper( | ||
transformer_block, job_config.activation_checkpoint | ||
) | ||
# As an optimization, do not reshard after forward for the last | ||
# transformer block since FSDP would prefetch it immediately | ||
reshard_after_forward = layer_id < len(model.layers) - 1 | ||
fully_shard( | ||
transformer_block, | ||
**fsdp_config, | ||
reshard_after_forward=reshard_after_forward, | ||
) | ||
model.layers[layer_id] = transformer_block | ||
model = fully_shard(model, **fsdp_config) | ||
if ac_mode in ("full", "selective"): | ||
logger.info(f"Applied {ac_mode} activation checkpointing to the model") | ||
logger.info("Applied FSDP to the model") | ||
model = enable_fsdp(model, dp_mesh, job_config) | ||
elif parallel_dims.dp_replicate_enabled: | ||
dp_mesh = world_mesh["dp_replicate"] if world_mesh.ndim > 1 else world_mesh | ||
model = enable_ddp(model, dp_mesh, job_config) | ||
|
||
return model |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# TorchTrain Config.toml | ||
[job] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since there's no official 1b model size for both llama 2/3 release, and the toml files are user facing, It would be better if we only add released model sizes. |
||
dump_folder = "./outputs" | ||
description = "LLaMA 1B training" | ||
|
||
[profiling] | ||
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
profile_freq = 100 | ||
|
||
[metrics] | ||
log_freq = 10 | ||
enable_tensorboard = true | ||
save_tb_folder = "tb" | ||
|
||
[model] | ||
name = "llama2" | ||
flavor = "1B" | ||
norm_type = "fused_rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] | ||
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model" | ||
|
||
[optimizer] | ||
name = "AdamW" | ||
lr = 1.5e-4 | ||
|
||
[training] | ||
batch_size = 8 | ||
seq_len = 1024 | ||
warmup_steps = 200 # lr scheduler warm up | ||
max_norm = 1.0 # grad norm clipping | ||
steps = 1000 | ||
data_parallel_degree = -1 | ||
tensor_parallel_degree = 1 | ||
pipeline_parallel_degree = 1 | ||
fp8_linear = "" | ||
compile = false | ||
dataset = "c4" | ||
|
||
[activation_checkpoint] | ||
mode = "none" # ['none', 'full', 'selective'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a different suggestion here after some thoughts:
data_parallel_replicate_degree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion here. I also thought about using
mode
as well. If that makes sense to people, I can change it to that.