From 5a1966bd9a0be6ede467e25d35f311f418df90cb Mon Sep 17 00:00:00 2001 From: wz337 Date: Mon, 15 Apr 2024 16:44:11 -0700 Subject: [PATCH 1/5] move ckpt under outputs --- torchtitan/config_manager.py | 9 +++++---- train.py | 9 ++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 4a3e61fc..ecd1321d 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -234,8 +234,9 @@ def __init__(self): type=str, default="", help=( - "The folder to store the checkpoints. If this is not specified or " - "is an empty string, checkpointing is disabled." + "The folder to store the checkpoints. If this is an empty string, checkpointing is disabled." + "When specified, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}." + "The default value is an empty string." ), ) self.parser.add_argument( @@ -243,9 +244,9 @@ def __init__(self): type=str, default=False, help=( - "When model_weights_only=True, we keep only model weights for your checkpoint at the end of training." + "When model_weights_only=True, only model weights will be saved at the end of training." "With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion." - "When model_weights_only=False, we do a full checkpoint." + "When model_weights_only=False, the full checkpoint will be saved." "A full checkpoint includes model, optimizer and train_state, which can be used to resume training." "The default value is false." ), diff --git a/train.py b/train.py index 2a669fcf..c1c08051 100644 --- a/train.py +++ b/train.py @@ -229,11 +229,18 @@ def loss_fn(pred, labels): # train loop model.train() + ckpt_folder = job_config.checkpoint.folder + ckpt_folder = ( + os.path.join(job_config.job.dump_folder, ckpt_folder) + if ckpt_folder != "" + else None + ) + checkpoint = CheckpointManager( model=model, optimizer=optimizer, states={"train_state": train_state}, - folder=job_config.checkpoint.folder, + folder=ckpt_folder, interval_type=( IntervalType.SECONDS if job_config.checkpoint.interval_type == "seconds" From 8e3b2775f26cb764223964b819dc0a7fa0567fc7 Mon Sep 17 00:00:00 2001 From: wz337 Date: Mon, 15 Apr 2024 17:41:19 -0700 Subject: [PATCH 2/5] address comments --- torchtitan/checkpoint.py | 31 +++++++++++++++++++++---------- train.py | 19 ++----------------- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 54fa4d71..8abfba26 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -17,6 +17,7 @@ set_model_state_dict, set_optimizer_state_dict, ) +from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger @@ -61,28 +62,38 @@ def __init__( model: nn.Module, optimizer: torch.optim.Optimizer, states: Dict[str, Any], - folder: str, - interval_type: IntervalType, - interval: int, - model_weights_only: bool = False, - export_dtype: str = "float32", + job_config: JobConfig, ) -> None: - self.folder = folder self.states = states - self.model_weights_only = model_weights_only self.states.update( { "model": ModelWrapper(model), "optimizer": OptimizerWrapper(model, optimizer), } ) - self.interval_type = interval_type - self.interval = interval + + ckpt_folder = job_config.checkpoint.folder + ckpt_folder = ( + os.path.join(job_config.job.dump_folder, ckpt_folder) + if ckpt_folder != "" + else None + ) + self.folder = ckpt_folder + + self.interval_type = ( + IntervalType.SECONDS + if job_config.checkpoint.interval_type == "seconds" + else IntervalType.STEPS + ) + + self.interval = job_config.checkpoint.interval + self.model_weights_only = job_config.checkpoint.model_weights_only + self.export_dtype = DTYPE_MAP[job_config.checkpoint.export_dtype] + self.begin = 0 self.work = None self.pg = dist.new_group(backend="gloo") self.doit = None - self.export_dtype = DTYPE_MAP[export_dtype] if self.folder: logger.info( diff --git a/train.py b/train.py index c1c08051..6ccb01d0 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,7 @@ from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.tensor.parallel import loss_parallel -from torchtitan.checkpoint import CheckpointManager, IntervalType +from torchtitan.checkpoint import CheckpointManager from torchtitan.config_manager import JobConfig from torchtitan.datasets import create_tokenizer, dataloader_fn from torchtitan.float8_linear import build_fp8_linear @@ -229,26 +229,11 @@ def loss_fn(pred, labels): # train loop model.train() - ckpt_folder = job_config.checkpoint.folder - ckpt_folder = ( - os.path.join(job_config.job.dump_folder, ckpt_folder) - if ckpt_folder != "" - else None - ) - checkpoint = CheckpointManager( model=model, optimizer=optimizer, states={"train_state": train_state}, - folder=ckpt_folder, - interval_type=( - IntervalType.SECONDS - if job_config.checkpoint.interval_type == "seconds" - else IntervalType.STEPS - ), - interval=job_config.checkpoint.interval, - model_weights_only=job_config.checkpoint.model_weights_only, - export_dtype=job_config.checkpoint.export_dtype, + job_config=job_config, ) checkpoint.load() From 929e5ff58e6f1c876ae08be737f0ec7ac6fece5d Mon Sep 17 00:00:00 2001 From: wz337 Date: Mon, 15 Apr 2024 19:28:08 -0700 Subject: [PATCH 3/5] address TY's comments --- test_runner.py | 24 +++++++++++++------- torchtitan/checkpoint.py | 41 +++++++++++++++------------------- torchtitan/config_manager.py | 33 +++++++++++++++------------ torchtitan/profiling.py | 2 +- train_configs/debug_model.toml | 12 +++++----- train_configs/llama_13b.toml | 10 +++++---- train_configs/llama_70b.toml | 10 +++++---- train_configs/llama_7b.toml | 10 +++++---- 8 files changed, 79 insertions(+), 63 deletions(-) diff --git a/test_runner.py b/test_runner.py index ebaa68ca..1077ce85 100755 --- a/test_runner.py +++ b/test_runner.py @@ -51,30 +51,38 @@ class OverrideDefinitions: ), OverrideDefinitions( [ - [f"--checkpoint.folder {test_checkpoint_dir}_full_checkpoint"], [ - f"--checkpoint.folder {test_checkpoint_dir}_full_checkpoint", + "--checkpoint.enable_checkpoint", + f"--checkpoint.checkpoint_folder {test_checkpoint_dir}_full_checkpoint", + ], + [ + "--checkpoint.enable_checkpoint", + f"--checkpoint.checkpoint_folder {test_checkpoint_dir}_full_checkpoint", "--training.steps 20", ], ], - "Checkpoint Integration Test - Model + Optimizer + TrainState", + "Checkpoint Integration Test - Save Load Full Checkpoint", ), OverrideDefinitions( [ [ - f"--checkpoint.folder {test_checkpoint_dir}_model_weights_only_fp32 --checkpoint.model_weights_only true" + "--checkpoint.enable_checkpoint", + f"--checkpoint.checkpoint_folder {test_checkpoint_dir}_model_weights_only_fp32", + "--checkpoint.model_weights_only", ], ], - "Checkpoint Integration Test - Model Weights Only fp32", + "Checkpoint Integration Test - Save Model Weights Only fp32", ), OverrideDefinitions( [ [ - f"--checkpoint.folder {test_checkpoint_dir}_model_weights_only_bf16", - "--checkpoint.model_weights_only true --checkpoint.export_dtype bfloat16", + "--checkpoint.enable_checkpoint", + f"--checkpoint.checkpoint_folder {test_checkpoint_dir}_model_weights_only_bf16", + "--checkpoint.model_weights_only", + "--checkpoint.export_dtype bfloat16", ], ], - "Checkpoint Integration Test - Model Weights Only bf16", + "Checkpoint Integration Test - Save Model Weights Only bf16", ), ] diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 8abfba26..9c973bfa 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -72,34 +72,29 @@ def __init__( } ) - ckpt_folder = job_config.checkpoint.folder - ckpt_folder = ( - os.path.join(job_config.job.dump_folder, ckpt_folder) - if ckpt_folder != "" - else None - ) - self.folder = ckpt_folder - - self.interval_type = ( - IntervalType.SECONDS - if job_config.checkpoint.interval_type == "seconds" - else IntervalType.STEPS - ) + self.enable_checkpoint = job_config.checkpoint.enable_checkpoint - self.interval = job_config.checkpoint.interval - self.model_weights_only = job_config.checkpoint.model_weights_only - self.export_dtype = DTYPE_MAP[job_config.checkpoint.export_dtype] + if self.enable_checkpoint: + self.folder = os.path.join( + job_config.job.dump_folder, job_config.checkpoint.checkpoint_folder + ) + self.interval_type = ( + IntervalType.SECONDS + if job_config.checkpoint.interval_type == "seconds" + else IntervalType.STEPS + ) + self.interval = job_config.checkpoint.interval + self.model_weights_only = job_config.checkpoint.model_weights_only + self.export_dtype = DTYPE_MAP[job_config.checkpoint.export_dtype] + logger.info( + f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" + ) self.begin = 0 self.work = None self.pg = dist.new_group(backend="gloo") self.doit = None - if self.folder: - logger.info( - f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" - ) - def reset(self) -> None: self.begin = time.monotonic() @@ -111,7 +106,7 @@ def save(self, curr_step: int, force: bool = False) -> None: force = True will force the checkpoint to be saved, even if the interval has not been reached. This only happens when train_state.step == job_config.training.steps. """ - if not self.folder: + if not self.enable_checkpoint: return if not force: @@ -173,7 +168,7 @@ def save(self, curr_step: int, force: bool = False) -> None: ) def load(self, step: int = -1) -> bool: - if not self.folder: + if not self.enable_checkpoint: return False if not os.path.isdir(self.folder): return False diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index ecd1321d..fbc15ab6 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -70,7 +70,7 @@ def __init__(self): # profiling configs self.parser.add_argument( - "--profiling.run_profiler", + "--profiling.enable_profiling", action="store_true", help="enable pytorch profiler", ) @@ -211,13 +211,20 @@ def __init__(self): action="store_true", help="Whether to compile the model.", ) + + # checkpoint configs self.parser.add_argument( - "--checkpoint.interval", - type=int, - default=500, + "--checkpoint.enable_checkpoint", + action="store_true", + help="Whether to enable checkpoint", + ) + self.parser.add_argument( + "--checkpoint.checkpoint_folder", + type=str, + default="checkpoint", help=( - "Checkpointing interval. The unit of measurement is in seconds or " - "steps depending on --checkpoint.interval_type." + "The folder to store the checkpoints." + "When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}." ), ) self.parser.add_argument( @@ -230,19 +237,17 @@ def __init__(self): ), ) self.parser.add_argument( - "--checkpoint.folder", - type=str, - default="", + "--checkpoint.interval", + type=int, + default=500, help=( - "The folder to store the checkpoints. If this is an empty string, checkpointing is disabled." - "When specified, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}." - "The default value is an empty string." + "Checkpointing interval. The unit of measurement is in seconds or " + "steps depending on --checkpoint.interval_type." ), ) self.parser.add_argument( "--checkpoint.model_weights_only", - type=str, - default=False, + action="store_true", help=( "When model_weights_only=True, only model weights will be saved at the end of training." "With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion." diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index c32d98ec..5384a941 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -12,7 +12,7 @@ @contextlib.contextmanager def maybe_run_profiler(config: JobConfig, *pos_args, **kwargs): # get user defined profiler settings - run_profiler = config.profiling.run_profiler + run_profiler = config.profiling.enable_profiling if run_profiler: dump_dir = config.job.dump_folder diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 16caa750..c6317b91 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -6,9 +6,10 @@ description = "LLaMA debug training" use_for_integration_test = true [profiling] -run_profiler = true -save_traces_folder = "profiling/traces" -profile_freq = 10 +enable_profiling = true +save_traces_folder = "profile_trace" +# profiling frequency - example: 10 means every 10th iter will be profiled +profile_freq = 100 [metrics] log_freq = 1 @@ -40,9 +41,10 @@ compile = false dataset = "alpaca" # supported datasets: alpaca (52K), openwebtext (8M), c4 (177M) [checkpoint] -interval = 5 +enable_checkpoint = false +checkpoint_folder = "checkpoint" interval_type = "steps" -folder = "" +interval = 5 model_weights_only = false export_dtype = "float32" diff --git a/train_configs/llama_13b.toml b/train_configs/llama_13b.toml index 9553c9c1..5ee60bbf 100644 --- a/train_configs/llama_13b.toml +++ b/train_configs/llama_13b.toml @@ -6,8 +6,9 @@ dump_folder = "./outputs" description = "LLaMA 13B training" [profiling] -run_profiler = true -save_traces_folder = "profiling/traces" +enable_profiling = true +save_traces_folder = "profile_trace" +# profiling frequency - example: 10 means every 10th iter will be profiled profile_freq = 100 [metrics] @@ -39,9 +40,10 @@ compile = false dataset = "openwebtext" [checkpoint] -interval = 500 +enable_checkpoint = false +checkpoint_folder = "checkpoint" interval_type = "steps" -folder = "" +interval = 500 model_weights_only = false export_dtype = "float32" diff --git a/train_configs/llama_70b.toml b/train_configs/llama_70b.toml index 032092e1..45fd1bfa 100644 --- a/train_configs/llama_70b.toml +++ b/train_configs/llama_70b.toml @@ -6,8 +6,9 @@ dump_folder = "./outputs" description = "LLaMA 70B training" [profiling] -run_profiler = true -save_traces_folder = "profiling/traces" +enable_profiling = true +save_traces_folder = "profile_trace" +# profiling frequency - example: 10 means every 10th iter will be profiled profile_freq = 100 [metrics] @@ -39,9 +40,10 @@ compile = false dataset = "openwebtext" [checkpoint] -interval = 500 +enable_checkpoint = false +checkpoint_folder = "checkpoint" interval_type = "steps" -folder = "" +interval = 500 model_weights_only = false export_dtype = "float32" diff --git a/train_configs/llama_7b.toml b/train_configs/llama_7b.toml index 0c51dc16..640ad628 100644 --- a/train_configs/llama_7b.toml +++ b/train_configs/llama_7b.toml @@ -5,8 +5,9 @@ dump_folder = "./outputs" description = "LLaMA 7B training" [profiling] -run_profiler = true -save_traces_folder = "profiling/traces" +enable_profiling = true +save_traces_folder = "profile_trace" +# profiling frequency - example: 10 means every 10th iter will be profiled profile_freq = 100 [metrics] @@ -38,9 +39,10 @@ compile = false dataset = "openwebtext" [checkpoint] -interval = 500 +enable_checkpoint = false +checkpoint_folder = "checkpoint" interval_type = "steps" -folder = "" +interval = 500 model_weights_only = false export_dtype = "float32" From 273c5662bf2c85a1681477a4700a3e413d49353f Mon Sep 17 00:00:00 2001 From: wz337 Date: Mon, 15 Apr 2024 20:53:47 -0700 Subject: [PATCH 4/5] address TY's comments --- test_runner.py | 8 ++++---- torchtitan/checkpoint.py | 30 +++++++++++++++--------------- torchtitan/config_manager.py | 2 +- torchtitan/profiling.py | 6 +++--- train.py | 4 ++-- train_configs/debug_model.toml | 3 +-- train_configs/llama_13b.toml | 3 +-- train_configs/llama_70b.toml | 3 +-- train_configs/llama_7b.toml | 3 +-- 9 files changed, 29 insertions(+), 33 deletions(-) diff --git a/test_runner.py b/test_runner.py index 1077ce85..33d404c0 100755 --- a/test_runner.py +++ b/test_runner.py @@ -53,11 +53,11 @@ class OverrideDefinitions: [ [ "--checkpoint.enable_checkpoint", - f"--checkpoint.checkpoint_folder {test_checkpoint_dir}_full_checkpoint", + f"--checkpoint.folder {test_checkpoint_dir}_full_checkpoint", ], [ "--checkpoint.enable_checkpoint", - f"--checkpoint.checkpoint_folder {test_checkpoint_dir}_full_checkpoint", + f"--checkpoint.folder {test_checkpoint_dir}_full_checkpoint", "--training.steps 20", ], ], @@ -67,7 +67,7 @@ class OverrideDefinitions: [ [ "--checkpoint.enable_checkpoint", - f"--checkpoint.checkpoint_folder {test_checkpoint_dir}_model_weights_only_fp32", + f"--checkpoint.folder {test_checkpoint_dir}_model_weights_only_fp32", "--checkpoint.model_weights_only", ], ], @@ -77,7 +77,7 @@ class OverrideDefinitions: [ [ "--checkpoint.enable_checkpoint", - f"--checkpoint.checkpoint_folder {test_checkpoint_dir}_model_weights_only_bf16", + f"--checkpoint.folder {test_checkpoint_dir}_model_weights_only_bf16", "--checkpoint.model_weights_only", "--checkpoint.export_dtype bfloat16", ], diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 9c973bfa..16e7517d 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -75,30 +75,30 @@ def __init__( self.enable_checkpoint = job_config.checkpoint.enable_checkpoint if self.enable_checkpoint: - self.folder = os.path.join( - job_config.job.dump_folder, job_config.checkpoint.checkpoint_folder - ) + ckpt_config = job_config.checkpoint + self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) self.interval_type = ( IntervalType.SECONDS - if job_config.checkpoint.interval_type == "seconds" + if ckpt_config.interval_type == "seconds" else IntervalType.STEPS ) - self.interval = job_config.checkpoint.interval - self.model_weights_only = job_config.checkpoint.model_weights_only - self.export_dtype = DTYPE_MAP[job_config.checkpoint.export_dtype] + self.interval = ckpt_config.interval + self.model_weights_only = ckpt_config.model_weights_only + self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype] + logger.info( f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" ) - self.begin = 0 - self.work = None - self.pg = dist.new_group(backend="gloo") - self.doit = None + self.begin = 0 + self.work = None + self.pg = dist.new_group(backend="gloo") + self.doit = None def reset(self) -> None: self.begin = time.monotonic() - def create_checkpoint_id(self, step: int) -> str: + def _create_checkpoint_id(self, step: int) -> str: return os.path.join(self.folder, f"step-{step}") def save(self, curr_step: int, force: bool = False) -> None: @@ -161,7 +161,7 @@ def save(self, curr_step: int, force: bool = False) -> None: logger.info(f"Saving a full checkpoint at step {curr_step}") begin = time.monotonic() - dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step)) + dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) self.reset() logger.info( f"Finished saving the checkpoint in {time.monotonic() - begin:.2f} seconds" @@ -172,7 +172,7 @@ def load(self, step: int = -1) -> bool: return False if not os.path.isdir(self.folder): return False - if step != -1 and not os.path.isdir(self.create_checkpoint_id(step)): + if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)): return False if step == -1: @@ -189,7 +189,7 @@ def load(self, step: int = -1) -> bool: begin = time.monotonic() dcp.load( self.states, - checkpoint_id=self.create_checkpoint_id(step), + checkpoint_id=self._create_checkpoint_id(step), ) logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds" diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index fbc15ab6..4eb8a978 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -219,7 +219,7 @@ def __init__(self): help="Whether to enable checkpoint", ) self.parser.add_argument( - "--checkpoint.checkpoint_folder", + "--checkpoint.folder", type=str, default="checkpoint", help=( diff --git a/torchtitan/profiling.py b/torchtitan/profiling.py index 5384a941..ca194066 100644 --- a/torchtitan/profiling.py +++ b/torchtitan/profiling.py @@ -10,11 +10,11 @@ @contextlib.contextmanager -def maybe_run_profiler(config: JobConfig, *pos_args, **kwargs): +def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs): # get user defined profiler settings - run_profiler = config.profiling.enable_profiling + enable_profiling = config.profiling.enable_profiling - if run_profiler: + if enable_profiling: dump_dir = config.job.dump_folder save_trace_dir = config.profiling.save_traces_folder trace_dir = os.path.join(dump_dir, save_trace_dir) diff --git a/train.py b/train.py index 6ccb01d0..41d3ea33 100644 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.parallelisms import models_parallelize_fns, ParallelDims -from torchtitan.profiling import maybe_run_profiler +from torchtitan.profiling import maybe_enable_profiling from torchtitan.utils import ( Color, dist_max, @@ -251,7 +251,7 @@ def loss_fn(pred, labels): data_iterator = iter(data_loader) logger.info(f"Training starts at step {train_state.step + 1}") - with maybe_run_profiler(job_config) as torch_profiler: + with maybe_enable_profiling(job_config) as torch_profiler: checkpoint.reset() # variables used to keep info for metrics logging diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index c6317b91..2d861bc7 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -8,7 +8,6 @@ use_for_integration_test = true [profiling] enable_profiling = true save_traces_folder = "profile_trace" -# profiling frequency - example: 10 means every 10th iter will be profiled profile_freq = 100 [metrics] @@ -42,7 +41,7 @@ dataset = "alpaca" # supported datasets: alpaca (52K), openwebtext (8M), c4 (1 [checkpoint] enable_checkpoint = false -checkpoint_folder = "checkpoint" +folder = "checkpoint" interval_type = "steps" interval = 5 model_weights_only = false diff --git a/train_configs/llama_13b.toml b/train_configs/llama_13b.toml index 5ee60bbf..4fc72c11 100644 --- a/train_configs/llama_13b.toml +++ b/train_configs/llama_13b.toml @@ -8,7 +8,6 @@ description = "LLaMA 13B training" [profiling] enable_profiling = true save_traces_folder = "profile_trace" -# profiling frequency - example: 10 means every 10th iter will be profiled profile_freq = 100 [metrics] @@ -41,7 +40,7 @@ dataset = "openwebtext" [checkpoint] enable_checkpoint = false -checkpoint_folder = "checkpoint" +folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false diff --git a/train_configs/llama_70b.toml b/train_configs/llama_70b.toml index 45fd1bfa..1878647d 100644 --- a/train_configs/llama_70b.toml +++ b/train_configs/llama_70b.toml @@ -8,7 +8,6 @@ description = "LLaMA 70B training" [profiling] enable_profiling = true save_traces_folder = "profile_trace" -# profiling frequency - example: 10 means every 10th iter will be profiled profile_freq = 100 [metrics] @@ -41,7 +40,7 @@ dataset = "openwebtext" [checkpoint] enable_checkpoint = false -checkpoint_folder = "checkpoint" +folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false diff --git a/train_configs/llama_7b.toml b/train_configs/llama_7b.toml index 640ad628..7e8f7f78 100644 --- a/train_configs/llama_7b.toml +++ b/train_configs/llama_7b.toml @@ -7,7 +7,6 @@ description = "LLaMA 7B training" [profiling] enable_profiling = true save_traces_folder = "profile_trace" -# profiling frequency - example: 10 means every 10th iter will be profiled profile_freq = 100 [metrics] @@ -40,7 +39,7 @@ dataset = "openwebtext" [checkpoint] enable_checkpoint = false -checkpoint_folder = "checkpoint" +folder = "checkpoint" interval_type = "steps" interval = 500 model_weights_only = false From 687322f06c320108453795b628ddf2a1ce1105b6 Mon Sep 17 00:00:00 2001 From: wz337 Date: Mon, 15 Apr 2024 21:43:52 -0700 Subject: [PATCH 5/5] address final comments --- torchtitan/checkpoint.py | 20 ++++++++++---------- torchtitan/config_manager.py | 16 ++++++++-------- train_configs/debug_model.toml | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 16e7517d..ca272e7d 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -64,18 +64,18 @@ def __init__( states: Dict[str, Any], job_config: JobConfig, ) -> None: - self.states = states - self.states.update( - { - "model": ModelWrapper(model), - "optimizer": OptimizerWrapper(model, optimizer), - } - ) - - self.enable_checkpoint = job_config.checkpoint.enable_checkpoint + ckpt_config = job_config.checkpoint + self.enable_checkpoint = ckpt_config.enable_checkpoint if self.enable_checkpoint: - ckpt_config = job_config.checkpoint + self.states = states + self.states.update( + { + "model": ModelWrapper(model), + "optimizer": OptimizerWrapper(model, optimizer), + } + ) + self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) self.interval_type = ( IntervalType.SECONDS diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 4eb8a978..0fcf84c8 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -223,7 +223,7 @@ def __init__(self): type=str, default="checkpoint", help=( - "The folder to store the checkpoints." + "The folder to store the checkpoints. " "When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}." ), ) @@ -232,7 +232,7 @@ def __init__(self): type=str, default="steps", help=( - "The checkpointing interval unit of measurement." + "The checkpointing interval unit of measurement. " "The default value is steps." ), ) @@ -249,10 +249,10 @@ def __init__(self): "--checkpoint.model_weights_only", action="store_true", help=( - "When model_weights_only=True, only model weights will be saved at the end of training." - "With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion." - "When model_weights_only=False, the full checkpoint will be saved." - "A full checkpoint includes model, optimizer and train_state, which can be used to resume training." + "When model_weights_only=True, only model weights will be saved at the end of training. " + "With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion. " + "When model_weights_only=False, the full checkpoint will be saved. " + "A full checkpoint includes model, optimizer and train_state, which can be used to resume training. " "The default value is false." ), ) @@ -261,8 +261,8 @@ def __init__(self): type=str, default="float32", help=( - "Converts to the specified precision when training completes and model_weights_only=true." - "Currently supports float32, float16, and bfloat16." + "Converts to the specified precision when training completes and model_weights_only=true. " + "Currently supports float32, float16, and bfloat16. " "The default value is float32." ), ) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 2d861bc7..6eb623a5 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -8,7 +8,7 @@ use_for_integration_test = true [profiling] enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 100 +profile_freq = 10 [metrics] log_freq = 1