diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 936020c5..0eeac026 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -100,6 +100,18 @@ def __init__(self): default=10, help="How often to collect profiler traces, in iterations", ) + self.parser.add_argument( + "--profiling.enable_memory_snapshot", + action="store_true", + default=False, + help="Whether to dump memory snapshot", + ) + self.parser.add_argument( + "--profiling.save_memory_snapshot_folder", + type=str, + default="memory_snapshot", + help="Memeory snapshot files location", + ) # metrics configs self.parser.add_argument( diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index b4c2b2e0..c993a74f 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -6,6 +6,7 @@ import contextlib import os +import pickle import time import torch @@ -15,6 +16,9 @@ # the number of warmup steps before the active step in each profiling cycle WARMUP = 3 +# how much memory allocation/free ops to record in memory snapshots +MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 + @contextlib.contextmanager def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): @@ -70,3 +74,57 @@ def trace_handler(prof): else: torch_profiler = contextlib.nullcontext() yield None + + +@contextlib.contextmanager +def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): + enable_snapshot = config.profiling.enable_memory_snapshot + if enable_snapshot: + snapshot_folder = config.profiling.save_memory_snapshot_folder + snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) + if not os.path.exists(snapshot_dir): + os.makedirs(snapshot_dir, exist_ok=True) + rank = torch.distributed.get_rank() + + class MemoryProfiler: + def __init__(self, step_num: int, freq: int): + torch.cuda.memory._record_memory_history( + max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES + ) + # when resume training, we start from the last step + self.step_num = step_num + self.freq = freq + + def step(self, exit_ctx: bool = False): + self.step_num += 1 + if not exit_ctx and self.step_num % self.freq != 0: + return + if not exit_ctx: + curr_step = self.step_num + dir_name = f"iteration_{curr_step}" + else: + # dump as iteration_0_exit if OOM at iter 1 + curr_step = self.step_num - 1 + dir_name = f"iteration_{curr_step}_exit" + curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) + if not os.path.exists(curr_snapshot_dir): + os.makedirs(curr_snapshot_dir, exist_ok=True) + logger.info(f"Dumping memory snapshot at step {curr_step}") + begin = time.monotonic() + with open( + f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb" + ) as output: + pickle.dump(torch.cuda.memory._snapshot(), output) + logger.info( + f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" + ) + torch.distributed.barrier() + + logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") + profiler = MemoryProfiler(global_step, config.profiling.profile_freq) + try: + yield profiler + except torch.OutOfMemoryError as e: + profiler.step(exit_ctx=True) + else: + yield None diff --git a/train.py b/train.py index 8bdd8934..64a50990 100644 --- a/train.py +++ b/train.py @@ -38,7 +38,7 @@ ParallelDims, ) from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule -from torchtitan.profiling import maybe_enable_profiling +from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling from torchtitan.utils import ( Color, dist_max, @@ -339,7 +339,9 @@ def loss_fn(pred, labels): logger.info(f"Training starts at step {train_state.step + 1}") with maybe_enable_profiling( job_config, global_step=train_state.step - ) as torch_profiler: + ) as torch_profiler, maybe_enable_memory_snapshot( + job_config, global_step=train_state.step + ) as memory_profiler: while train_state.step < job_config.training.steps: train_state.step += 1 if train_state.step > 1 and train_state.step % _gc_freq == 0: @@ -477,6 +479,9 @@ def loss_fn(pred, labels): if torch_profiler: torch_profiler.step() + if memory_profiler: + memory_profiler.step() + # Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished) if train_state.step == 1: set_pg_timeouts( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 5d7e9987..0f0794a1 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -9,6 +9,8 @@ use_for_integration_test = true enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1