Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dump memory snapshot to analyze OOMs #395

Merged
merged 9 commits into from
Jun 19, 2024
Merged

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Jun 13, 2024

when setting enable_memory_snapshot = true in .toml

  • dump memory snapshots in case of OOMs. output folder is memory_snapshot/iteration_x_exit
  • dump regularly according to profile_freq. output folder is memory_snapshot/iteration_x
  • existing .toml works since enable_memory_snapshot=False by default

snapshot is an example of the dump when OOM happens

Screenshot 2024-06-12 at 9 26 53 PM

weifengpy and others added 3 commits June 12, 2024 11:09
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 13, 2024
@@ -9,6 +9,7 @@ use_for_integration_test = true
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
Copy link
Contributor Author

@weifengpy weifengpy Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing .toml without enable_memory_snapshot still works. enable_memory_snapshot is optional with getattr(config.profiling, 'enable_memory_snapshot', False) I am just adding it here so people can start toggle it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: we should put default option False into config_manager, and remove this option in all the toml config files. Maybe only enable it to True in debug_model.

@@ -15,6 +16,14 @@
# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
Copy link
Contributor Author

@weifengpy weifengpy Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MEMORY_SNAPSHOT_MAX_ENTRIES controls how large .pickle can be. Right now it's 36MB

with open(
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb"
) as output:
pickle.dump(torch.cuda.memory._snapshot(), output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a threshold to control dumping the memory snapshot when the memory usage is larger than the threshold to avoid overwhelming data?

Copy link
Contributor Author

@weifengpy weifengpy Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean in MB threashold? Right now it's bounded by number of free/allocate MEMORY_SNAPSHOT_MAX_ENTRIES. For MB, I can google around

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

googled for MB threshold but did not find something useful. Currently MEMORY_SNAPSHOT_MAX_ENTRIES=100000 conroled the file size to 36MB. Let me know if this is still a blocker

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great addition to torchtitan! Had some comments on how to structure the configs.

Also, I wonder if it makes sense to have a very short tutorial on how to read/parse the output of memory profiler. Maybe extract part of this tutorial.

Comment on lines 22 to 25
# default memory snapshot folder
ENABLE_MEMORY_SNAPSHOT_KEY = "enable_memory_snapshot"
MEMORY_SNAPSHOT_FOLDER_KEY = "memory_snapshot_folder"
MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE = "memory_snapshot"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make these into configs. Please refer to how torch_profiler does this part, e.g. put into config_manager.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to know config_manager.py. I will move deafults into config_manager

@@ -9,6 +9,7 @@ use_for_integration_test = true
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: we should put default option False into config_manager, and remove this option in all the toml config files. Maybe only enable it to True in debug_model.

@weifengpy
Copy link
Contributor Author

convert to draft now and will publish again after moving default into config_manager.py. But current version is good for benchmarking float8 + compile + fsdp2 on MAST

@weifengpy weifengpy marked this pull request as draft June 13, 2024 21:34
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overhead from torch.profiler is only around the steps where dumping actually happens (warmup steps + actual profiling steps). If _record_memory_history is always enabled for entire training, there will constantly be overhead from this memory profiler.

In other words, torch.profiler only profile one step per freq steps, while MemoryProfiler profiles every step and dump all freq iterations per freq steps. As a result, adjusting freq only affects how often the snapshot are grouped into one pickle file. If we run a job 3000 steps, there will be snapshot for every step, regardless of freq.

Comment on lines 108 to 114
if not exit_ctx and self.step_num % self.freq != 0:
self.step_num += 1
return
if not exit_ctx:
curr_step = self.step_num
self.step_num += 1
dir_name = f"iteration_{curr_step}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.profiler starts from step 0, whereas train.py starts from step 1. In order to make things work as expected, I suggest we do the following, so that if we set profile_freq=10 and run training for 10 steps, there will be memory snapshots for iteration_10 (similar to torch.profiler) and iteration_10_exit. I've tested this offline.

Suggested change
if not exit_ctx and self.step_num % self.freq != 0:
self.step_num += 1
return
if not exit_ctx:
curr_step = self.step_num
self.step_num += 1
dir_name = f"iteration_{curr_step}"
self.step_num += 1
if not exit_ctx and self.step_num % self.freq != 0:
return
if not exit_ctx:
curr_step = self.step_num
dir_name = f"iteration_{curr_step}"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing out the difference. updated accordingly

weifengpy and others added 4 commits June 17, 2024 17:03
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy marked this pull request as ready for review June 18, 2024 16:14
@weifengpy weifengpy requested a review from tianyu-l June 18, 2024 16:14
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work!! thank you!
please address my nits before merge :)

@@ -9,6 +9,7 @@ use_for_integration_test = true
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
enable_memory_snapshot = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's put the folder here as well to be consistent and informative

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added save_memory_snapshot_folder in .toml

help="Whether to dump memory snapshot",
)
self.parser.add_argument(
"--profiling.memory_snapshot_folder",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please rename it save_memory_snapshot_folder to be consistent with save_traces_folder and save_tb_folder.

self.parser.add_argument(
"--profiling.memory_snapshot_folder",
type=str,
default="memory_snapshots",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's call it memory_snapshot

@weifengpy
Copy link
Contributor Author

great work!! thank you! please address my nits before merge :)

thanks. will address feedback before merging

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy merged commit 8adbfa3 into pytorch:main Jun 19, 2024
5 checks passed
tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
when setting `enable_memory_snapshot = true` in `.toml`
* dump memory snapshots in case of OOMs. output folder is
`memory_snapshot/iteration_x_exit`
* dump regularly according to `profile_freq`. output folder is
`memory_snapshot/iteration_x`
* existing `.toml` works since `enable_memory_snapshot=False` by default

snapshot is an example of the dump when OOM happens

<img width="1640" alt="Screenshot 2024-06-12 at 9 26 53 PM"
src="https://github.com/pytorch/torchtitan/assets/134637289/6420799c-ae68-4b35-b8bb-f5b6ab3dd053">
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
when setting `enable_memory_snapshot = true` in `.toml`
* dump memory snapshots in case of OOMs. output folder is
`memory_snapshot/iteration_x_exit`
* dump regularly according to `profile_freq`. output folder is
`memory_snapshot/iteration_x`
* existing `.toml` works since `enable_memory_snapshot=False` by default

snapshot is an example of the dump when OOM happens

<img width="1640" alt="Screenshot 2024-06-12 at 9 26 53 PM"
src="https://github.com/pytorch/torchtitan/assets/134637289/6420799c-ae68-4b35-b8bb-f5b6ab3dd053">
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants