Skip to content

Commit

Permalink
[tune] Fix checkpoint sorting with nan values (#23909)
Browse files Browse the repository at this point in the history
Following #23862, there was an uncaught bug when comparing nan-priority checkpoints. This is because float("nan") <= float("nan") is always False (unlike e.g. np.nan <= np.nan, which is True).

This PR fixes this bug and adds a new test to ensure correct behavior.
  • Loading branch information
krfricke authored Apr 14, 2022
1 parent 52eaf02 commit 79b154c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/ray/tune/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ def _priority(self, checkpoint):
priority = result[self._checkpoint_score_attr]
if self._checkpoint_score_desc:
priority = -priority
return (not is_nan(priority), priority, checkpoint.order)
return (
not is_nan(priority),
priority if not is_nan(priority) else 0,
checkpoint.order,
)

def __getstate__(self):
state = self.__dict__.copy()
Expand Down
22 changes: 22 additions & 0 deletions python/ray/tune/tests/test_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,28 @@ def testBestCheckpointsWithNan(self):
self.assertEqual(best_checkpoints[0].value, None)
self.assertEqual(best_checkpoints[1].value, 3)

def testBestCheckpointsOnlyNan(self):
"""
Tests that checkpoints with only nan priority are handled correctly.
"""
keep_checkpoints_num = 2
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, i, self.mock_result(float("nan"), i)
)
for i in range(4)
]

for checkpoint in checkpoints:
checkpoint_manager.on_checkpoint(checkpoint)

best_checkpoints = checkpoint_manager.best_checkpoints()
# best_checkpoints is sorted from worst to best
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
self.assertEqual(best_checkpoints[0].value, 2)
self.assertEqual(best_checkpoints[1].value, 3)

def testOnCheckpointUnavailableAttribute(self):
"""
Tests that an error is logged when the associated result of the
Expand Down

0 comments on commit 79b154c

Please sign in to comment.