From 52eaf020bcd4e8ebeb94af11a8039313a37488d1 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 14 Apr 2022 11:11:37 +0200 Subject: [PATCH] [tune] Treat checkpoints with nan value as worst (#23862) Changes the logic in CheckpointManager to consider checkpoints with nan value of the metric as worst values, meaning they will be deleted first if keep_checkpoints_num is set. --- python/ray/train/checkpoint.py | 12 +++-- python/ray/train/tests/test_trainer.py | 9 ++-- python/ray/tune/checkpoint_manager.py | 10 +++- .../ray/tune/tests/test_checkpoint_manager.py | 52 ++++++++++++++----- python/ray/tune/utils/util.py | 12 ++--- python/ray/util/ml_utils/util.py | 9 ++++ 6 files changed, 73 insertions(+), 31 deletions(-) diff --git a/python/ray/train/checkpoint.py b/python/ray/train/checkpoint.py index 504ec04c4cdd..dd03ed3197eb 100644 --- a/python/ray/train/checkpoint.py +++ b/python/ray/train/checkpoint.py @@ -12,6 +12,7 @@ from ray.train.session import TrainingResult from ray.train.utils import construct_path from ray.util import PublicAPI +from ray.util.ml_utils.util import is_nan if TUNE_INSTALLED: from ray import tune @@ -212,10 +213,13 @@ def write_checkpoint(self, checkpoint: Dict): ) def priority(checkpoint_score_order, checkpoint_score): - if checkpoint_score_order == MAX: - return checkpoint_score - else: - return -checkpoint_score + # Treat NaN as worst + # The tuple structure is (not is_nan(), metric), which makes + # the nan values to be always considered as the worst + # metrics by the heap + if checkpoint_score_order != MAX: + checkpoint_score = -checkpoint_score + return (not is_nan(checkpoint_score), checkpoint_score) checkpoint_priority = priority(checkpoint_score_order, checkpoint_score) diff --git a/python/ray/train/tests/test_trainer.py b/python/ray/train/tests/test_trainer.py index 290fe30c580e..6233bd527dc1 100644 --- a/python/ray/train/tests/test_trainer.py +++ b/python/ray/train/tests/test_trainer.py @@ -517,6 +517,7 @@ def test_persisted_checkpoint_strategy(ray_start_2_cpus): ) def train_func(): + train.save_checkpoint(loss=float("nan")) # nan, deleted train.save_checkpoint(loss=3) # best train.save_checkpoint(loss=7) # worst, deleted train.save_checkpoint(loss=5) @@ -530,16 +531,16 @@ def train_func(): assert trainer.logdir == Path(logdir).expanduser().resolve() assert trainer.latest_checkpoint_dir.is_dir() assert trainer.best_checkpoint_path.is_file() - assert trainer.best_checkpoint_path.name == f"checkpoint_{1:06d}" + assert trainer.best_checkpoint_path.name == f"checkpoint_{2:06d}" assert trainer.latest_checkpoint["loss"] == 5 assert trainer.best_checkpoint["loss"] == 3 checkpoint_dir = trainer.latest_checkpoint_dir file_names = [f.name for f in checkpoint_dir.iterdir()] assert len(file_names) == 2 - assert f"checkpoint_{1:06d}" in file_names - assert f"checkpoint_{2:06d}" not in file_names - assert f"checkpoint_{3:06d}" in file_names + assert f"checkpoint_{2:06d}" in file_names + assert f"checkpoint_{3:06d}" not in file_names + assert f"checkpoint_{4:06d}" in file_names def validate(): checkpoint = train.load_checkpoint() diff --git a/python/ray/tune/checkpoint_manager.py b/python/ray/tune/checkpoint_manager.py index b0a7662ea7d7..2d9dd4749f66 100644 --- a/python/ray/tune/checkpoint_manager.py +++ b/python/ray/tune/checkpoint_manager.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Optional from ray.tune.result import NODE_IP -from ray.tune.utils.util import flatten_dict +from ray.tune.utils.util import flatten_dict, is_nan logger = logging.getLogger(__name__) @@ -168,6 +168,10 @@ def on_checkpoint(self, checkpoint: _TuneCheckpoint): self.delete(old_checkpoint) try: + # NaN metrics are treated as worst checkpoint + # The tuple structure is (not is_nan(), metric), which makes + # the nan values to be always considered as the worst + # metrics by the heap queue_item = QueueItem(self._priority(checkpoint), checkpoint) except KeyError: logger.error( @@ -198,7 +202,9 @@ def best_checkpoints(self): def _priority(self, checkpoint): result = flatten_dict(checkpoint.result) priority = result[self._checkpoint_score_attr] - return -priority if self._checkpoint_score_desc else priority + if self._checkpoint_score_desc: + priority = -priority + return (not is_nan(priority), priority, checkpoint.order) def __getstate__(self): state = self.__dict__.copy() diff --git a/python/ray/tune/tests/test_checkpoint_manager.py b/python/ray/tune/tests/test_checkpoint_manager.py index 95aba1590371..9b921fa4cccd 100644 --- a/python/ray/tune/tests/test_checkpoint_manager.py +++ b/python/ray/tune/tests/test_checkpoint_manager.py @@ -12,8 +12,8 @@ class CheckpointManagerTest(unittest.TestCase): @staticmethod - def mock_result(i): - return {"i": i, TRAINING_ITERATION: i} + def mock_result(metric, i): + return {"i": metric, TRAINING_ITERATION: i} def checkpoint_manager(self, keep_checkpoints_num): return CheckpointManager(keep_checkpoints_num, "i", delete_fn=lambda c: None) @@ -21,11 +21,11 @@ def checkpoint_manager(self, keep_checkpoints_num): def testNewestCheckpoint(self): checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) memory_checkpoint = _TuneCheckpoint( - _TuneCheckpoint.MEMORY, {0}, self.mock_result(0) + _TuneCheckpoint.MEMORY, {0}, self.mock_result(0, 0) ) checkpoint_manager.on_checkpoint(memory_checkpoint) persistent_checkpoint = _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, {1}, self.mock_result(1) + _TuneCheckpoint.PERSISTENT, {1}, self.mock_result(1, 1) ) checkpoint_manager.on_checkpoint(persistent_checkpoint) self.assertEqual( @@ -40,7 +40,7 @@ def testOnCheckpointOrdered(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i)) + _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i, i)) for i in range(3) ] @@ -66,7 +66,7 @@ def testOnCheckpointUnordered(self): keep_checkpoints_num = 2 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i)) + _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i, i)) for i in range(3, -1, -1) ] @@ -91,7 +91,7 @@ def testBestCheckpoints(self): keep_checkpoints_num = 4 checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) checkpoints = [ - _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i)) + _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i, i)) for i in range(16) ] random.shuffle(checkpoints) @@ -104,6 +104,32 @@ def testBestCheckpoints(self): for i in range(len(best_checkpoints)): self.assertEqual(best_checkpoints[i].value, i + 12) + def testBestCheckpointsWithNan(self): + """ + Tests that checkpoints with nan priority are handled correctly. + """ + keep_checkpoints_num = 2 + checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num) + checkpoints = [ + _TuneCheckpoint( + _TuneCheckpoint.PERSISTENT, None, self.mock_result(float("nan"), i) + ) + for i in range(2) + ] + checkpoints += [ + _TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3, self.mock_result(0, 3)) + ] + random.shuffle(checkpoints) + + for checkpoint in checkpoints: + checkpoint_manager.on_checkpoint(checkpoint) + + best_checkpoints = checkpoint_manager.best_checkpoints() + # best_checkpoints is sorted from worst to best + self.assertEqual(len(best_checkpoints), keep_checkpoints_num) + self.assertEqual(best_checkpoints[0].value, None) + self.assertEqual(best_checkpoints[1].value, 3) + def testOnCheckpointUnavailableAttribute(self): """ Tests that an error is logged when the associated result of the @@ -122,8 +148,8 @@ def testOnCheckpointUnavailableAttribute(self): def testOnMemoryCheckpoint(self): checkpoints = [ - _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)), - _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)), + _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0, 0)), + _TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0, 0)), ] checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1) checkpoint_manager.on_checkpoint(checkpoints[0]) @@ -147,16 +173,16 @@ def testSameCheckpoint(self): checkpoints = [ _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, tmpfiles[0], self.mock_result(5) + _TuneCheckpoint.PERSISTENT, tmpfiles[0], self.mock_result(5, 5) ), _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(10) + _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(10, 10) ), _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, tmpfiles[2], self.mock_result(0) + _TuneCheckpoint.PERSISTENT, tmpfiles[2], self.mock_result(0, 0) ), _TuneCheckpoint( - _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(20) + _TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(20, 20) ), ] for checkpoint in checkpoints: diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 9bb320fb8577..6b096a73ba67 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -26,6 +26,10 @@ unflattened_lookup, ) from ray.util.ml_utils.json import SafeFallbackEncoder # noqa +from ray.util.ml_utils.util import ( # noqa: F401 + is_nan, + is_nan_or_inf, +) logger = logging.getLogger(__name__) @@ -276,14 +280,6 @@ def date_str(): return datetime.today().strftime("%Y-%m-%d_%H-%M-%S") -def is_nan(value): - return np.isnan(value) - - -def is_nan_or_inf(value): - return is_nan(value) or np.isinf(value) - - def _to_pinnable(obj): """Converts obj to a form that can be pinned in object store memory. diff --git a/python/ray/util/ml_utils/util.py b/python/ray/util/ml_utils/util.py index 0bd6e177f030..4f99f8bf92a3 100644 --- a/python/ray/util/ml_utils/util.py +++ b/python/ray/util/ml_utils/util.py @@ -1,5 +1,6 @@ from contextlib import closing import socket +import numpy as np def find_free_port(): @@ -7,3 +8,11 @@ def find_free_port(): s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] + + +def is_nan(value): + return np.isnan(value) + + +def is_nan_or_inf(value): + return is_nan(value) or np.isinf(value)