From 6cd485329ce9b2d952217d978b04d65b30ee8c20 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 12 Jun 2024 11:09:15 -0700 Subject: [PATCH 1/6] del logits=(bs, seq_len, vocab_size) to save 3.9G memory Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train.py b/train.py index 4d58f5b8..adbd975f 100644 --- a/train.py +++ b/train.py @@ -351,6 +351,9 @@ def loss_fn(pred, labels): with loss_parallel_ctx(): pred = model(input_ids) loss = loss_fn(pred, labels) + # pred.shape=(bs, seq_len, vocab_size) + # need to free to before bwd to avoid peaking memory + del pred loss.backward() # clip gradients From 510e9f80db0c52b020b0bb7ef123a1fe66494b9b Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 12 Jun 2024 21:47:40 -0700 Subject: [PATCH 2/6] dump memory snapshot to analyze OOM Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/profiling.py | 68 ++++++++++++++++++++++++++++++++++ train.py | 9 ++++- train_configs/debug_model.toml | 1 + train_configs/llama2_13b.toml | 1 + train_configs/llama2_70b.toml | 1 + train_configs/llama2_7b.toml | 1 + train_configs/llama3_70b.toml | 2 + train_configs/llama3_8b.toml | 1 + 8 files changed, 82 insertions(+), 2 deletions(-) diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index b4c2b2e0..dcb5eb1f 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -6,6 +6,7 @@ import contextlib import os +import pickle import time import torch @@ -15,6 +16,14 @@ # the number of warmup steps before the active step in each profiling cycle WARMUP = 3 +# how much memory allocation/free ops to record in memory snapshots +MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 + +# default memory snapshot folder +ENABLE_MEMORY_SNAPSHOT_KEY = "enable_memory_snapshot" +MEMORY_SNAPSHOT_FOLDER_KEY = "memory_snapshot_folder" +MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE = "memory_snapshot" + @contextlib.contextmanager def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): @@ -70,3 +79,62 @@ def trace_handler(prof): else: torch_profiler = contextlib.nullcontext() yield None + + +@contextlib.contextmanager +def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): + enable_snapshot = getattr(config.profiling, ENABLE_MEMORY_SNAPSHOT_KEY, False) + if enable_snapshot: + snapshot_folder = getattr( + config.profiling, + MEMORY_SNAPSHOT_FOLDER_KEY, + MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE, + ) + snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) + if not os.path.exists(snapshot_dir): + os.makedirs(snapshot_dir, exist_ok=True) + rank = torch.distributed.get_rank() + + class MemoryProfiler: + def __init__(self, step_num: int, freq: int): + torch.cuda.memory._record_memory_history( + max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES + ) + # when resume training, we start from the last step + self.step_num = step_num + self.freq = freq + + def step(self, exit_ctx: bool = False): + if not exit_ctx and self.step_num % self.freq != 0: + self.step_num += 1 + return + if not exit_ctx: + curr_step = self.step_num + self.step_num += 1 + dir_name = f"iteration_{curr_step}" + else: + curr_step = self.step_num - 1 + dir_name = f"iteration_{curr_step}_exit" + curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) + if not os.path.exists(curr_snapshot_dir): + os.makedirs(curr_snapshot_dir, exist_ok=True) + logger.info(f"Dumping memory snapshot at step {curr_step}") + begin = time.monotonic() + with open( + f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb" + ) as output: + pickle.dump(torch.cuda.memory._snapshot(), output) + logger.info( + f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" + ) + torch.distributed.barrier() + + logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") + profiler = MemoryProfiler(global_step, config.profiling.profile_freq) + try: + yield profiler + finally: + # dump snapshot when CUDA OOMs + profiler.step(exit_ctx=True) + else: + yield None diff --git a/train.py b/train.py index adbd975f..2e75730c 100644 --- a/train.py +++ b/train.py @@ -38,7 +38,7 @@ ParallelDims, ) from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule -from torchtitan.profiling import maybe_enable_profiling +from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling from torchtitan.utils import ( Color, dist_max, @@ -301,7 +301,9 @@ def loss_fn(pred, labels): logger.info(f"Training starts at step {train_state.step + 1}") with maybe_enable_profiling( job_config, global_step=train_state.step - ) as torch_profiler: + ) as torch_profiler, maybe_enable_memory_snapshot( + job_config, global_step=train_state.step + ) as memory_profiler: checkpoint.reset() # variables used to keep info for metrics logging @@ -447,6 +449,9 @@ def loss_fn(pred, labels): if torch_profiler: torch_profiler.step() + if memory_profiler: + memory_profiler.step() + # Reduce timeout after first train step for faster signal (assumes lazy init, compile are finished) if train_state.step == 1: set_pg_timeouts( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 5d7e9987..ab9dff5b 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -9,6 +9,7 @@ use_for_integration_test = true enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 +enable_memory_snapshot = false [metrics] log_freq = 1 diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 280ac2ae..e95d1723 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -9,6 +9,7 @@ description = "Llama2 13B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 +enable_memory_snapshot = false [metrics] log_freq = 10 diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 959c270a..71170e54 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -9,6 +9,7 @@ description = "Llama2 70B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 +enable_memory_snapshot = false [metrics] log_freq = 10 diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index f2e66de7..9f0dfe31 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -8,6 +8,7 @@ description = "Llama2 7B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 +enable_memory_snapshot = false [metrics] log_freq = 10 diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index f45632ad..2798cee3 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -9,6 +9,8 @@ description = "Llama 3 70B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 +enable_memory_snapshot = false + [metrics] log_freq = 10 diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index aaba99a2..47366e39 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -9,6 +9,7 @@ description = "Llama 3 8B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 +enable_memory_snapshot = false [metrics] log_freq = 10 From 87aaf9e2525d271e3f2ae311062fbc1fcc6da827 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 17 Jun 2024 17:03:43 -0700 Subject: [PATCH 3/6] avoid -1 iteration Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/profiling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index dcb5eb1f..53e8c6e5 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -113,7 +113,9 @@ def step(self, exit_ctx: bool = False): self.step_num += 1 dir_name = f"iteration_{curr_step}" else: - curr_step = self.step_num - 1 + # dump as iteration_0_exit if OOM at iter 0 + # instead of iteration_-1_exit + curr_step = min(self.step_num - 1, 0) dir_name = f"iteration_{curr_step}_exit" curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) if not os.path.exists(curr_snapshot_dir): From e7a3b084c238da30a67a87fdd0dad431f3dcc008 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 17 Jun 2024 22:54:01 -0700 Subject: [PATCH 4/6] step from 1 and exit at OOM Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/config_manager.py | 12 ++++++++++++ torchtitan/profiling.py | 24 ++++++------------------ train_configs/debug_model.toml | 2 +- train_configs/llama2_13b.toml | 1 - train_configs/llama2_70b.toml | 1 - train_configs/llama2_7b.toml | 1 - train_configs/llama3_70b.toml | 2 -- train_configs/llama3_8b.toml | 1 - 8 files changed, 19 insertions(+), 25 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 936020c5..665a8c3f 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -100,6 +100,18 @@ def __init__(self): default=10, help="How often to collect profiler traces, in iterations", ) + self.parser.add_argument( + "--profiling.enable_memory_snapshot", + action="store_true", + default=False, + help="Whether to dump memory snapshot", + ) + self.parser.add_argument( + "--profiling.memory_snapshot_folder", + type=str, + default="memory_snapshots", + help="Memeory snapshot files location", + ) # metrics configs self.parser.add_argument( diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index 53e8c6e5..fdb7d10c 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -19,11 +19,6 @@ # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 -# default memory snapshot folder -ENABLE_MEMORY_SNAPSHOT_KEY = "enable_memory_snapshot" -MEMORY_SNAPSHOT_FOLDER_KEY = "memory_snapshot_folder" -MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE = "memory_snapshot" - @contextlib.contextmanager def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): @@ -83,13 +78,9 @@ def trace_handler(prof): @contextlib.contextmanager def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): - enable_snapshot = getattr(config.profiling, ENABLE_MEMORY_SNAPSHOT_KEY, False) + enable_snapshot = config.profiling.enable_memory_snapshot if enable_snapshot: - snapshot_folder = getattr( - config.profiling, - MEMORY_SNAPSHOT_FOLDER_KEY, - MEMORY_SNAPSHOT_FOLDER_DEFAULT_VALUE, - ) + snapshot_folder = config.profiling.memory_snapshot_folder snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) if not os.path.exists(snapshot_dir): os.makedirs(snapshot_dir, exist_ok=True) @@ -105,17 +96,15 @@ def __init__(self, step_num: int, freq: int): self.freq = freq def step(self, exit_ctx: bool = False): + self.step_num += 1 if not exit_ctx and self.step_num % self.freq != 0: - self.step_num += 1 return if not exit_ctx: curr_step = self.step_num - self.step_num += 1 dir_name = f"iteration_{curr_step}" else: - # dump as iteration_0_exit if OOM at iter 0 - # instead of iteration_-1_exit - curr_step = min(self.step_num - 1, 0) + # dump as iteration_0_exit if OOM at iter 1 + curr_step = self.step_num - 1 dir_name = f"iteration_{curr_step}_exit" curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) if not os.path.exists(curr_snapshot_dir): @@ -135,8 +124,7 @@ def step(self, exit_ctx: bool = False): profiler = MemoryProfiler(global_step, config.profiling.profile_freq) try: yield profiler - finally: - # dump snapshot when CUDA OOMs + except torch.OutOfMemoryError as e: profiler.step(exit_ctx=True) else: yield None diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index ab9dff5b..b2dc9a72 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -9,7 +9,7 @@ use_for_integration_test = true enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 -enable_memory_snapshot = false +enable_memory_snapshot = true [metrics] log_freq = 1 diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 55811f12..f3048ac4 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -9,7 +9,6 @@ description = "Llama2 13B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 -enable_memory_snapshot = false [metrics] log_freq = 10 diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 3304432b..97b1bc71 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -9,7 +9,6 @@ description = "Llama2 70B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 -enable_memory_snapshot = false [metrics] log_freq = 10 diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 83ead66a..95b4c496 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -8,7 +8,6 @@ description = "Llama2 7B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 -enable_memory_snapshot = false [metrics] log_freq = 10 diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 7c88d75f..d498e677 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -9,8 +9,6 @@ description = "Llama 3 70B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 -enable_memory_snapshot = false - [metrics] log_freq = 10 diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 2813d632..f194addb 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -9,7 +9,6 @@ description = "Llama 3 8B training" enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 100 -enable_memory_snapshot = false [metrics] log_freq = 10 From 0639da9087c00caad265cd49d221ce2ea283ba42 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 17 Jun 2024 22:59:56 -0700 Subject: [PATCH 5/6] resolve merge conflict Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- train.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/train.py b/train.py index c61b141f..e6217f4b 100644 --- a/train.py +++ b/train.py @@ -311,15 +311,6 @@ def loss_fn(pred, labels): ) as torch_profiler, maybe_enable_memory_snapshot( job_config, global_step=train_state.step ) as memory_profiler: - checkpoint.reset() - - # variables used to keep info for metrics logging - losses_since_last_log: List[float] = [] - ntokens_since_last_log = 0 - data_loading_times: List[float] = [] - time_last_log = timer() - gpu_memory_monitor.reset_peak_stats() - while train_state.step < job_config.training.steps: train_state.step += 1 if train_state.step > 1 and train_state.step % _gc_freq == 0: From abec6fb9eee7d611f850d73fe4629ce6676a2d22 Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 18 Jun 2024 17:14:24 -0700 Subject: [PATCH 6/6] consistent naming with profiler trace Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/config_manager.py | 4 ++-- torchtitan/profiling.py | 2 +- train_configs/debug_model.toml | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 665a8c3f..0eeac026 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -107,9 +107,9 @@ def __init__(self): help="Whether to dump memory snapshot", ) self.parser.add_argument( - "--profiling.memory_snapshot_folder", + "--profiling.save_memory_snapshot_folder", type=str, - default="memory_snapshots", + default="memory_snapshot", help="Memeory snapshot files location", ) diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index fdb7d10c..c993a74f 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -80,7 +80,7 @@ def trace_handler(prof): def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): enable_snapshot = config.profiling.enable_memory_snapshot if enable_snapshot: - snapshot_folder = config.profiling.memory_snapshot_folder + snapshot_folder = config.profiling.save_memory_snapshot_folder snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) if not os.path.exists(snapshot_dir): os.makedirs(snapshot_dir, exist_ok=True) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index b2dc9a72..0f0794a1 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -10,6 +10,7 @@ enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1