-
Notifications
You must be signed in to change notification settings - Fork 172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchTitan][Checkpoint] Move checkpoint folder under dump_folder and a few config updates #230
Conversation
train.py
Outdated
checkpoint = CheckpointManager( | ||
model=model, | ||
optimizer=optimizer, | ||
states={"train_state": train_state}, | ||
folder=job_config.checkpoint.folder, | ||
folder=ckpt_folder, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to see if we can further simplify train.py for the checkpoint logic.
can we pass job_config
to CheckpointManager
, and handle the:
- checkpoint folder logic above
- set all the options like interval_type/interval/model_weights_only inside the
CheckpointManager
constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I think I can take the entire job_config.checkpoint and handle this inside checkpoint.py. Let me do that.
torchtitan/config_manager.py
Outdated
@@ -234,18 +234,19 @@ def __init__(self): | |||
type=str, | |||
default="", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, can we use None
as default and use it to disable checkpoint? because empty string is also a relative path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is an empty string, I'll make the ckpt folder to be None.
train.py
Outdated
@@ -229,11 +229,18 @@ def loss_fn(pred, labels): | |||
# train loop | |||
model.train() | |||
|
|||
ckpt_folder = job_config.checkpoint.folder | |||
ckpt_folder = ( | |||
os.path.join(job_config.job.dump_folder, ckpt_folder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC some one of us proposed that we should support both relative path and absolute path. I'm OK with either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think based on the discussion today, we are putting the ckpt under dump_folder. So it would always be relative for right now.
if self.folder: | ||
self.enable_checkpoint = job_config.checkpoint.enable_checkpoint | ||
|
||
if self.enable_checkpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have if not self.enable_checkpoint: return
and then everything else? just like in save
and load
functions. Essentially we can make CheckpointManager a noop class if not enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am gonna indent everything under if self.enable_checkpoint:
. For the else case, it would just simply exit the constructor and there is no return value.
torchtitan/config_manager.py
Outdated
help="Whether to enable checkpoint", | ||
) | ||
self.parser.add_argument( | ||
"--checkpoint.checkpoint_folder", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After second thoughts, I think it's better to name it checkpoint.folder
rather than checkpoint.checkpoint_folder
, since there is no ambiguity. The other two appearances of folder
need prefix because there could be ambiguity over there.
torchtitan/profiling.py
Outdated
@@ -12,7 +12,7 @@ | |||
@contextlib.contextmanager | |||
def maybe_run_profiler(config: JobConfig, *pos_args, **kwargs): | |||
# get user defined profiler settings | |||
run_profiler = config.profiling.run_profiler | |||
run_profiler = config.profiling.enable_profiling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rename run_profiler
as well to be consistent
train_configs/debug_model.toml
Outdated
profile_freq = 10 | ||
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
# profiling frequency - example: 10 means every 10th iter will be profiled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove comment, as there's no ambiguity here.
train_configs/llama_13b.toml
Outdated
save_traces_folder = "profiling/traces" | ||
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
# profiling frequency - example: 10 means every 10th iter will be profiled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: remove
train_configs/llama_70b.toml
Outdated
save_traces_folder = "profiling/traces" | ||
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
# profiling frequency - example: 10 means every 10th iter will be profiled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: remove
train_configs/llama_7b.toml
Outdated
save_traces_folder = "profiling/traces" | ||
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
# profiling frequency - example: 10 means every 10th iter will be profiled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: remove
self.work = None | ||
self.pg = dist.new_group(backend="gloo") | ||
self.doit = None | ||
|
||
def reset(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about rename reset
and create_checkpoint_id
to _reset
and _create_checkpoint_id
as they are helper function only called within?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reset() did get used outside. Just updated _create_checkpoint_id.
torchtitan/checkpoint.py
Outdated
|
||
if self.enable_checkpoint: | ||
self.folder = os.path.join( | ||
job_config.job.dump_folder, job_config.checkpoint.checkpoint_folder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we are calling job_config.checkpoint
several times, shall we set checkpoint_config = job_config.checkpoint
in the beginning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! Thanks for improving the checkpointing ux! Had somef inal inline comments
torchtitan/config_manager.py
Outdated
help=( | ||
"Checkpointing interval. The unit of measurement is in seconds or " | ||
"steps depending on --checkpoint.interval_type." | ||
"The folder to store the checkpoints." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: need a whitespace between sentences; same for the helper messages for other checkpointing options
train_configs/debug_model.toml
Outdated
profile_freq = 10 | ||
enable_profiling = true | ||
save_traces_folder = "profile_trace" | ||
profile_freq = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls change back to 10 :)
torchtitan/checkpoint.py
Outdated
) -> None: | ||
self.folder = folder | ||
self.states = states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need self.states when not enabling checkpointing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved under if.
… a few config updates (#230) Let CheckpointManager take entire job_config as an arg so we can keep train.py a little bit cleaner. Discussed with @tianyu-l and made a few additional changes, including: 1. Rename "run_profiler" to "enable_profiling" 2. Add an "enable_checkpoint" flag so it is consistent to "enable_profiling" or "enable_tensorboard". We feel like this is a little bit more explicit. 3. Change the default checkpoint folder to be ".outputs/checkpoint" when checkpoint is enabled. 4. Rename "folder" in [checkpiont]" to be "checkpoint_folder" 5. Change save_traces_folder to be "./outputs/profile_trace" from ".outputs/profiling/traces".
… a few config updates (pytorch#230) Let CheckpointManager take entire job_config as an arg so we can keep train.py a little bit cleaner. Discussed with @tianyu-l and made a few additional changes, including: 1. Rename "run_profiler" to "enable_profiling" 2. Add an "enable_checkpoint" flag so it is consistent to "enable_profiling" or "enable_tensorboard". We feel like this is a little bit more explicit. 3. Change the default checkpoint folder to be ".outputs/checkpoint" when checkpoint is enabled. 4. Rename "folder" in [checkpiont]" to be "checkpoint_folder" 5. Change save_traces_folder to be "./outputs/profile_trace" from ".outputs/profiling/traces".
Let CheckpointManager take entire job_config as an arg so we can keep train.py a little bit cleaner.
Discussed with @tianyu-l and made a few additional changes, including: