Skip to content

Commit

Permalink
[Not for land] Added changes for GPT-2 perf
Browse files Browse the repository at this point in the history
ghstack-source-id: 0cdcc964f2012f1b0c00e3eeba7eaca14e768629
Pull Request resolved: #533
  • Loading branch information
awgu committed Sep 7, 2024
1 parent ce48156 commit 628f4bd
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 16 deletions.
61 changes: 58 additions & 3 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,60 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
)


class ChunkedCE(torch.autograd.Function):
"""
Credit: https://github.com/Chillee
"""

@staticmethod
def forward(ctx, _input, weight, target, compiled=True):
CHUNK_SIZE = 8192
inp_shape = _input.shape
_input = _input.view(-1, _input.shape[-1])
target = target.view(-1)

def compute_loss(input_chunk, weight, target):
logits = torch.mm(input_chunk, weight.t())
logits = logits.float()
loss = torch.nn.functional.cross_entropy(logits, target)
return loss

grad_weight = torch.zeros_like(weight)
grad_inputs = []
loss_acc = torch.zeros((), device=_input.device)

chunks = max(_input.shape[0] // CHUNK_SIZE, 1)

def accumulate_chunk(input_chunk, target_chunk):
(chunk_grad_input, chunk_grad_weight), chunk_loss = (
torch.func.grad_and_value(compute_loss, argnums=(0, 1))(
input_chunk, weight, target_chunk
)
)
grad_weight.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
return chunk_grad_input

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
target_chunks = torch.chunk(target, chunks=chunks, dim=0)
for input_chunk, target_chunk in zip(input_chunks, target_chunks, strict=True):
grad_inputs.append(accumulate_chunk(input_chunk, target_chunk))

ctx.save_for_backward(
(torch.cat(grad_inputs, dim=0) / chunks).view(inp_shape),
grad_weight / chunks,
)
return loss_acc / chunks

@staticmethod
def backward(ctx, grad_output):
(grad_input, grad_weight) = ctx.saved_tensors
return (grad_input, grad_weight, None, None)


class Attention(nn.Module):
"""
Multi-head attention module.
Expand Down Expand Up @@ -421,7 +475,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.rope_theta,
)

def forward(self, tokens: torch.Tensor):
def forward(self, tokens: torch.Tensor, labels: torch.Tensor):
"""
Perform a forward pass through the Transformer model.
Expand All @@ -439,8 +493,9 @@ def forward(self, tokens: torch.Tensor):
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
return output
if not self.output:
return h
return ChunkedCE.apply(h, self.output.weight, labels)

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
Expand Down
8 changes: 4 additions & 4 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,15 @@ def apply_fsdp(
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
# For small models (e.g. GPT-2), parameter memory is low, so there
# is no need to reshard after forward
reshard_after_forward = False
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
fully_shard(model, **fsdp_config)

logger.info("Applied FSDP to the model")

Expand Down
13 changes: 8 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import time
from datetime import timedelta
from typing import List

import torch
from torch.distributed.elastic.multiprocessing.errors import record
Expand Down Expand Up @@ -136,6 +137,12 @@ def main(job_config: JobConfig):

# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
if isinstance(pred, torch.Tensor):
pred_chunks = pred.chunk(token_chunked_cross_entropy_loss.num_chunks, dim=1)
else:
assert isinstance(pred, list)
pred_chunks = pred
return token_chunked_cross_entropy_loss(pred_chunks, labels)
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)
Expand Down Expand Up @@ -288,11 +295,7 @@ def loss_fn(pred, labels):
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss = model(input_ids, labels)
loss.backward()

# clip gradients
Expand Down
8 changes: 4 additions & 4 deletions train_configs/gpt2.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ dump_folder = "./outputs"
description = "GPT-2 training"

[profiling]
enable_profiling = false
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
profile_freq = 50

[metrics]
log_freq = 10
Expand All @@ -27,11 +27,11 @@ lr = 3e-4
fused = true

[training]
batch_size = 16
batch_size = 32
seq_len = 8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 100
steps = 50
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = true
Expand Down

0 comments on commit 628f4bd

Please sign in to comment.