Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dump memory snapshot to analyze OOMs #395

Merged
merged 9 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def __init__(self):
default=10,
help="How often to collect profiler traces, in iterations",
)
self.parser.add_argument(
"--profiling.enable_memory_snapshot",
action="store_true",
default=False,
help="Whether to dump memory snapshot",
)
self.parser.add_argument(
"--profiling.save_memory_snapshot_folder",
type=str,
default="memory_snapshot",
help="Memeory snapshot files location",
)

# metrics configs
self.parser.add_argument(
Expand Down
58 changes: 58 additions & 0 deletions torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import contextlib
import os
import pickle
import time

import torch
Expand All @@ -15,6 +16,9 @@
# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
Copy link
Contributor Author

@weifengpy weifengpy Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MEMORY_SNAPSHOT_MAX_ENTRIES controls how large .pickle can be. Right now it's 36MB



@contextlib.contextmanager
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0):
Expand Down Expand Up @@ -70,3 +74,57 @@ def trace_handler(prof):
else:
torch_profiler = contextlib.nullcontext()
yield None


@contextlib.contextmanager
def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0):
enable_snapshot = config.profiling.enable_memory_snapshot
if enable_snapshot:
snapshot_folder = config.profiling.save_memory_snapshot_folder
snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder)
if not os.path.exists(snapshot_dir):
os.makedirs(snapshot_dir, exist_ok=True)
rank = torch.distributed.get_rank()

class MemoryProfiler:
def __init__(self, step_num: int, freq: int):
torch.cuda.memory._record_memory_history(
max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES
)
# when resume training, we start from the last step
self.step_num = step_num
self.freq = freq

def step(self, exit_ctx: bool = False):
self.step_num += 1
if not exit_ctx and self.step_num % self.freq != 0:
return
if not exit_ctx:
curr_step = self.step_num
dir_name = f"iteration_{curr_step}"
else:
# dump as iteration_0_exit if OOM at iter 1
curr_step = self.step_num - 1
dir_name = f"iteration_{curr_step}_exit"
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name)
if not os.path.exists(curr_snapshot_dir):
os.makedirs(curr_snapshot_dir, exist_ok=True)
logger.info(f"Dumping memory snapshot at step {curr_step}")
begin = time.monotonic()
with open(
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb"
) as output:
pickle.dump(torch.cuda.memory._snapshot(), output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a threshold to control dumping the memory snapshot when the memory usage is larger than the threshold to avoid overwhelming data?

Copy link
Contributor Author

@weifengpy weifengpy Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean in MB threashold? Right now it's bounded by number of free/allocate MEMORY_SNAPSHOT_MAX_ENTRIES. For MB, I can google around

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

googled for MB threshold but did not find something useful. Currently MEMORY_SNAPSHOT_MAX_ENTRIES=100000 conroled the file size to 36MB. Let me know if this is still a blocker

logger.info(
f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds"
)
torch.distributed.barrier()

logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}")
profiler = MemoryProfiler(global_step, config.profiling.profile_freq)
try:
yield profiler
except torch.OutOfMemoryError as e:
profiler.step(exit_ctx=True)
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
else:
yield None
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ParallelDims,
)
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule
from torchtitan.profiling import maybe_enable_profiling
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
from torchtitan.utils import (
Color,
dist_max,
Expand Down Expand Up @@ -308,7 +308,9 @@ def loss_fn(pred, labels):
logger.info(f"Training starts at step {train_state.step + 1}")
with maybe_enable_profiling(
job_config, global_step=train_state.step
) as torch_profiler:
) as torch_profiler, maybe_enable_memory_snapshot(
job_config, global_step=train_state.step
) as memory_profiler:
while train_state.step < job_config.training.steps:
train_state.step += 1
if train_state.step > 1 and train_state.step % _gc_freq == 0:
Expand Down Expand Up @@ -445,6 +447,9 @@ def loss_fn(pred, labels):
if torch_profiler:
torch_profiler.step()

if memory_profiler:
memory_profiler.step()

# Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished)
if train_state.step == 1:
set_pg_timeouts(
Expand Down
2 changes: 2 additions & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use_for_integration_test = true
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's put the folder here as well to be consistent and informative

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added save_memory_snapshot_folder in .toml

save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
Expand Down
Loading