Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] deflake pbt. #21366

Merged
merged 5 commits into from
Jan 4, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions python/ray/tune/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import gc
import logging

from ray.tune.result import TRAINING_ITERATION
from ray.tune.utils.util import flatten_dict

logger = logging.getLogger(__name__)
Expand All @@ -28,6 +27,10 @@ def __init__(self, storage, value, result=None):
self.storage = storage
self.value = value
self.result = result or {}
# The logical order of checkpoints (both in memory and persistent)
# The more recent checkpoints have larger order.
# The most recent checkpoint is used to restore the trial.
self.order = 0

@staticmethod
def from_object(value=None):
Expand Down Expand Up @@ -93,13 +96,14 @@ def __init__(self, keep_checkpoints_num, checkpoint_score_attr, delete_fn):
self._newest_memory_checkpoint = Checkpoint(Checkpoint.MEMORY, None)
self._best_checkpoints = []
self._membership = set()
self._cur_order = 0

@property
def newest_checkpoint(self):
"""Returns the newest checkpoint (based on training iteration)."""
newest_checkpoint = max(
[self.newest_persistent_checkpoint, self.newest_memory_checkpoint],
key=lambda c: c.result.get(TRAINING_ITERATION, -1))
key=lambda c: c.order)
return newest_checkpoint

@property
Expand All @@ -114,6 +118,11 @@ def replace_newest_memory_checkpoint(self, new_checkpoint):
self._newest_memory_checkpoint = new_checkpoint

def on_checkpoint(self, checkpoint):
self._cur_order += 1
checkpoint.order = self._cur_order
self._on_checkpoint_internal(checkpoint)

def _on_checkpoint_internal(self, checkpoint):
"""Starts tracking checkpoint metadata on checkpoint.

Sets the newest checkpoint. For PERSISTENT checkpoints: Deletes
Expand All @@ -130,6 +139,8 @@ def on_checkpoint(self, checkpoint):
old_checkpoint = self.newest_persistent_checkpoint

if old_checkpoint.value == checkpoint.value:
# Override the order of the checkpoint.
old_checkpoint.order = checkpoint.order
return

self.newest_persistent_checkpoint = checkpoint
Expand Down