-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dump memory snapshot to analyze OOMs #395
Changes from 4 commits
6cd4853
ae76243
510e9f8
76ab55d
87aaf9e
48e8bc8
e7a3b08
0639da9
abec6fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,6 +6,7 @@ | |||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
import contextlib | ||||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||||
import pickle | ||||||||||||||||||||||||||||
import time | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||
|
@@ -15,6 +16,14 @@ | |||||||||||||||||||||||||||
# the number of warmup steps before the active step in each profiling cycle | ||||||||||||||||||||||||||||
WARMUP = 3 | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# how much memory allocation/free ops to record in memory snapshots | ||||||||||||||||||||||||||||
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# default memory snapshot folder | ||||||||||||||||||||||||||||
ENABLE_MEMORY_SNAPSHOT_KEY = "enable_memory_snapshot" | ||||||||||||||||||||||||||||
MEMORY_SNAPSHOT_FOLDER_KEY = "memory_snapshot_folder" | ||||||||||||||||||||||||||||
MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE = "memory_snapshot" | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make these into configs. Please refer to how torch_profiler does this part, e.g. put into config_manager.py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good to know config_manager.py. I will move deafults into |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
@contextlib.contextmanager | ||||||||||||||||||||||||||||
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): | ||||||||||||||||||||||||||||
|
@@ -70,3 +79,62 @@ def trace_handler(prof): | |||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
torch_profiler = contextlib.nullcontext() | ||||||||||||||||||||||||||||
yield None | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
@contextlib.contextmanager | ||||||||||||||||||||||||||||
def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): | ||||||||||||||||||||||||||||
enable_snapshot = getattr(config.profiling, ENABLE_MEMORY_SNAPSHOT_KEY, False) | ||||||||||||||||||||||||||||
if enable_snapshot: | ||||||||||||||||||||||||||||
snapshot_folder = getattr( | ||||||||||||||||||||||||||||
config.profiling, | ||||||||||||||||||||||||||||
MEMORY_SNAPSHOT_FOLDER_KEY, | ||||||||||||||||||||||||||||
MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) | ||||||||||||||||||||||||||||
if not os.path.exists(snapshot_dir): | ||||||||||||||||||||||||||||
os.makedirs(snapshot_dir, exist_ok=True) | ||||||||||||||||||||||||||||
rank = torch.distributed.get_rank() | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
class MemoryProfiler: | ||||||||||||||||||||||||||||
def __init__(self, step_num: int, freq: int): | ||||||||||||||||||||||||||||
torch.cuda.memory._record_memory_history( | ||||||||||||||||||||||||||||
max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
# when resume training, we start from the last step | ||||||||||||||||||||||||||||
self.step_num = step_num | ||||||||||||||||||||||||||||
self.freq = freq | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def step(self, exit_ctx: bool = False): | ||||||||||||||||||||||||||||
if not exit_ctx and self.step_num % self.freq != 0: | ||||||||||||||||||||||||||||
self.step_num += 1 | ||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||
if not exit_ctx: | ||||||||||||||||||||||||||||
curr_step = self.step_num | ||||||||||||||||||||||||||||
self.step_num += 1 | ||||||||||||||||||||||||||||
dir_name = f"iteration_{curr_step}" | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch.profiler starts from step 0, whereas train.py starts from step 1. In order to make things work as expected, I suggest we do the following, so that if we set
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for pointing out the difference. updated accordingly |
||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
curr_step = self.step_num - 1 | ||||||||||||||||||||||||||||
dir_name = f"iteration_{curr_step}_exit" | ||||||||||||||||||||||||||||
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) | ||||||||||||||||||||||||||||
if not os.path.exists(curr_snapshot_dir): | ||||||||||||||||||||||||||||
os.makedirs(curr_snapshot_dir, exist_ok=True) | ||||||||||||||||||||||||||||
logger.info(f"Dumping memory snapshot at step {curr_step}") | ||||||||||||||||||||||||||||
begin = time.monotonic() | ||||||||||||||||||||||||||||
with open( | ||||||||||||||||||||||||||||
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb" | ||||||||||||||||||||||||||||
) as output: | ||||||||||||||||||||||||||||
pickle.dump(torch.cuda.memory._snapshot(), output) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a threshold to control dumping the memory snapshot when the memory usage is larger than the threshold to avoid overwhelming data? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mean in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. googled for |
||||||||||||||||||||||||||||
logger.info( | ||||||||||||||||||||||||||||
f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
torch.distributed.barrier() | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") | ||||||||||||||||||||||||||||
profiler = MemoryProfiler(global_step, config.profiling.profile_freq) | ||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||
yield profiler | ||||||||||||||||||||||||||||
finally: | ||||||||||||||||||||||||||||
# dump snapshot when CUDA OOMs | ||||||||||||||||||||||||||||
profiler.step(exit_ctx=True) | ||||||||||||||||||||||||||||
weifengpy marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
yield None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ use_for_integration_test = true | |
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
profile_freq = 10 | ||
enable_memory_snapshot = false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: we should put default option |
||
|
||
[metrics] | ||
log_freq = 1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MEMORY_SNAPSHOT_MAX_ENTRIES
controls how large.pickle
can be. Right now it's36MB