Skip to content

Commit

Permalink
[tune] Treat checkpoints with nan value as worst (#23862)
Browse files Browse the repository at this point in the history
Changes the logic in CheckpointManager to consider checkpoints with nan value of the metric as worst values, meaning they will be deleted first if keep_checkpoints_num is set.
  • Loading branch information
Yard1 authored Apr 14, 2022
1 parent 6e37a48 commit 52eaf02
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 31 deletions.
12 changes: 8 additions & 4 deletions python/ray/train/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.train.session import TrainingResult
from ray.train.utils import construct_path
from ray.util import PublicAPI
from ray.util.ml_utils.util import is_nan

if TUNE_INSTALLED:
from ray import tune
Expand Down Expand Up @@ -212,10 +213,13 @@ def write_checkpoint(self, checkpoint: Dict):
)

def priority(checkpoint_score_order, checkpoint_score):
if checkpoint_score_order == MAX:
return checkpoint_score
else:
return -checkpoint_score
# Treat NaN as worst
# The tuple structure is (not is_nan(), metric), which makes
# the nan values to be always considered as the worst
# metrics by the heap
if checkpoint_score_order != MAX:
checkpoint_score = -checkpoint_score
return (not is_nan(checkpoint_score), checkpoint_score)

checkpoint_priority = priority(checkpoint_score_order, checkpoint_score)

Expand Down
9 changes: 5 additions & 4 deletions python/ray/train/tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def test_persisted_checkpoint_strategy(ray_start_2_cpus):
)

def train_func():
train.save_checkpoint(loss=float("nan")) # nan, deleted
train.save_checkpoint(loss=3) # best
train.save_checkpoint(loss=7) # worst, deleted
train.save_checkpoint(loss=5)
Expand All @@ -530,16 +531,16 @@ def train_func():
assert trainer.logdir == Path(logdir).expanduser().resolve()
assert trainer.latest_checkpoint_dir.is_dir()
assert trainer.best_checkpoint_path.is_file()
assert trainer.best_checkpoint_path.name == f"checkpoint_{1:06d}"
assert trainer.best_checkpoint_path.name == f"checkpoint_{2:06d}"
assert trainer.latest_checkpoint["loss"] == 5
assert trainer.best_checkpoint["loss"] == 3

checkpoint_dir = trainer.latest_checkpoint_dir
file_names = [f.name for f in checkpoint_dir.iterdir()]
assert len(file_names) == 2
assert f"checkpoint_{1:06d}" in file_names
assert f"checkpoint_{2:06d}" not in file_names
assert f"checkpoint_{3:06d}" in file_names
assert f"checkpoint_{2:06d}" in file_names
assert f"checkpoint_{3:06d}" not in file_names
assert f"checkpoint_{4:06d}" in file_names

def validate():
checkpoint = train.load_checkpoint()
Expand Down
10 changes: 8 additions & 2 deletions python/ray/tune/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Callable, Optional

from ray.tune.result import NODE_IP
from ray.tune.utils.util import flatten_dict
from ray.tune.utils.util import flatten_dict, is_nan

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -168,6 +168,10 @@ def on_checkpoint(self, checkpoint: _TuneCheckpoint):
self.delete(old_checkpoint)

try:
# NaN metrics are treated as worst checkpoint
# The tuple structure is (not is_nan(), metric), which makes
# the nan values to be always considered as the worst
# metrics by the heap
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
except KeyError:
logger.error(
Expand Down Expand Up @@ -198,7 +202,9 @@ def best_checkpoints(self):
def _priority(self, checkpoint):
result = flatten_dict(checkpoint.result)
priority = result[self._checkpoint_score_attr]
return -priority if self._checkpoint_score_desc else priority
if self._checkpoint_score_desc:
priority = -priority
return (not is_nan(priority), priority, checkpoint.order)

def __getstate__(self):
state = self.__dict__.copy()
Expand Down
52 changes: 39 additions & 13 deletions python/ray/tune/tests/test_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@

class CheckpointManagerTest(unittest.TestCase):
@staticmethod
def mock_result(i):
return {"i": i, TRAINING_ITERATION: i}
def mock_result(metric, i):
return {"i": metric, TRAINING_ITERATION: i}

def checkpoint_manager(self, keep_checkpoints_num):
return CheckpointManager(keep_checkpoints_num, "i", delete_fn=lambda c: None)

def testNewestCheckpoint(self):
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
memory_checkpoint = _TuneCheckpoint(
_TuneCheckpoint.MEMORY, {0}, self.mock_result(0)
_TuneCheckpoint.MEMORY, {0}, self.mock_result(0, 0)
)
checkpoint_manager.on_checkpoint(memory_checkpoint)
persistent_checkpoint = _TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, {1}, self.mock_result(1)
_TuneCheckpoint.PERSISTENT, {1}, self.mock_result(1, 1)
)
checkpoint_manager.on_checkpoint(persistent_checkpoint)
self.assertEqual(
Expand All @@ -40,7 +40,7 @@ def testOnCheckpointOrdered(self):
keep_checkpoints_num = 2
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i))
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i, i))
for i in range(3)
]

Expand All @@ -66,7 +66,7 @@ def testOnCheckpointUnordered(self):
keep_checkpoints_num = 2
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i))
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i, i))
for i in range(3, -1, -1)
]

Expand All @@ -91,7 +91,7 @@ def testBestCheckpoints(self):
keep_checkpoints_num = 4
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i))
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i, i))
for i in range(16)
]
random.shuffle(checkpoints)
Expand All @@ -104,6 +104,32 @@ def testBestCheckpoints(self):
for i in range(len(best_checkpoints)):
self.assertEqual(best_checkpoints[i].value, i + 12)

def testBestCheckpointsWithNan(self):
"""
Tests that checkpoints with nan priority are handled correctly.
"""
keep_checkpoints_num = 2
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, None, self.mock_result(float("nan"), i)
)
for i in range(2)
]
checkpoints += [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3, self.mock_result(0, 3))
]
random.shuffle(checkpoints)

for checkpoint in checkpoints:
checkpoint_manager.on_checkpoint(checkpoint)

best_checkpoints = checkpoint_manager.best_checkpoints()
# best_checkpoints is sorted from worst to best
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
self.assertEqual(best_checkpoints[0].value, None)
self.assertEqual(best_checkpoints[1].value, 3)

def testOnCheckpointUnavailableAttribute(self):
"""
Tests that an error is logged when the associated result of the
Expand All @@ -122,8 +148,8 @@ def testOnCheckpointUnavailableAttribute(self):

def testOnMemoryCheckpoint(self):
checkpoints = [
_TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)),
_TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)),
_TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0, 0)),
_TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0, 0)),
]
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
checkpoint_manager.on_checkpoint(checkpoints[0])
Expand All @@ -147,16 +173,16 @@ def testSameCheckpoint(self):

checkpoints = [
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, tmpfiles[0], self.mock_result(5)
_TuneCheckpoint.PERSISTENT, tmpfiles[0], self.mock_result(5, 5)
),
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(10)
_TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(10, 10)
),
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, tmpfiles[2], self.mock_result(0)
_TuneCheckpoint.PERSISTENT, tmpfiles[2], self.mock_result(0, 0)
),
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(20)
_TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(20, 20)
),
]
for checkpoint in checkpoints:
Expand Down
12 changes: 4 additions & 8 deletions python/ray/tune/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
unflattened_lookup,
)
from ray.util.ml_utils.json import SafeFallbackEncoder # noqa
from ray.util.ml_utils.util import ( # noqa: F401
is_nan,
is_nan_or_inf,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -276,14 +280,6 @@ def date_str():
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")


def is_nan(value):
return np.isnan(value)


def is_nan_or_inf(value):
return is_nan(value) or np.isinf(value)


def _to_pinnable(obj):
"""Converts obj to a form that can be pinned in object store memory.
Expand Down
9 changes: 9 additions & 0 deletions python/ray/util/ml_utils/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from contextlib import closing
import socket
import numpy as np


def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]


def is_nan(value):
return np.isnan(value)


def is_nan_or_inf(value):
return is_nan(value) or np.isinf(value)

0 comments on commit 52eaf02

Please sign in to comment.