Skip to content

Commit

Permalink
dump memory snapshot to analyze OOMs (pytorch#395)
Browse files Browse the repository at this point in the history
when setting `enable_memory_snapshot = true` in `.toml`
* dump memory snapshots in case of OOMs. output folder is
`memory_snapshot/iteration_x_exit`
* dump regularly according to `profile_freq`. output folder is
`memory_snapshot/iteration_x`
* existing `.toml` works since `enable_memory_snapshot=False` by default

snapshot is an example of the dump when OOM happens

<img width="1640" alt="Screenshot 2024-06-12 at 9 26 53 PM"
src="https://github.com/pytorch/torchtitan/assets/134637289/6420799c-ae68-4b35-b8bb-f5b6ab3dd053">
  • Loading branch information
weifengpy authored Jun 19, 2024
1 parent efdac9b commit 4fef6d6
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
12 changes: 12 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def __init__(self):
default=10,
help="How often to collect profiler traces, in iterations",
)
self.parser.add_argument(
"--profiling.enable_memory_snapshot",
action="store_true",
default=False,
help="Whether to dump memory snapshot",
)
self.parser.add_argument(
"--profiling.save_memory_snapshot_folder",
type=str,
default="memory_snapshot",
help="Memeory snapshot files location",
)

# metrics configs
self.parser.add_argument(
Expand Down
58 changes: 58 additions & 0 deletions torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import contextlib
import os
import pickle
import time

import torch
Expand All @@ -15,6 +16,9 @@
# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000


@contextlib.contextmanager
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0):
Expand Down Expand Up @@ -70,3 +74,57 @@ def trace_handler(prof):
else:
torch_profiler = contextlib.nullcontext()
yield None


@contextlib.contextmanager
def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0):
enable_snapshot = config.profiling.enable_memory_snapshot
if enable_snapshot:
snapshot_folder = config.profiling.save_memory_snapshot_folder
snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder)
if not os.path.exists(snapshot_dir):
os.makedirs(snapshot_dir, exist_ok=True)
rank = torch.distributed.get_rank()

class MemoryProfiler:
def __init__(self, step_num: int, freq: int):
torch.cuda.memory._record_memory_history(
max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES
)
# when resume training, we start from the last step
self.step_num = step_num
self.freq = freq

def step(self, exit_ctx: bool = False):
self.step_num += 1
if not exit_ctx and self.step_num % self.freq != 0:
return
if not exit_ctx:
curr_step = self.step_num
dir_name = f"iteration_{curr_step}"
else:
# dump as iteration_0_exit if OOM at iter 1
curr_step = self.step_num - 1
dir_name = f"iteration_{curr_step}_exit"
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name)
if not os.path.exists(curr_snapshot_dir):
os.makedirs(curr_snapshot_dir, exist_ok=True)
logger.info(f"Dumping memory snapshot at step {curr_step}")
begin = time.monotonic()
with open(
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb"
) as output:
pickle.dump(torch.cuda.memory._snapshot(), output)
logger.info(
f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds"
)
torch.distributed.barrier()

logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}")
profiler = MemoryProfiler(global_step, config.profiling.profile_freq)
try:
yield profiler
except torch.OutOfMemoryError as e:
profiler.step(exit_ctx=True)
else:
yield None
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ParallelDims,
)
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule
from torchtitan.profiling import maybe_enable_profiling
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
from torchtitan.utils import (
Color,
dist_max,
Expand Down Expand Up @@ -339,7 +339,9 @@ def loss_fn(pred, labels):
logger.info(f"Training starts at step {train_state.step + 1}")
with maybe_enable_profiling(
job_config, global_step=train_state.step
) as torch_profiler:
) as torch_profiler, maybe_enable_memory_snapshot(
job_config, global_step=train_state.step
) as memory_profiler:
while train_state.step < job_config.training.steps:
train_state.step += 1
if train_state.step > 1 and train_state.step % _gc_freq == 0:
Expand Down Expand Up @@ -477,6 +479,9 @@ def loss_fn(pred, labels):
if torch_profiler:
torch_profiler.step()

if memory_profiler:
memory_profiler.step()

# Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished)
if train_state.step == 1:
set_pg_timeouts(
Expand Down
2 changes: 2 additions & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use_for_integration_test = true
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = true
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
Expand Down

0 comments on commit 4fef6d6

Please sign in to comment.