Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of DDP and CompiledAutograd. #319

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,20 @@ def __init__(self):
"--training.data_parallel_degree",
type=int,
default=-1,
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
help="Data Parallelism degree (FSDP). -1 means leftover ranks will be used (After SP/PP/replicate). 1 means disabled.",
)
self.parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a different suggestion here after some thoughts:

  • we should keep the data_parallel_degree to be used by all data parallel
  • we should add a dp_mode training arg that distinguish whether to apply DDP/fSDP/HSDP, instead of data_parallel_replicate_degree
  • dp_degree -> int/tuple[int], when it's tuple of int, it must be hsdp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion here. I also thought about using mode as well. If that makes sense to people, I can change it to that.

"--training.data_parallel_replicate_degree",
type=int,
default=1,
help="""
Data Parallelism with parameters being replicated degree. 1 means disabled.
If data_parallel_degree is > 1 and data_parallel_replicate_degree > 1,
the parallelism is HSDP. HSDP is not yet neabled and but will be supported soon.
When data_parallel_degree is -1 and data_parallel_replicate_degree > 1,
the parallelism is DDP. DDP should only be used for small model as
DDP + TP is not yet supported.
""",
)
self.parser.add_argument(
"--training.tensor_parallel_degree",
Expand All @@ -210,7 +223,16 @@ def __init__(self):
self.parser.add_argument(
"--training.compile",
action="store_true",
help="Whether to compile the model",
help="Whether to compile the model.",
)
self.parser.add_argument(
"--training.compiled_autograd",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be added to the experimental space IMO

action="store_true",
help="""
Whether to use CompiledAutograd to trace the backward.
This is an experimental feature and should not be used
unless you are familiar with CompiledAutograd.
""",
)
self.parser.add_argument(
"--training.fp8_linear",
Expand Down
21 changes: 16 additions & 5 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
@dataclass
class ParallelDims:
dp: int
dp_replicate: int
tp: int
pp: int
world_size: int
Expand All @@ -29,21 +30,27 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
dp, dp_replicate, tp, pp = self.dp, self.dp_replicate, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
self.dp = dp = self.world_size // (dp_replicate * tp * pp)
assert dp >= 1, dp
assert dp_replicate >= 1, dp_replicate
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
dp * dp_replicate * tp * pp == self.world_size
), (
f"Invalid parallel dims: dp({dp}) * dp_replicate({dp_replicate}) * "
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})."
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
[self.pp, self.dp_replicate, self.dp, self.tp],
["pp", "dp_replicate", "dp", "tp"],
strict=True
):
if d > 1:
dims.append(d)
Expand All @@ -56,6 +63,10 @@ def build_mesh(self, device_type):
def dp_enabled(self):
return self.dp > 1

@property
def dp_replicate_enabled(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: the comments should be addressed together.
have dp_mode instead and reuse the dp_degree

return self.dp_replicate > 1

@property
def tp_enabled(self):
return self.tp > 1
Expand Down
86 changes: 61 additions & 25 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from typing import Tuple

import torch
import torch.nn as nn

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
Expand Down Expand Up @@ -129,7 +131,56 @@ def get_tp_parallel_strategy(
return RowwiseParallel, ColwiseParallel


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
def maybe_enable_activation_checkpoint(
model: nn.Module, job_config: JobConfig
) -> nn.Module:
config = job_config.activation_checkpoint
ac_mode = config.mode
if ac_mode in ("full", "selective"):
for layer_id, transformer_block in enumerate(model.layers):
model.layers[layer_id] = checkpoint_wrapper(transformer_block, config)
logger.info(f"Applied {ac_mode} activation checkpointing to the model")

return model


def enable_fsdp(model: nn.Module, dp_mesh, job_config: JobConfig) -> nn.Module:
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
logger.info("Applied FSDP to the model")

return model


def enable_ddp(model: nn.Module, dp_mesh, job_config: JobConfig) -> nn.Module:
if job_config.training.compile:
if job_config.training.compiled_autograd:
torch._dynamo.config.optimize_ddp = "python_reducer"
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
model = replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
logger.info("Applied DDP to the model")

return model


def parallelize_llama(
model: nn.Module, world_mesh, parallel_dims, job_config: JobConfig
) -> nn.Module:
"""
Apply parallelisms and activation checkpointing to the model.

Expand All @@ -144,6 +195,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
raise NotImplementedError(
"fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm."
)
if parallel_dims.dp_replicate_enabled:
raise NotImplementedError("DDP/HSDP + TP are not supported yet.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make DDP + TP work and see if it could support llama3_8b or llama2_7b. If not, we could try to import other models instead of Llama, and have DDP to apply to that model instead :)


tp_mesh = world_mesh["tp"]
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
Expand Down Expand Up @@ -206,32 +259,15 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

logger.info("Applied Tensor Parallelism to the model")

model = maybe_enable_activation_checkpoint(model, job_config)
if parallel_dims.dp_enabled:
if parallel_dims.dp_replicate_enabled:
raise NotImplementedError("HSDP is not supported yet.")
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.layers):
if job_config.activation_checkpoint.mode in ("full", "selective"):
transformer_block = checkpoint_wrapper(
transformer_block, job_config.activation_checkpoint
)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = layer_id < len(model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.layers[layer_id] = transformer_block
model = fully_shard(model, **fsdp_config)
if ac_mode in ("full", "selective"):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")
model = enable_fsdp(model, dp_mesh, job_config)
elif parallel_dims.dp_replicate_enabled:
dp_mesh = world_mesh["dp_replicate"] if world_mesh.ndim > 1 else world_mesh
model = enable_ddp(model, dp_mesh, job_config)

return model
12 changes: 8 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def main(job_config: JobConfig):
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.training.pipeline_parallel_degree,
world_size=world_size,
Expand Down Expand Up @@ -303,10 +304,13 @@ def loss_fn(pred, labels):
optimizer.zero_grad()

# forward / backward
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
with torch._dynamo.utils.maybe_enable_compiled_autograd(
job_config.training.compiled_autograd
):
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
Expand Down
40 changes: 40 additions & 0 deletions train_configs/llama_1b.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# TorchTrain Config.toml
[job]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there's no official 1b model size for both llama 2/3 release, and the toml files are user facing, It would be better if we only add released model sizes.

dump_folder = "./outputs"
description = "LLaMA 1B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama2"
flavor = "1B"
norm_type = "fused_rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 1.5e-4

[training]
batch_size = 8
seq_len = 1024
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4"

[activation_checkpoint]
mode = "none" # ['none', 'full', 'selective']
Loading