Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchTitan][Checkpoint] Move checkpoint folder under dump_folder and a few config updates #230

Merged
merged 5 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,38 @@ class OverrideDefinitions:
),
OverrideDefinitions(
[
[f"--checkpoint.folder {test_checkpoint_dir}_full_checkpoint"],
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_full_checkpoint",
],
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_full_checkpoint",
"--training.steps 20",
],
],
"Checkpoint Integration Test - Model + Optimizer + TrainState",
"Checkpoint Integration Test - Save Load Full Checkpoint",
),
OverrideDefinitions(
[
[
f"--checkpoint.folder {test_checkpoint_dir}_model_weights_only_fp32 --checkpoint.model_weights_only true"
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_model_weights_only_fp32",
"--checkpoint.model_weights_only",
],
],
"Checkpoint Integration Test - Model Weights Only fp32",
"Checkpoint Integration Test - Save Model Weights Only fp32",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--checkpoint.folder {test_checkpoint_dir}_model_weights_only_bf16",
"--checkpoint.model_weights_only true --checkpoint.export_dtype bfloat16",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
],
],
"Checkpoint Integration Test - Model Weights Only bf16",
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
]

Expand Down
64 changes: 35 additions & 29 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
set_model_state_dict,
set_optimizer_state_dict,
)
from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger


Expand Down Expand Up @@ -61,46 +62,51 @@ def __init__(
model: nn.Module,
optimizer: torch.optim.Optimizer,
states: Dict[str, Any],
folder: str,
interval_type: IntervalType,
interval: int,
model_weights_only: bool = False,
export_dtype: str = "float32",
job_config: JobConfig,
) -> None:
self.folder = folder
self.states = states
self.model_weights_only = model_weights_only
self.states.update(
{
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
}
)
self.interval_type = interval_type
self.interval = interval
self.begin = 0
self.work = None
self.pg = dist.new_group(backend="gloo")
self.doit = None
self.export_dtype = DTYPE_MAP[export_dtype]

if self.folder:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint

if self.enable_checkpoint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have if not self.enable_checkpoint: return and then everything else? just like in save and load functions. Essentially we can make CheckpointManager a noop class if not enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am gonna indent everything under if self.enable_checkpoint:. For the else case, it would just simply exit the constructor and there is no return value.

self.states = states
self.states.update(
{
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
}
)

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
IntervalType.SECONDS
if ckpt_config.interval_type == "seconds"
else IntervalType.STEPS
)
self.interval = ckpt_config.interval
self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype]

logger.info(
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}"
)

self.begin = 0
self.work = None
self.pg = dist.new_group(backend="gloo")
self.doit = None

def reset(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about rename reset and create_checkpoint_id to _reset and _create_checkpoint_id as they are helper function only called within?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reset() did get used outside. Just updated _create_checkpoint_id.

self.begin = time.monotonic()

def create_checkpoint_id(self, step: int) -> str:
def _create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")

def save(self, curr_step: int, force: bool = False) -> None:
"""
force = True will force the checkpoint to be saved, even if the interval has not been reached.
This only happens when train_state.step == job_config.training.steps.
"""
if not self.folder:
if not self.enable_checkpoint:
return

if not force:
Expand Down Expand Up @@ -155,18 +161,18 @@ def save(self, curr_step: int, force: bool = False) -> None:
logger.info(f"Saving a full checkpoint at step {curr_step}")

begin = time.monotonic()
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step))
dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step))
self.reset()
logger.info(
f"Finished saving the checkpoint in {time.monotonic() - begin:.2f} seconds"
)

def load(self, step: int = -1) -> bool:
if not self.folder:
if not self.enable_checkpoint:
return False
if not os.path.isdir(self.folder):
return False
if step != -1 and not os.path.isdir(self.create_checkpoint_id(step)):
if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)):
return False

if step == -1:
Expand All @@ -183,7 +189,7 @@ def load(self, step: int = -1) -> bool:
begin = time.monotonic()
dcp.load(
self.states,
checkpoint_id=self.create_checkpoint_id(step),
checkpoint_id=self._create_checkpoint_id(step),
)
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds"
Expand Down
46 changes: 26 additions & 20 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self):

# profiling configs
self.parser.add_argument(
"--profiling.run_profiler",
"--profiling.enable_profiling",
action="store_true",
help="enable pytorch profiler",
)
Expand Down Expand Up @@ -211,42 +211,48 @@ def __init__(self):
action="store_true",
help="Whether to compile the model.",
)

# checkpoint configs
self.parser.add_argument(
"--checkpoint.interval",
type=int,
default=500,
"--checkpoint.enable_checkpoint",
action="store_true",
help="Whether to enable checkpoint",
)
self.parser.add_argument(
"--checkpoint.folder",
type=str,
default="checkpoint",
help=(
"Checkpointing interval. The unit of measurement is in seconds or "
"steps depending on --checkpoint.interval_type."
"The folder to store the checkpoints. "
"When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}."
),
)
self.parser.add_argument(
"--checkpoint.interval_type",
type=str,
default="steps",
help=(
"The checkpointing interval unit of measurement."
"The checkpointing interval unit of measurement. "
"The default value is steps."
),
)
self.parser.add_argument(
"--checkpoint.folder",
type=str,
default="",
"--checkpoint.interval",
type=int,
default=500,
help=(
"The folder to store the checkpoints. If this is not specified or "
"is an empty string, checkpointing is disabled."
"Checkpointing interval. The unit of measurement is in seconds or "
"steps depending on --checkpoint.interval_type."
),
)
self.parser.add_argument(
"--checkpoint.model_weights_only",
type=str,
default=False,
action="store_true",
help=(
"When model_weights_only=True, we keep only model weights for your checkpoint at the end of training."
"With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion."
"When model_weights_only=False, we do a full checkpoint."
"A full checkpoint includes model, optimizer and train_state, which can be used to resume training."
"When model_weights_only=True, only model weights will be saved at the end of training. "
"With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion. "
"When model_weights_only=False, the full checkpoint will be saved. "
"A full checkpoint includes model, optimizer and train_state, which can be used to resume training. "
"The default value is false."
),
)
Expand All @@ -255,8 +261,8 @@ def __init__(self):
type=str,
default="float32",
help=(
"Converts to the specified precision when training completes and model_weights_only=true."
"Currently supports float32, float16, and bfloat16."
"Converts to the specified precision when training completes and model_weights_only=true. "
"Currently supports float32, float16, and bfloat16. "
"The default value is float32."
),
)
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@


@contextlib.contextmanager
def maybe_run_profiler(config: JobConfig, *pos_args, **kwargs):
def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs):
# get user defined profiler settings
run_profiler = config.profiling.run_profiler
enable_profiling = config.profiling.enable_profiling

if run_profiler:
if enable_profiling:
dump_dir = config.job.dump_folder
save_trace_dir = config.profiling.save_traces_folder
trace_dir = os.path.join(dump_dir, save_trace_dir)
Expand Down
16 changes: 4 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel

from torchtitan.checkpoint import CheckpointManager, IntervalType
from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import create_tokenizer, dataloader_fn
from torchtitan.float8_linear import build_fp8_linear
Expand All @@ -29,7 +29,7 @@
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from torchtitan.profiling import maybe_run_profiler
from torchtitan.profiling import maybe_enable_profiling
from torchtitan.utils import (
Color,
dist_max,
Expand Down Expand Up @@ -233,15 +233,7 @@ def loss_fn(pred, labels):
model=model,
optimizer=optimizer,
states={"train_state": train_state},
folder=job_config.checkpoint.folder,
interval_type=(
IntervalType.SECONDS
if job_config.checkpoint.interval_type == "seconds"
else IntervalType.STEPS
),
interval=job_config.checkpoint.interval,
model_weights_only=job_config.checkpoint.model_weights_only,
export_dtype=job_config.checkpoint.export_dtype,
job_config=job_config,
)
checkpoint.load()

Expand All @@ -259,7 +251,7 @@ def loss_fn(pred, labels):
data_iterator = iter(data_loader)

logger.info(f"Training starts at step {train_state.step + 1}")
with maybe_run_profiler(job_config) as torch_profiler:
with maybe_enable_profiling(job_config) as torch_profiler:
checkpoint.reset()

# variables used to keep info for metrics logging
Expand Down
9 changes: 5 additions & 4 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ description = "LLaMA debug training"
use_for_integration_test = true

[profiling]
run_profiler = true
save_traces_folder = "profiling/traces"
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10

[metrics]
Expand Down Expand Up @@ -40,9 +40,10 @@ compile = false
dataset = "alpaca" # supported datasets: alpaca (52K), openwebtext (8M), c4 (177M)

[checkpoint]
interval = 5
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
folder = ""
interval = 5
model_weights_only = false
export_dtype = "float32"

Expand Down
9 changes: 5 additions & 4 deletions train_configs/llama_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ dump_folder = "./outputs"
description = "LLaMA 13B training"

[profiling]
run_profiler = true
save_traces_folder = "profiling/traces"
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
Expand Down Expand Up @@ -39,9 +39,10 @@ compile = false
dataset = "openwebtext"

[checkpoint]
interval = 500
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
folder = ""
interval = 500
model_weights_only = false
export_dtype = "float32"

Expand Down
9 changes: 5 additions & 4 deletions train_configs/llama_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ dump_folder = "./outputs"
description = "LLaMA 70B training"

[profiling]
run_profiler = true
save_traces_folder = "profiling/traces"
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
Expand Down Expand Up @@ -39,9 +39,10 @@ compile = false
dataset = "openwebtext"

[checkpoint]
interval = 500
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
folder = ""
interval = 500
model_weights_only = false
export_dtype = "float32"

Expand Down
Loading
Loading