-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchTitan][Checkpoint] Move checkpoint folder under dump_folder and a few config updates #230
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
set_model_state_dict, | ||
set_optimizer_state_dict, | ||
) | ||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.logging_utils import logger | ||
|
||
|
||
|
@@ -61,46 +62,51 @@ def __init__( | |
model: nn.Module, | ||
optimizer: torch.optim.Optimizer, | ||
states: Dict[str, Any], | ||
folder: str, | ||
interval_type: IntervalType, | ||
interval: int, | ||
model_weights_only: bool = False, | ||
export_dtype: str = "float32", | ||
job_config: JobConfig, | ||
) -> None: | ||
self.folder = folder | ||
self.states = states | ||
self.model_weights_only = model_weights_only | ||
self.states.update( | ||
{ | ||
"model": ModelWrapper(model), | ||
"optimizer": OptimizerWrapper(model, optimizer), | ||
} | ||
) | ||
self.interval_type = interval_type | ||
self.interval = interval | ||
self.begin = 0 | ||
self.work = None | ||
self.pg = dist.new_group(backend="gloo") | ||
self.doit = None | ||
self.export_dtype = DTYPE_MAP[export_dtype] | ||
|
||
if self.folder: | ||
|
||
self.enable_checkpoint = job_config.checkpoint.enable_checkpoint | ||
|
||
if self.enable_checkpoint: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am gonna indent everything under |
||
ckpt_config = job_config.checkpoint | ||
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) | ||
self.interval_type = ( | ||
IntervalType.SECONDS | ||
if ckpt_config.interval_type == "seconds" | ||
else IntervalType.STEPS | ||
) | ||
self.interval = ckpt_config.interval | ||
self.model_weights_only = ckpt_config.model_weights_only | ||
self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype] | ||
|
||
logger.info( | ||
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" | ||
) | ||
|
||
self.begin = 0 | ||
self.work = None | ||
self.pg = dist.new_group(backend="gloo") | ||
self.doit = None | ||
|
||
def reset(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about rename There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reset() did get used outside. Just updated _create_checkpoint_id. |
||
self.begin = time.monotonic() | ||
|
||
def create_checkpoint_id(self, step: int) -> str: | ||
def _create_checkpoint_id(self, step: int) -> str: | ||
return os.path.join(self.folder, f"step-{step}") | ||
|
||
def save(self, curr_step: int, force: bool = False) -> None: | ||
""" | ||
force = True will force the checkpoint to be saved, even if the interval has not been reached. | ||
This only happens when train_state.step == job_config.training.steps. | ||
""" | ||
if not self.folder: | ||
if not self.enable_checkpoint: | ||
return | ||
|
||
if not force: | ||
|
@@ -155,18 +161,18 @@ def save(self, curr_step: int, force: bool = False) -> None: | |
logger.info(f"Saving a full checkpoint at step {curr_step}") | ||
|
||
begin = time.monotonic() | ||
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step)) | ||
dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) | ||
self.reset() | ||
logger.info( | ||
f"Finished saving the checkpoint in {time.monotonic() - begin:.2f} seconds" | ||
) | ||
|
||
def load(self, step: int = -1) -> bool: | ||
if not self.folder: | ||
if not self.enable_checkpoint: | ||
return False | ||
if not os.path.isdir(self.folder): | ||
return False | ||
if step != -1 and not os.path.isdir(self.create_checkpoint_id(step)): | ||
if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)): | ||
return False | ||
|
||
if step == -1: | ||
|
@@ -183,7 +189,7 @@ def load(self, step: int = -1) -> bool: | |
begin = time.monotonic() | ||
dcp.load( | ||
self.states, | ||
checkpoint_id=self.create_checkpoint_id(step), | ||
checkpoint_id=self._create_checkpoint_id(step), | ||
) | ||
logger.info( | ||
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,7 +70,7 @@ def __init__(self): | |
|
||
# profiling configs | ||
self.parser.add_argument( | ||
"--profiling.run_profiler", | ||
"--profiling.enable_profiling", | ||
action="store_true", | ||
help="enable pytorch profiler", | ||
) | ||
|
@@ -211,13 +211,20 @@ def __init__(self): | |
action="store_true", | ||
help="Whether to compile the model.", | ||
) | ||
|
||
# checkpoint configs | ||
self.parser.add_argument( | ||
"--checkpoint.interval", | ||
type=int, | ||
default=500, | ||
"--checkpoint.enable_checkpoint", | ||
action="store_true", | ||
help="Whether to enable checkpoint", | ||
) | ||
self.parser.add_argument( | ||
"--checkpoint.folder", | ||
type=str, | ||
default="checkpoint", | ||
help=( | ||
"Checkpointing interval. The unit of measurement is in seconds or " | ||
"steps depending on --checkpoint.interval_type." | ||
"The folder to store the checkpoints." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: need a whitespace between sentences; same for the helper messages for other checkpointing options |
||
"When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}." | ||
), | ||
) | ||
self.parser.add_argument( | ||
|
@@ -230,22 +237,21 @@ def __init__(self): | |
), | ||
) | ||
self.parser.add_argument( | ||
"--checkpoint.folder", | ||
type=str, | ||
default="", | ||
"--checkpoint.interval", | ||
type=int, | ||
default=500, | ||
help=( | ||
"The folder to store the checkpoints. If this is not specified or " | ||
"is an empty string, checkpointing is disabled." | ||
"Checkpointing interval. The unit of measurement is in seconds or " | ||
"steps depending on --checkpoint.interval_type." | ||
), | ||
) | ||
self.parser.add_argument( | ||
"--checkpoint.model_weights_only", | ||
type=str, | ||
default=False, | ||
action="store_true", | ||
help=( | ||
"When model_weights_only=True, we keep only model weights for your checkpoint at the end of training." | ||
"When model_weights_only=True, only model weights will be saved at the end of training." | ||
"With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion." | ||
"When model_weights_only=False, we do a full checkpoint." | ||
"When model_weights_only=False, the full checkpoint will be saved." | ||
"A full checkpoint includes model, optimizer and train_state, which can be used to resume training." | ||
"The default value is false." | ||
), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,9 @@ description = "LLaMA debug training" | |
use_for_integration_test = true | ||
|
||
[profiling] | ||
run_profiler = true | ||
save_traces_folder = "profiling/traces" | ||
profile_freq = 10 | ||
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
profile_freq = 100 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls change back to 10 :) |
||
|
||
[metrics] | ||
log_freq = 1 | ||
|
@@ -40,9 +40,10 @@ compile = false | |
dataset = "alpaca" # supported datasets: alpaca (52K), openwebtext (8M), c4 (177M) | ||
|
||
[checkpoint] | ||
interval = 5 | ||
enable_checkpoint = false | ||
folder = "checkpoint" | ||
interval_type = "steps" | ||
folder = "" | ||
interval = 5 | ||
model_weights_only = false | ||
export_dtype = "float32" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need self.states when not enabling checkpointing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved under if.