Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air/tuner] Add checkpoint_frequency/checkpoint_at_end arguments to CheckpointConfig #26661

Merged
merged 24 commits into from
Jul 19, 2022

Conversation

krfricke
Copy link
Contributor

Why are these changes needed?

Includes/depends on #26656

This PR adds the checkpoint_freq and checkpoint_at_end arguments to the CheckpointConfig:

  • Adds CheckpointConfig.checkpoint_frequency and checkpoint_at_end
  • Implements the argument for LightGBM and XGBoost
  • Adds tests for LightGBM, XGBoost, and RLTrainer
  • Raises an error if used with an incompatible Trainer (e.g. TorchTrainer)
  • Sets default value for checkpoint_at_end

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Kai Fricke added 8 commits July 18, 2022 08:36
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Kai Fricke added 8 commits July 18, 2022 17:26
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
_dmatrix_cls: type
_ray_params_cls: type
_tune_callback_cls: type
_tune_callback_report_cls: type
_tune_callback_checkpoint_cls: type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clearly being clueless here,
why do we need to add these separately to all of gbdt, lightbgm, xgboost trainers?
why don't we just save CheckpointConfig and call save_checkpoint() from the base class?

also, why doesn't RLTrainer need these changes???

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a specific pattern for some of our downstream libraries. XGBoost and LightGBM use callbacks in the framework library to save checkpoints and report to tune (i.e. the calls to tune.checkpoint_dir() and tune.report(), which will be changed to session.report()).

Because LightGBM and XGBoost are so similar (LightGBM's API was based on XGBoost) we have a GBDTTrainer that can be used for most commong things. We only have to deal with a few framework-specific details, which are the actual callbacks used for saving checkpoints/reporting results, getting information from the library-native model and saving it to disk.

# Conflicts:
#	python/ray/tune/tests/test_tuner.py
@krfricke krfricke marked this pull request as ready for review July 18, 2022 18:41
Comment on lines +197 to +207
if not any(
isinstance(
cb, (self._tune_callback_report_cls, self._tune_callback_checkpoint_cls)
)
for cb in config["callbacks"]
):
# Only add our own callback if it hasn't been added before
checkpoint_frequency = (
self.run_config.checkpoint_config.checkpoint_frequency
)
if checkpoint_frequency > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we ban this in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users can always add their own callbacks (which could also be non-reporting, or derived from our reporting callbacks), so we should check if the callbacks exist. We do the same thing in Tune. But for better readability I can at least put this into a separate function.

Copy link
Contributor

@richardliaw richardliaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me.

Kai Fricke added 4 commits July 19, 2022 00:16
# Conflicts:
#	python/ray/train/tests/test_torch_trainer.py
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Kai Fricke <[email protected]>
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few quick questions.

@@ -193,6 +223,21 @@ def training_loop(self) -> None:
**config,
)

checkpoint_at_end = self.run_config.checkpoint_config.checkpoint_at_end
if checkpoint_at_end is None:
checkpoint_at_end = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

providing a default in customer class seems strange.
should we just give a True default in CheckpointConfig?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we can't. We have different default values for checkpoint_at_end depending on the trainable. If we set it to True per default, it will be incompatible with legacy trainables such as regular function trainables and raise an error downstream in tune.run. If we set it to False per default, we won't save checkpoints for most trainers and can't use their models in downstream processing.

I'm working on a FunctionTrainer today that will replace running legacy trainable functions. This will enable us to ignore the legacy function trainables here and default to True. Does that sound good?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that another reason why we can't to it at the moment is that if people specifically pass checkpoint_at_end=True to use with their function trainables, we don't want to silently set it to False, but we could do this once we use a FunctionTrainer (though it's not ideal as users may still pass True and wonder why they don't see any saved checkpoints)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. appreciate the detailed explanation.
navigating through a lot of legacy stuff is hard ...

python/ray/tune/impl/tuner_internal.py Show resolved Hide resolved
python/ray/tune/impl/tuner_internal.py Outdated Show resolved Hide resolved
checkpoint_at_end = False
# If this is a user-defined trainable, just keep the value
elif checkpoint_at_end is None:
# Set default to False for function trainables and True for everything else
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intuitively, why does the type of trainable has anything to do with whether we should checkpoint at end??
are these 2 orthogonal things?

Copy link
Contributor Author

@krfricke krfricke Jul 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function trainables and generic training loop trainers like the TorchTrainer don't support driver-based checkpointing at the end of a run. This is because the user defines the training loop and thus decides themselves when to save checkpoints. This is in contrast to e.g. rllib where we can just call trainable.save.remote() anytime we want

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, I guess I am still not sure why we want to default checkpoint_at_end to True just because we are able to do it for class trainables.

Signed-off-by: Kai Fricke <[email protected]>
@krfricke krfricke requested a review from gjoliver July 19, 2022 13:24
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this.

@richardliaw richardliaw merged commit 02dded1 into ray-project:master Jul 19, 2022
@krfricke krfricke deleted the air/checkpoint-freq branch July 19, 2022 17:23
Stefan-1313 pushed a commit to Stefan-1313/ray_mod that referenced this pull request Aug 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants