Skip to content

Commit

Permalink
[tune] deflake pbt. (#21366)
Browse files Browse the repository at this point in the history
We use `trial.checkpoint` to restore a perturbed trial. Currently trial.checkpoint is looking at both in-memory and persistent checkpoints to find the most recent one. The definition of "the most recent one" is based on iteration. This may no longer be a valid assumption in PBT case, considering `trial_low_quantile` may have an iter=2_persistent_checkpoint as well as a iter=1_in_memory_checkpoint (perturbed from `trial_upper_quantile`).
  • Loading branch information
xwjiang2010 authored Jan 4, 2022
1 parent e453837 commit fc22200
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions python/ray/tune/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import gc
import logging

from ray.tune.result import TRAINING_ITERATION
from ray.tune.utils.util import flatten_dict

logger = logging.getLogger(__name__)
Expand All @@ -28,6 +27,10 @@ def __init__(self, storage, value, result=None):
self.storage = storage
self.value = value
self.result = result or {}
# The logical order of checkpoints (both in memory and persistent)
# The more recent checkpoints have larger order.
# The most recent checkpoint is used to restore the trial.
self.order = 0

@staticmethod
def from_object(value=None):
Expand Down Expand Up @@ -93,13 +96,14 @@ def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn):
self._newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()
self._cur_order = 0

@property
def newest_checkpoint(self):
"""Returns the newest checkpoint (based on training iteration)."""
newest_checkpoint = max(
[self.newest_persistent_checkpoint, self.newest_memory_checkpoint],
key=lambda c: c.result.get(TRAINING_ITERATION, -1))
key=lambda c: c.order)
return newest_checkpoint

@property
Expand All @@ -116,20 +120,28 @@ def replace_newest_memory_checkpoint(self, new_checkpoint):
def on_checkpoint(self, checkpoint):
"""Starts tracking checkpoint metadata on checkpoint.
Checkpoints get assigned with an `order` as they come in.
The order is monotonically increasing.
Sets the newest checkpoint. For PERSISTENT checkpoints: Deletes
previous checkpoint as long as it isn't one of the best ones. Also
deletes the worst checkpoint if at capacity.
Args:
checkpoint (Checkpoint): Trial state checkpoint.
"""
self._cur_order += 1
checkpoint.order = self._cur_order

if checkpoint.storage == Checkpoint.MEMORY:
self.replace_newest_memory_checkpoint(checkpoint)
return

old_checkpoint = self.newest_persistent_checkpoint

if old_checkpoint.value == checkpoint.value:
# Overwrite the order of the checkpoint.
old_checkpoint.order = checkpoint.order
return

self.newest_persistent_checkpoint = checkpoint
Expand Down

0 comments on commit fc22200

Please sign in to comment.