Skip to content

Commit

Permalink
[air/tuner] Expose number of errored/terminated trials in ResultGrid (r…
Browse files Browse the repository at this point in the history
…ay-project#26655)

This introduces an easy interface to retrieve the number of errored and terminated (non-errored) trials from the result grid.

Previously `tune.run(raise_on_failed_trial)` could be used to raise a TuneError if at least one trial failed. We've removed this option to make sure we always get a return value. `ResultGrid.num_errored` will make it easy for users to identify if trials failed and react to it instead of the old try-catch loop.

Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Xiaowei Jiang <[email protected]>
  • Loading branch information
krfricke authored and xwjiang2010 committed Jul 19, 2022
1 parent 14b2291 commit 3e14b45
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 24 deletions.
23 changes: 23 additions & 0 deletions python/ray/tune/result_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,29 @@ def __getitem__(self, i: int) -> Result:
self._experiment_analysis.trials[i],
)

@property
def errors(self):
"""Returns the exceptions of errored trials."""
return [result.error for result in self if result.error]

@property
def num_errors(self):
"""Returns the number of errored trials."""
return len(
[t for t in self._experiment_analysis.trials if t.status == Trial.ERROR]
)

@property
def num_terminated(self):
"""Returns the number of terminated (but not errored) trials."""
return len(
[
t
for t in self._experiment_analysis.trials
if t.status == Trial.TERMINATED
]
)

@staticmethod
def _populate_exception(trial: Trial) -> Optional[Union[TuneError, RayTaskError]]:
if trial.pickled_error_file and os.path.exists(trial.pickled_error_file):
Expand Down
27 changes: 27 additions & 0 deletions python/ray/tune/tests/test_result_grid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import pickle
import shutil

import pytest
import pandas as pd
Expand All @@ -11,6 +12,7 @@
from ray.tune.registry import get_trainable_cls
from ray.tune.result_grid import ResultGrid
from ray.tune.experiment import Trial
from ray.tune.tests.tune_test_util import create_tune_experiment_checkpoint
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint


Expand Down Expand Up @@ -230,6 +232,31 @@ def f(config):
assert sorted(df["config/nested/param"]) == [1, 2]


def test_num_errors_terminated(tmpdir):
error_file = tmpdir / "error.txt"
with open(error_file, "w") as fp:
fp.write("Test error\n")

trials = [Trial("foo", stub=True) for i in range(10)]
trials[4].status = Trial.ERROR
trials[6].status = Trial.ERROR
trials[8].status = Trial.ERROR

trials[4].error_file = error_file
trials[6].error_file = error_file
trials[8].error_file = error_file

trials[3].status = Trial.TERMINATED
trials[5].status = Trial.TERMINATED

experiment_dir = create_tune_experiment_checkpoint(trials)
result_grid = ResultGrid(tune.ExperimentAnalysis(experiment_dir))
assert len(result_grid.errors) == 3
assert result_grid.num_errors == 3
assert result_grid.num_terminated == 2
shutil.rmtree(experiment_dir)


if __name__ == "__main__":
import sys

Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/tests/test_trial_runner_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ray.tune.execution.trial_runner import TrialRunner
from ray.tune.resources import Resources
from ray.tune.search import BasicVariantGenerator
from ray.tune.tests.utils_for_test_trial_runner import TrialResultObserver
from ray.tune.tests.tune_test_util import TrialResultObserver
from ray.tune.trainable.util import TrainableUtil
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage

Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/tests/test_trial_runner_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ray.tune.search import Searcher, ConcurrencyLimiter
from ray.tune.search.search_generator import SearchGenerator
from ray.tune.syncer import SyncConfig, Syncer
from ray.tune.tests.utils_for_test_trial_runner import TrialResultObserver
from ray.tune.tests.tune_test_util import TrialResultObserver


class TrialRunnerTest3(unittest.TestCase):
Expand Down
51 changes: 51 additions & 0 deletions python/ray/tune/tests/tune_test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import tempfile

from ray.tune import Callback
from ray.tune.execution.trial_runner import TrialRunner


class TrialResultObserver(Callback):
"""Helper class to control runner.step() count."""

def __init__(self):
self._counter = 0
self._last_counter = 0

def reset(self):
self._last_counter = self._counter

def just_received_a_result(self):
if self._last_counter == self._counter:
return False
else:
self._last_counter = self._counter
return True

def on_trial_result(self, **kwargs):
self._counter += 1


def create_tune_experiment_checkpoint(trials: list, **runner_kwargs) -> str:
experiment_dir = tempfile.mkdtemp()
runner_kwargs.setdefault("local_checkpoint_dir", experiment_dir)

# Update environment
orig_env = os.environ.copy()

# Set to 1 to disable ray cluster resource lookup. That way we can
# create experiment checkpoints without initializing ray.
os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"

try:
runner = TrialRunner(**runner_kwargs)

for trial in trials:
runner.add_trial(trial)

runner.checkpoint(force=True)
finally:
os.environ.clear()
os.environ.update(orig_env)

return experiment_dir
22 changes: 0 additions & 22 deletions python/ray/tune/tests/utils_for_test_trial_runner.py

This file was deleted.

0 comments on commit 3e14b45

Please sign in to comment.