Skip to content

Commit

Permalink
[BE][1/n] simplify train.py
Browse files Browse the repository at this point in the history
ghstack-source-id: bdadd8da3821090f430b78aea8193b20bccf1528
Pull Request resolved: #494
  • Loading branch information
tianyu-l committed Jul 31, 2024
1 parent b069f70 commit 3e456bf
Show file tree
Hide file tree
Showing 19 changed files with 220 additions and 228 deletions.
45 changes: 42 additions & 3 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import re
import shutil
import time
from dataclasses import dataclass, field
from io import BytesIO
from multiprocessing import get_context
from typing import Any, Dict, List, Union

Expand All @@ -27,7 +29,7 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import init_logger, logger
from torchtitan.logging import init_logger, logger


class IntervalType(enum.Enum):
Expand All @@ -41,6 +43,43 @@ class AsyncMode(str, enum.Enum):
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"


@dataclass
class TrainState(Stateful):
step: int = 0
global_avg_losses: List[float] = field(default_factory=list)
global_max_losses: List[float] = field(default_factory=list)
log_steps: List[int] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
# Only checkpoint global_avg_losses and global_max_losses per log frequency
# to avoid sync overhead in every iteration.
global_avg_losses_bytes = BytesIO()
torch.save(self.global_avg_losses, global_avg_losses_bytes)
global_max_losses_bytes = BytesIO()
torch.save(self.global_max_losses, global_max_losses_bytes)
log_steps_bytes = BytesIO()
torch.save(self.log_steps, log_steps_bytes)
return {
"step": torch.tensor(self.step, dtype=torch.int32),
"global_avg_losses": global_avg_losses_bytes,
"global_max_losses": global_max_losses_bytes,
"log_steps": log_steps_bytes,
}

def load_state_dict(self, state_dict) -> None:
self.step = state_dict["step"].item()
state_dict["global_avg_losses"].seek(0)
self.global_avg_losses = torch.load(
state_dict["global_avg_losses"], weights_only=False
)
state_dict["global_max_losses"].seek(0)
self.global_max_losses = torch.load(
state_dict["global_max_losses"], weights_only=False
)
state_dict["log_steps"].seek(0)
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)


class ModelWrapper(Stateful):
def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
Expand Down Expand Up @@ -124,10 +163,10 @@ def checkpoint_mp(recv, send):
class CheckpointManager:
def __init__(
self,
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: List[torch.optim.Optimizer],
lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler],
dataloader: DataLoader,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand Down Expand Up @@ -390,7 +429,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
f"in {time.monotonic() - begin:.2f} seconds."
)

def wait_for_staging(self) -> None:
def maybe_wait_for_staging(self) -> None:
if (
self.enable_checkpoint
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
except ModuleNotFoundError:
import tomli as tomllib

from torchtitan.logging_utils import logger
from torchtitan.logging import logger

TORCH_DTYPE_MAP = {
"float16": torch.float16,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import create_tokenizer
from torchtitan.datasets.tokenizer import build_tokenizer

__all__ = [
"build_hf_data_loader",
"create_tokenizer",
"build_tokenizer",
]
2 changes: 1 addition & 1 deletion torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
) from e

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
from torchtitan.logging import logger

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/datasets/tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer
from torchtitan.datasets.tokenizer.tokenizer import Tokenizer

from torchtitan.logging_utils import logger
from torchtitan.logging import logger


def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer:
def build_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "sentencepiece":
return SentencePieceTokenizer(tokenizer_path)
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/datasets/tokenizer/sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sentencepiece import SentencePieceProcessor

from torchtitan.datasets.tokenizer.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
from torchtitan.logging import logger


class SentencePieceTokenizer(Tokenizer):
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/datasets/tokenizer/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tiktoken.load import load_tiktoken_bpe

from torchtitan.datasets.tokenizer.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
from torchtitan.logging import logger


class TikTokenizer(Tokenizer):
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch._logging import warning_once

from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger
from torchtitan.logging import logger


@functools.lru_cache(None)
Expand Down
File renamed without changes.
32 changes: 23 additions & 9 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import torch
from torch.utils.tensorboard import SummaryWriter
from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger
from torchtitan.logging import logger
from torchtitan.parallelisms import ParallelDims

# named tuple for passing GPU memory stats for logging
GPUMemStats = namedtuple(
Expand Down Expand Up @@ -110,16 +111,29 @@ def close(self):
self.writer.close()


def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
"""
Returns global rank 0 in non-pipeline-parallel configs, and returns the global
rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled.
"""
if parallel_dims.pp_enabled:
world_size = parallel_dims.world_size
pp_size = parallel_dims.pp
metrics_log_rank = (world_mesh.size() // pp_size) * (pp_size - 1)
else:
metrics_log_rank = 0

return metrics_log_rank


def build_metric_logger(
config: JobConfig, metrics_log_rank: int = 0, tag: Optional[str] = None
config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
):
"""
metrics_log_rank controls which rank acts as 'rank 0' for logging metrics.
If 'tb_config.rank_0_only' is set, then `metrics_log_rank` will be used as the rank to log metrics.
This is intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
parallelism is enabled, without forcing logging from all ranks to capture loss information when using pipeline
parallelism.
parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is
intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
parallelism is enabled, without forcing logging from all ranks to capture loss information.
"""
dump_dir = config.job.dump_folder
tb_config = config.metrics
Expand All @@ -134,7 +148,7 @@ def build_metric_logger(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
)
if tb_config.rank_0_only:
enable_tb = torch.distributed.get_rank() == metrics_log_rank
enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)
else:
rank_str = f"rank_{torch.distributed.get_rank()}"
log_dir = os.path.join(log_dir, rank_str)
Expand Down
8 changes: 4 additions & 4 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.models.norms import create_norm
from torchtitan.models.norms import build_norm


@dataclass
Expand Down Expand Up @@ -291,10 +291,10 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
self.layer_id = layer_id
self.num_layers = model_args.n_layers

self.attention_norm = create_norm(
self.attention_norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)
self.ffn_norm = create_norm(
self.ffn_norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)

Expand Down Expand Up @@ -370,7 +370,7 @@ def __init__(self, model_args: ModelArgs):
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

self.norm = create_norm(
self.norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)

Expand Down
8 changes: 4 additions & 4 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
from torch.distributed._tensor.experimental import local_map


def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Creates the specified normalization layer based on the norm_type.
Builds the specified normalization layer based on the norm_type.
Args:
norm_type (str): The type of normalization layer to create.
norm_type (str): The type of normalization layer to build.
Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
dim (int): The dimension of the normalization layer.
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
Returns:
The created normalization layer.
The built normalization layer.
Raises:
NotImplementedError: If an unknown norm_type is provided.
Expand Down
53 changes: 50 additions & 3 deletions torchtitan/lr_scheduling.py → torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,57 @@

import functools

import torch
from torch.optim.lr_scheduler import LambdaLR
from torchtitan.config_manager import JobConfig


# consider split between PP and non-PP
def build_optimizers(model_parts, job_config: JobConfig):
"""Wrap one optimizer per model part in an OptimizersContainer which provides a single
step() and zero_grad() method for all the child optimizers.
"""

def _build_optimizer(model):
name = job_config.optimizer.name
lr = job_config.optimizer.lr
fused = job_config.optimizer.fused

# Common parameters for both optimizers
optimizer_kwargs = {
"lr": lr,
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": fused,
"foreach": not fused,
}
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
else:
raise NotImplementedError(f"Optimizer {name} not added.")

return optimizer

class OptimizersContainer:
"""Util for calling step/zero_grad on multiple optimizers needed for virtual pipeline stages"""

def __init__(self, optimizers):
self.optimizers = optimizers

def step(self):
for optimizer in self.optimizers:
optimizer.step()

def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()

return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def linear_warmup_linear_decay(
warmup_steps: int, decay_steps: int, current_step: int
) -> float:
Expand All @@ -32,8 +79,8 @@ def linear_warmup_linear_decay(
return curr_adjustment


def get_lr_schedulers(optimizers, job_config: JobConfig):
def _get_lr_scheduler(optimizer):
def build_lr_schedulers(optimizers, job_config: JobConfig):
def _build_lr_scheduler(optimizer):
"""Build a linear warmup and linear decay scheduler"""
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
Expand All @@ -54,5 +101,5 @@ def step(self):
schedulers.step()

return SchedulersContainer(
[_get_lr_scheduler(optimizer) for optimizer in optimizers]
[_build_lr_scheduler(optimizer) for optimizer in optimizers]
)
11 changes: 10 additions & 1 deletion torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@
from functools import cached_property

from torch.distributed.device_mesh import init_device_mesh
from torchtitan.logging_utils import logger
from torchtitan.logging import logger
from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule


__all__ = [
"build_pipeline_schedule",
"models_parallelize_fns",
"models_pipelining_fns",
"ParallelDims",
]

models_parallelize_fns = {
"llama2": parallelize_llama,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank

Expand Down
4 changes: 2 additions & 2 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
ScheduleGPipe,
ScheduleInterleaved1F1B,
)
from torchtitan.logging_utils import logger
from torchtitan.logging import logger


def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):

looped_schedule = False

if job_config.experimental.pipeline_parallel_schedule == "1f1b":
schedule_class = Schedule1F1B
elif job_config.experimental.pipeline_parallel_schedule == "gpipe":
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger
from torchtitan.logging import logger

# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3
Expand Down
Loading

0 comments on commit 3e456bf

Please sign in to comment.