-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchTitan][Checkpoint] Move checkpoint folder under dump_folder and a few config updates #230
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
set_model_state_dict, | ||
set_optimizer_state_dict, | ||
) | ||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.logging_utils import logger | ||
|
||
|
||
|
@@ -61,46 +62,51 @@ def __init__( | |
model: nn.Module, | ||
optimizer: torch.optim.Optimizer, | ||
states: Dict[str, Any], | ||
folder: str, | ||
interval_type: IntervalType, | ||
interval: int, | ||
model_weights_only: bool = False, | ||
export_dtype: str = "float32", | ||
job_config: JobConfig, | ||
) -> None: | ||
self.folder = folder | ||
self.states = states | ||
self.model_weights_only = model_weights_only | ||
self.states.update( | ||
{ | ||
"model": ModelWrapper(model), | ||
"optimizer": OptimizerWrapper(model, optimizer), | ||
} | ||
) | ||
self.interval_type = interval_type | ||
self.interval = interval | ||
self.begin = 0 | ||
self.work = None | ||
self.pg = dist.new_group(backend="gloo") | ||
self.doit = None | ||
self.export_dtype = DTYPE_MAP[export_dtype] | ||
|
||
if self.folder: | ||
ckpt_config = job_config.checkpoint | ||
self.enable_checkpoint = ckpt_config.enable_checkpoint | ||
|
||
if self.enable_checkpoint: | ||
self.states = states | ||
self.states.update( | ||
{ | ||
"model": ModelWrapper(model), | ||
"optimizer": OptimizerWrapper(model, optimizer), | ||
} | ||
) | ||
|
||
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) | ||
self.interval_type = ( | ||
IntervalType.SECONDS | ||
if ckpt_config.interval_type == "seconds" | ||
else IntervalType.STEPS | ||
) | ||
self.interval = ckpt_config.interval | ||
self.model_weights_only = ckpt_config.model_weights_only | ||
self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype] | ||
|
||
logger.info( | ||
f"Checkpointing active. Checkpoints will be loaded from and saved to {self.folder}" | ||
) | ||
|
||
self.begin = 0 | ||
self.work = None | ||
self.pg = dist.new_group(backend="gloo") | ||
self.doit = None | ||
|
||
def reset(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about rename There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reset() did get used outside. Just updated _create_checkpoint_id. |
||
self.begin = time.monotonic() | ||
|
||
def create_checkpoint_id(self, step: int) -> str: | ||
def _create_checkpoint_id(self, step: int) -> str: | ||
return os.path.join(self.folder, f"step-{step}") | ||
|
||
def save(self, curr_step: int, force: bool = False) -> None: | ||
""" | ||
force = True will force the checkpoint to be saved, even if the interval has not been reached. | ||
This only happens when train_state.step == job_config.training.steps. | ||
""" | ||
if not self.folder: | ||
if not self.enable_checkpoint: | ||
return | ||
|
||
if not force: | ||
|
@@ -155,18 +161,18 @@ def save(self, curr_step: int, force: bool = False) -> None: | |
logger.info(f"Saving a full checkpoint at step {curr_step}") | ||
|
||
begin = time.monotonic() | ||
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step)) | ||
dcp.save(self.states, checkpoint_id=self._create_checkpoint_id(curr_step)) | ||
self.reset() | ||
logger.info( | ||
f"Finished saving the checkpoint in {time.monotonic() - begin:.2f} seconds" | ||
) | ||
|
||
def load(self, step: int = -1) -> bool: | ||
if not self.folder: | ||
if not self.enable_checkpoint: | ||
return False | ||
if not os.path.isdir(self.folder): | ||
return False | ||
if step != -1 and not os.path.isdir(self.create_checkpoint_id(step)): | ||
if step != -1 and not os.path.isdir(self._create_checkpoint_id(step)): | ||
return False | ||
|
||
if step == -1: | ||
|
@@ -183,7 +189,7 @@ def load(self, step: int = -1) -> bool: | |
begin = time.monotonic() | ||
dcp.load( | ||
self.states, | ||
checkpoint_id=self.create_checkpoint_id(step), | ||
checkpoint_id=self._create_checkpoint_id(step), | ||
) | ||
logger.info( | ||
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have
if not self.enable_checkpoint: return
and then everything else? just like insave
andload
functions. Essentially we can make CheckpointManager a noop class if not enabled.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am gonna indent everything under
if self.enable_checkpoint:
. For the else case, it would just simply exit the constructor and there is no return value.