-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tune] [PBT] [Doc] Fix and clean up PBT examples #29060
[Tune] [PBT] [Doc] Fix and clean up PBT examples #29060
Conversation
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…invyu/train-pbt-test-checkpoint
Signed-off-by: Justin Yu <[email protected]>
…invyu/train-pbt-test-checkpoint
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…invyu/train-pbt-test-checkpoint Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…ss API pbt_dcgan_mnist Signed-off-by: Justin Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure to use framework-specific checkpoints where applicable
Signed-off-by: Justin Yu <[email protected]>
…invyu/train-pbt-test-checkpoint
…invyu/train-pbt-test-checkpoint
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…rately) Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
with FileLock(".ray.lock"): | ||
data_dir = config.get("data_dir", "~/data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use something like os.path.expanduser("~/.ray.lock")
instead? Ideally tie it to the data_dir
. If each worker runs in a separate dir, then they will not use the same lock files.
@@ -110,22 +136,25 @@ def train_func(config): | |||
# Create loss. | |||
criterion = nn.CrossEntropyLoss() | |||
|
|||
results = [] | |||
for _ in range(epochs): | |||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the reason for using while here? I realize we are defining stop conditions later but the common pattern in examples is to use a for loop anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I can change this back.
|
||
# Optimizer configs (`lr`, `momentum`) are being mutated by PBT and passed in | ||
# through config, so we need to update the optimizer loaded from the checkpoint | ||
update_optimizer_config(optimizer, optimizer_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, following this, is that the case that one should probably not use LR scheduler with PBT that also mutates LR stuff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we just need to also save and load the learning rate scheduler state (which holds the epoch information), along with the optimizer. Then, PBT will perturb the LR, but it will still follow the same schedule. We could show how to do it in this example or another example maybe?
@@ -3,4 +3,9 @@ | |||
PBT Function Example | |||
~~~~~~~~~~~~~~~~~~~~ | |||
|
|||
The following script produces the following results. For a population of 8 trials, | |||
the PBT learning rate schedule roughly matches the optimal learning rate schedule. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, how should I interpret this result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reproduced the original pbt_function plots, and they were pretty much the same as these. I think the idea is that the cur_lr
plot roughly matches the optimal_lr
schedule.
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is a great improvement!
…invyu/train-pbt-test-checkpoint Signed-off-by: Justin Yu <[email protected]>
…invyu/train-pbt-test-checkpoint
Signed-off-by: Weichen Xu <[email protected]>
Why are these changes needed?
Many examples in the docs related to PBT either did not have checkpointing or were checkpointing/loading incorrectly. The examples updated are:
tune-advanced-tutorial
) (converted this to a runnable notebook with the fixes)pbt_function
pbt_example
pbt_memnn_example
pbt_tune_cifar10_with_keras
pb2_example
pbt_function
, so some of the parameters needed to be updated.tune_cifar_torch_pbt_example
pytorch_pbt_failure
release test (it uses the train function defined in this example).This is especially the case with examples that use the Function Trainable API with PBT, since we require the user to checkpoint themselves via
session.report
. This makes it a bit harder to align checkpoints with PBT perturbations. The main issue is that it's difficult for the user to keep track of the iteration number themselves (which is what they need to do if they want to copy thecheckpoint_frequency
functionality that is available when using the class Trainable API).Consider the following cases:
vs.
Open questions:
step
counter be solved if we introduced asession.get_training_iteration()
API?session.report
, andsession.report
is what increments the training iteration.Related issue number
Closes #22733
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.