Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune; RLlib] Missing stopping criterion should not error (just warn). #45613

Merged
20 changes: 12 additions & 8 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from ray.tune.trainable.metadata import _TrainingRunMetadata
from ray.tune.utils import date_str, flatten_dict
from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
from ray.util import log_once
from ray.util.annotations import Deprecated, DeveloperAPI

DEBUG_PRINT_INTERVAL = 5
Expand Down Expand Up @@ -851,18 +852,21 @@ def should_stop(self, result):
if result.get(DONE):
return True

for criteria, stop_value in self.stopping_criterion.items():
if criteria not in result:
raise TuneError(
"Stopping criteria {} not provided in result dict. Keys "
"are {}.".format(criteria, list(result.keys()))
)
elif isinstance(criteria, dict):
for criterion, stop_value in self.stopping_criterion.items():
if isinstance(criterion, dict):
raise ValueError(
"Stopping criteria is now flattened by default. "
"Use forward slashes to nest values `key1/key2/key3`."
)
elif result[criteria] >= stop_value:
elif criterion not in result:
if log_once("tune_trial_stop_criterion_not_found"):
logger.warning(
f"Stopping criterion '{criterion}' not found in result dict! "
f"Available keys are {list(result.keys())}. If '{criterion}' is"
" never reported, the run will continue until training is "
"finished."
)
elif result[criterion] >= stop_value:
return True
return False

Expand Down
48 changes: 48 additions & 0 deletions python/ray/tune/tests/test_trial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import sys

import pytest
Expand Down Expand Up @@ -116,5 +117,52 @@ def test_trial_logdir_length():
assert len(trial.storage.trial_dir_name) < 200


def test_should_stop(caplog):
"""Test whether `Trial.should_stop()` works as expected given a result dict."""
trial = Trial(
"MockTrainable",
stub=True,
trial_id="abcd1234",
stopping_criterion={"a": 10.0, "b/c": 20.0},
)

# Criterion is not reached yet -> don't stop.
result = _TrainingResult(
checkpoint=None, metrics={"a": 9.999, "b/c": 0.0, "some_other_key": True}
)
assert not trial.should_stop(result.metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just pass in a dict without the _TrainingResult wrapper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


# Criterion is exactly reached -> stop.
result = _TrainingResult(
checkpoint=None, metrics={"a": 10.0, "b/c": 0.0, "some_other_key": False}
)
assert trial.should_stop(result.metrics)

# Criterion is exceeded -> stop.
result = _TrainingResult(
checkpoint=None, metrics={"a": 10000.0, "b/c": 0.0, "some_other_key": False}
)
assert trial.should_stop(result.metrics)

# Test nested criteria.
result = _TrainingResult(
checkpoint=None, metrics={"a": 5.0, "b/c": 1000.0, "some_other_key": False}
)
assert trial.should_stop(result.metrics)

# Test criterion NOT found in result metrics.
result = _TrainingResult(checkpoint=None, metrics={"b/c": 1000.0})
with caplog.at_level(logging.WARNING):
trial.should_stop(result.metrics)
assert (
"Stopping criterion 'a' not found in result dict! Available keys are ['b/c']."
) in caplog.text

# The warning should, however, only be triggered once.
with caplog.at_level(logging.WARNING):
trial.should_stop(result.metrics)
assert "Stopping criterion " not in caplog.text


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
Loading