-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dump memory snapshot to analyze OOMs #395
Changes from all commits
6cd4853
ae76243
510e9f8
76ab55d
87aaf9e
48e8bc8
e7a3b08
0639da9
abec6fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
import contextlib | ||
import os | ||
import pickle | ||
import time | ||
|
||
import torch | ||
|
@@ -15,6 +16,9 @@ | |
# the number of warmup steps before the active step in each profiling cycle | ||
WARMUP = 3 | ||
|
||
# how much memory allocation/free ops to record in memory snapshots | ||
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 | ||
|
||
|
||
@contextlib.contextmanager | ||
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): | ||
|
@@ -70,3 +74,57 @@ def trace_handler(prof): | |
else: | ||
torch_profiler = contextlib.nullcontext() | ||
yield None | ||
|
||
|
||
@contextlib.contextmanager | ||
def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): | ||
enable_snapshot = config.profiling.enable_memory_snapshot | ||
if enable_snapshot: | ||
snapshot_folder = config.profiling.save_memory_snapshot_folder | ||
snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) | ||
if not os.path.exists(snapshot_dir): | ||
os.makedirs(snapshot_dir, exist_ok=True) | ||
rank = torch.distributed.get_rank() | ||
|
||
class MemoryProfiler: | ||
def __init__(self, step_num: int, freq: int): | ||
torch.cuda.memory._record_memory_history( | ||
max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES | ||
) | ||
# when resume training, we start from the last step | ||
self.step_num = step_num | ||
self.freq = freq | ||
|
||
def step(self, exit_ctx: bool = False): | ||
self.step_num += 1 | ||
if not exit_ctx and self.step_num % self.freq != 0: | ||
return | ||
if not exit_ctx: | ||
curr_step = self.step_num | ||
dir_name = f"iteration_{curr_step}" | ||
else: | ||
# dump as iteration_0_exit if OOM at iter 1 | ||
curr_step = self.step_num - 1 | ||
dir_name = f"iteration_{curr_step}_exit" | ||
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) | ||
if not os.path.exists(curr_snapshot_dir): | ||
os.makedirs(curr_snapshot_dir, exist_ok=True) | ||
logger.info(f"Dumping memory snapshot at step {curr_step}") | ||
begin = time.monotonic() | ||
with open( | ||
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb" | ||
) as output: | ||
pickle.dump(torch.cuda.memory._snapshot(), output) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a threshold to control dumping the memory snapshot when the memory usage is larger than the threshold to avoid overwhelming data? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. googled for |
||
logger.info( | ||
f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" | ||
) | ||
torch.distributed.barrier() | ||
|
||
logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") | ||
profiler = MemoryProfiler(global_step, config.profiling.profile_freq) | ||
try: | ||
yield profiler | ||
except torch.OutOfMemoryError as e: | ||
profiler.step(exit_ctx=True) | ||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
yield None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,8 @@ use_for_integration_test = true | |
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
profile_freq = 10 | ||
enable_memory_snapshot = true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: let's put the folder here as well to be consistent and informative There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
save_memory_snapshot_folder = "memory_snapshot" | ||
|
||
[metrics] | ||
log_freq = 1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MEMORY_SNAPSHOT_MAX_ENTRIES
controls how large.pickle
can be. Right now it's36MB