Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] Added changes for GPT-2 perf #533

Draft
wants to merge 5 commits into
base: gh/awgu/15/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,61 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
)


class ChunkedCE(torch.autograd.Function):
"""
Credit: https://github.com/Chillee
"""

@staticmethod
def forward(ctx, _input, weight, target, compiled=True):
CHUNK_SIZE = 8192
inp_shape = _input.shape
_input = _input.view(-1, _input.shape[-1])
target = target.view(-1)

def compute_loss(input_chunk, weight, target):
logits = torch.mm(input_chunk, weight.t())
logits = logits.float()
loss = torch.nn.functional.cross_entropy(logits, target)
return loss

grad_weight = torch.zeros_like(weight)
grad_inputs = []
loss_acc = torch.zeros((), device=_input.device)

chunks = max(_input.shape[0] // CHUNK_SIZE, 1)

def accumulate_chunk(input_chunk, target_chunk):
(
chunk_grad_input,
chunk_grad_weight,
), chunk_loss = torch.func.grad_and_value(compute_loss, argnums=(0, 1))(
input_chunk, weight, target_chunk
)
grad_weight.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
return chunk_grad_input

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
target_chunks = torch.chunk(target, chunks=chunks, dim=0)
for input_chunk, target_chunk in zip(input_chunks, target_chunks):
grad_inputs.append(accumulate_chunk(input_chunk, target_chunk))

ctx.save_for_backward(
(torch.cat(grad_inputs, dim=0) / chunks).view(inp_shape),
grad_weight / chunks,
)
return loss_acc / chunks

@staticmethod
def backward(ctx, grad_output):
(grad_input, grad_weight) = ctx.saved_tensors
return (grad_input, grad_weight, None, None)


class Attention(nn.Module):
"""
Multi-head attention module.
Expand Down Expand Up @@ -421,7 +476,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.rope_theta,
)

def forward(self, tokens: torch.Tensor):
def forward(self, tokens: torch.Tensor, labels: torch.Tensor):
"""
Perform a forward pass through the Transformer model.

Expand All @@ -439,8 +494,9 @@ def forward(self, tokens: torch.Tensor):
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
return output
if not self.output:
return h
return ChunkedCE.apply(h, self.output.weight, labels)

@classmethod
def from_model_args(cls, model_args: ModelArgs) -> "Transformer":
Expand Down
8 changes: 4 additions & 4 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,15 @@ def apply_fsdp(
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
# For small models (e.g. GPT-2), parameter memory is low, so there
# is no need to reshard after forward
reshard_after_forward = False
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
fully_shard(model, **fsdp_config)

logger.info("Applied FSDP to the model")

Expand Down
12 changes: 7 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ def main(job_config: JobConfig):

# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
if isinstance(pred, torch.Tensor):
pred_chunks = pred.chunk(token_chunked_cross_entropy_loss.num_chunks, dim=1)
else:
assert isinstance(pred, list)
pred_chunks = pred
return token_chunked_cross_entropy_loss(pred_chunks, labels)
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)
Expand Down Expand Up @@ -288,11 +294,7 @@ def loss_fn(pred, labels):
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss = model(input_ids, labels)
loss.backward()

# clip gradients
Expand Down
8 changes: 4 additions & 4 deletions train_configs/gpt2.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ dump_folder = "./outputs"
description = "GPT-2 training"

[profiling]
enable_profiling = false
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
profile_freq = 50

[metrics]
log_freq = 10
Expand All @@ -27,11 +27,11 @@ lr = 3e-4
fused = true

[training]
batch_size = 16
batch_size = 32
seq_len = 8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 100
steps = 50
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = true
Expand Down
Loading