Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] [PBT] [Doc] Fix and clean up PBT examples #29060

Merged

Conversation

justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Oct 4, 2022

Why are these changes needed?

Many examples in the docs related to PBT either did not have checkpointing or were checkpointing/loading incorrectly. The examples updated are:

  • PBT User Guide (originally tune-advanced-tutorial) (converted this to a runnable notebook with the fixes)
  • pbt_function
  • pbt_example
  • pbt_memnn_example
  • pbt_tune_cifar10_with_keras
  • pb2_example
    • This example uses the same training function as pbt_function, so some of the parameters needed to be updated.
  • tune_cifar_torch_pbt_example
    • This one affects the long-running pytorch_pbt_failure release test (it uses the train function defined in this example).

This is especially the case with examples that use the Function Trainable API with PBT, since we require the user to checkpoint themselves via session.report. This makes it a bit harder to align checkpoints with PBT perturbations. The main issue is that it's difficult for the user to keep track of the iteration number themselves (which is what they need to do if they want to copy the checkpoint_frequency functionality that is available when using the class Trainable API).

Consider the following cases:

  1. The starting step needs to be set to 1. Otherwise, checkpointing and perturbation will be out of sync:
step = 0

# Checkpoint every `checkpoint_interval` steps
if step % checkpoint_interval == 0:
    # NOTE: Since we initialized `step=0` above, our checkpointing and perturbing
    # are out of sync by 1 step.
    # Ex: if `checkpoint_interval` = `perturbation_interval` = 3
    # step:                0 (checkpoint)  1     2            3 (checkpoint)
    # training_iteration:  1               2     3 (perturb)  4
    checkpoint = Checkpoint.from_dict({"acc": accuracy, "step": step})
session.report(..., checkpoint=checkpoint)
step += 1

vs.

step = 1

# Checkpoint every `checkpoint_interval` steps
if step % checkpoint_interval == 0:
    # Fixed if we initialize step = 1
    # Ex: if `checkpoint_interval` = `perturbation_interval` = 3
    # step:                1          2     3 (checkpoint)     4
    # training_iteration:  1          2     3 (perturb)        4
    checkpoint = Checkpoint.from_dict({"acc": accuracy, "step": step})
session.report(..., checkpoint=checkpoint)
step += 1
  1. The user can easily start from the wrong step upon restore if they don't increment the checkpointed step by 1.
if session.get_checkpoint():
    state = session.get_checkpoint().to_dict()
    accuracy = state["acc"]
    last_step = state["step"]
    # Current step should be 1 more than the last checkpoint step
    # If we did `step = last_step` instead, we might repeat the step and end up
    # checkpointing more than we want to.
    # Ex: last_step = 4, step = 4 --> if `checkpoint_interval = 4`,
    # then we would checkpoint again, even though we just restored.
    # Should be last_step = 4, step = 5 --> next checkpoint will be step = 8
    step = last_step + 1

Open questions:

  • Would this issue of requiring the user to manually keep another step counter be solved if we introduced a session.get_training_iteration() API?
    • One problem with this is that the user needs to create the checkpoint before calling session.report, and session.report is what increments the training iteration.

Related issue number

Closes #22733

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…invyu/train-pbt-test-checkpoint

Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
@justinvyu justinvyu self-assigned this Oct 4, 2022
Copy link
Member

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to use framework-specific checkpoints where applicable

python/ray/train/examples/tune_cifar_torch_pbt_example.py Outdated Show resolved Hide resolved
@justinvyu justinvyu marked this pull request as ready for review October 25, 2022 16:05
Comment on lines 115 to 116
with FileLock(".ray.lock"):
data_dir = config.get("data_dir", "~/data")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use something like os.path.expanduser("~/.ray.lock") instead? Ideally tie it to the data_dir. If each worker runs in a separate dir, then they will not use the same lock files.

@@ -110,22 +136,25 @@ def train_func(config):
# Create loss.
criterion = nn.CrossEntropyLoss()

results = []
for _ in range(epochs):
while True:
Copy link
Member

@Yard1 Yard1 Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reason for using while here? I realize we are defining stop conditions later but the common pattern in examples is to use a for loop anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can change this back.


# Optimizer configs (`lr`, `momentum`) are being mutated by PBT and passed in
# through config, so we need to update the optimizer loaded from the checkpoint
update_optimizer_config(optimizer, optimizer_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, following this, is that the case that one should probably not use LR scheduler with PBT that also mutates LR stuff?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we just need to also save and load the learning rate scheduler state (which holds the epoch information), along with the optimizer. Then, PBT will perturb the LR, but it will still follow the same schedule. We could show how to do it in this example or another example maybe?

@@ -3,4 +3,9 @@
PBT Function Example
~~~~~~~~~~~~~~~~~~~~

The following script produces the following results. For a population of 8 trials,
the PBT learning rate schedule roughly matches the optimal learning rate schedule.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, how should I interpret this result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reproduced the original pbt_function plots, and they were pretty much the same as these. I think the idea is that the cur_lr plot roughly matches the optimal_lr schedule.

Copy link
Member

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Contributor

@xwjiang2010 xwjiang2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is a great improvement!

@richardliaw richardliaw merged commit 8c4e6dc into ray-project:master Oct 27, 2022
WeichenXu123 pushed a commit to WeichenXu123/ray that referenced this pull request Dec 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] Train PBT example is not using checkpointing
4 participants