-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[air/tuner] Add checkpoint_frequency/checkpoint_at_end arguments to CheckpointConfig #26661
[air/tuner] Add checkpoint_frequency/checkpoint_at_end arguments to CheckpointConfig #26661
Conversation
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
987ccca
to
301b48f
Compare
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
301b48f
to
82b3d19
Compare
_dmatrix_cls: type | ||
_ray_params_cls: type | ||
_tune_callback_cls: type | ||
_tune_callback_report_cls: type | ||
_tune_callback_checkpoint_cls: type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clearly being clueless here,
why do we need to add these separately to all of gbdt, lightbgm, xgboost trainers?
why don't we just save CheckpointConfig and call save_checkpoint() from the base class?
also, why doesn't RLTrainer need these changes???
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a specific pattern for some of our downstream libraries. XGBoost and LightGBM use callbacks in the framework library to save checkpoints and report to tune (i.e. the calls to tune.checkpoint_dir()
and tune.report()
, which will be changed to session.report()
).
Because LightGBM and XGBoost are so similar (LightGBM's API was based on XGBoost) we have a GBDTTrainer that can be used for most commong things. We only have to deal with a few framework-specific details, which are the actual callbacks used for saving checkpoints/reporting results, getting information from the library-native model and saving it to disk.
# Conflicts: # python/ray/tune/tests/test_tuner.py
if not any( | ||
isinstance( | ||
cb, (self._tune_callback_report_cls, self._tune_callback_checkpoint_cls) | ||
) | ||
for cb in config["callbacks"] | ||
): | ||
# Only add our own callback if it hasn't been added before | ||
checkpoint_frequency = ( | ||
self.run_config.checkpoint_config.checkpoint_frequency | ||
) | ||
if checkpoint_frequency > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we ban this in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users can always add their own callbacks (which could also be non-reporting, or derived from our reporting callbacks), so we should check if the callbacks exist. We do the same thing in Tune. But for better readability I can at least put this into a separate function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me.
Signed-off-by: Kai Fricke <[email protected]>
# Conflicts: # python/ray/train/tests/test_torch_trainer.py
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a few quick questions.
@@ -193,6 +223,21 @@ def training_loop(self) -> None: | |||
**config, | |||
) | |||
|
|||
checkpoint_at_end = self.run_config.checkpoint_config.checkpoint_at_end | |||
if checkpoint_at_end is None: | |||
checkpoint_at_end = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
providing a default in customer class seems strange.
should we just give a True default in CheckpointConfig?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case we can't. We have different default values for checkpoint_at_end depending on the trainable. If we set it to True per default, it will be incompatible with legacy trainables such as regular function trainables and raise an error downstream in tune.run. If we set it to False per default, we won't save checkpoints for most trainers and can't use their models in downstream processing.
I'm working on a FunctionTrainer today that will replace running legacy trainable functions. This will enable us to ignore the legacy function trainables here and default to True. Does that sound good?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that another reason why we can't to it at the moment is that if people specifically pass checkpoint_at_end=True
to use with their function trainables, we don't want to silently set it to False, but we could do this once we use a FunctionTrainer (though it's not ideal as users may still pass True and wonder why they don't see any saved checkpoints)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. appreciate the detailed explanation.
navigating through a lot of legacy stuff is hard ...
checkpoint_at_end = False | ||
# If this is a user-defined trainable, just keep the value | ||
elif checkpoint_at_end is None: | ||
# Set default to False for function trainables and True for everything else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
intuitively, why does the type of trainable has anything to do with whether we should checkpoint at end??
are these 2 orthogonal things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function trainables and generic training loop trainers like the TorchTrainer don't support driver-based checkpointing at the end of a run. This is because the user defines the training loop and thus decides themselves when to save checkpoints. This is in contrast to e.g. rllib where we can just call trainable.save.remote()
anytime we want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, I guess I am still not sure why we want to default checkpoint_at_end to True just because we are able to do it for class trainables.
Signed-off-by: Kai Fricke <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this.
…heckpointConfig (ray-project#26661) Signed-off-by: Stefan van der Kleij <[email protected]>
Why are these changes needed?
Includes/depends on #26656
This PR adds the
checkpoint_freq
andcheckpoint_at_end
arguments to the CheckpointConfig:CheckpointConfig.checkpoint_frequency
andcheckpoint_at_end
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.