Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dump memory snapshot to analyze OOMs #395

Merged
merged 9 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import contextlib
import os
import pickle
import time

import torch
Expand All @@ -15,6 +16,14 @@
# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
Copy link
Contributor Author

@weifengpy weifengpy Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MEMORY_SNAPSHOT_MAX_ENTRIES controls how large .pickle can be. Right now it's 36MB


# default memory snapshot folder
ENABLE_MEMORY_SNAPSHOT_KEY = "enable_memory_snapshot"
MEMORY_SNAPSHOT_FOLDER_KEY = "memory_snapshot_folder"
MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE = "memory_snapshot"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make these into configs. Please refer to how torch_profiler does this part, e.g. put into config_manager.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to know config_manager.py. I will move deafults into config_manager



@contextlib.contextmanager
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0):
Expand Down Expand Up @@ -70,3 +79,62 @@ def trace_handler(prof):
else:
torch_profiler = contextlib.nullcontext()
yield None


@contextlib.contextmanager
def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0):
enable_snapshot = getattr(config.profiling, ENABLE_MEMORY_SNAPSHOT_KEY, False)
if enable_snapshot:
snapshot_folder = getattr(
config.profiling,
MEMORY_SNAPSHOT_FOLDER_KEY,
MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE,
)
snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder)
if not os.path.exists(snapshot_dir):
os.makedirs(snapshot_dir, exist_ok=True)
rank = torch.distributed.get_rank()

class MemoryProfiler:
def __init__(self, step_num: int, freq: int):
torch.cuda.memory._record_memory_history(
max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES
)
# when resume training, we start from the last step
self.step_num = step_num
self.freq = freq

def step(self, exit_ctx: bool = False):
if not exit_ctx and self.step_num % self.freq != 0:
self.step_num += 1
return
if not exit_ctx:
curr_step = self.step_num
self.step_num += 1
dir_name = f"iteration_{curr_step}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.profiler starts from step 0, whereas train.py starts from step 1. In order to make things work as expected, I suggest we do the following, so that if we set profile_freq=10 and run training for 10 steps, there will be memory snapshots for iteration_10 (similar to torch.profiler) and iteration_10_exit. I've tested this offline.

Suggested change
if not exit_ctx and self.step_num % self.freq != 0:
self.step_num += 1
return
if not exit_ctx:
curr_step = self.step_num
self.step_num += 1
dir_name = f"iteration_{curr_step}"
self.step_num += 1
if not exit_ctx and self.step_num % self.freq != 0:
return
if not exit_ctx:
curr_step = self.step_num
dir_name = f"iteration_{curr_step}"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing out the difference. updated accordingly

else:
curr_step = self.step_num - 1
dir_name = f"iteration_{curr_step}_exit"
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name)
if not os.path.exists(curr_snapshot_dir):
os.makedirs(curr_snapshot_dir, exist_ok=True)
logger.info(f"Dumping memory snapshot at step {curr_step}")
begin = time.monotonic()
with open(
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb"
) as output:
pickle.dump(torch.cuda.memory._snapshot(), output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a threshold to control dumping the memory snapshot when the memory usage is larger than the threshold to avoid overwhelming data?

Copy link
Contributor Author

@weifengpy weifengpy Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean in MB threashold? Right now it's bounded by number of free/allocate MEMORY_SNAPSHOT_MAX_ENTRIES. For MB, I can google around

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

googled for MB threshold but did not find something useful. Currently MEMORY_SNAPSHOT_MAX_ENTRIES=100000 conroled the file size to 36MB. Let me know if this is still a blocker

logger.info(
f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds"
)
torch.distributed.barrier()

logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}")
profiler = MemoryProfiler(global_step, config.profiling.profile_freq)
try:
yield profiler
finally:
# dump snapshot when CUDA OOMs
profiler.step(exit_ctx=True)
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
else:
yield None
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ParallelDims,
)
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule
from torchtitan.profiling import maybe_enable_profiling
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
from torchtitan.utils import (
Color,
dist_max,
Expand Down Expand Up @@ -301,7 +301,9 @@ def loss_fn(pred, labels):
logger.info(f"Training starts at step {train_state.step + 1}")
with maybe_enable_profiling(
job_config, global_step=train_state.step
) as torch_profiler:
) as torch_profiler, maybe_enable_memory_snapshot(
job_config, global_step=train_state.step
) as memory_profiler:
checkpoint.reset()

# variables used to keep info for metrics logging
Expand Down Expand Up @@ -447,6 +449,9 @@ def loss_fn(pred, labels):
if torch_profiler:
torch_profiler.step()

if memory_profiler:
memory_profiler.step()

# Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished)
if train_state.step == 1:
set_pg_timeouts(
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use_for_integration_test = true
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
Copy link
Contributor Author

@weifengpy weifengpy Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing .toml without enable_memory_snapshot still works. enable_memory_snapshot is optional with getattr(config.profiling, 'enable_memory_snapshot', False) I am just adding it here so people can start toggle it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: we should put default option False into config_manager, and remove this option in all the toml config files. Maybe only enable it to True in debug_model.


[metrics]
log_freq = 1
Expand Down
1 change: 1 addition & 0 deletions train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ description = "Llama2 13B training"
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
enable_memory_snapshot = false

[metrics]
log_freq = 10
Expand Down
1 change: 1 addition & 0 deletions train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ description = "Llama2 70B training"
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
enable_memory_snapshot = false

[metrics]
log_freq = 10
Expand Down
1 change: 1 addition & 0 deletions train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ description = "Llama2 7B training"
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
enable_memory_snapshot = false

[metrics]
log_freq = 10
Expand Down
2 changes: 2 additions & 0 deletions train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ description = "Llama 3 70B training"
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
enable_memory_snapshot = false


[metrics]
log_freq = 10
Expand Down
1 change: 1 addition & 0 deletions train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ description = "Llama 3 8B training"
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
enable_memory_snapshot = false

[metrics]
log_freq = 10
Expand Down
Loading