Skip to content

Commit

Permalink
[Tune] Fix docstring failures (#32484)
Browse files Browse the repository at this point in the history
This PR fixes the `Stopper` doctests that are erroring. Previously, it used a `tune.Trainable` as its trainable, which would error on fit since its methods are not implemented. Also, it was missing some imports.

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored Feb 14, 2023
1 parent 71dfd20 commit 421b527
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
1 change: 1 addition & 0 deletions python/ray/air/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def from_checkpoint(cls, other: "Checkpoint") -> "Checkpoint":
generic :py:class:`Checkpoint` object.
Examples:
>>> result = TorchTrainer.fit(...) # doctest: +SKIP
>>> checkpoint = TorchCheckpoint.from_checkpoint(result.checkpoint) # doctest: +SKIP # noqa: E501
>>> model = checkpoint.get_model() # doctest: +SKIP
Expand Down
58 changes: 39 additions & 19 deletions python/ray/tune/stopper/stopper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from typing import Any, Dict

from ray.util.annotations import PublicAPI

Expand All @@ -15,34 +16,42 @@ class Stopper(abc.ABC):
>>> import time
>>> from ray import air, tune
>>> from ray.air import session
>>> from ray.tune import Stopper
>>>
>>> class TimeStopper(Stopper):
... def __init__(self):
... self._start = time.time()
... self._deadline = 5
... self._deadline = 5 # Stop all trials after 5 seconds
...
... def __call__(self, trial_id, result):
... return False
...
... def stop_all(self):
... return time.time() - self._start > self._deadline
>>>
...
>>> def train_fn(config):
... for i in range(100):
... time.sleep(1)
... session.report({"iter": i})
...
>>> tuner = tune.Tuner(
... tune.Trainable,
... tune_config=tune.TuneConfig(num_samples=200),
... run_config=air.RunConfig(stop=TimeStopper())
... train_fn,
... tune_config=tune.TuneConfig(num_samples=2),
... run_config=air.RunConfig(stop=TimeStopper()),
... )
>>> tuner.fit()
== Status ==...
>>> print("[ignore]"); result_grid = tuner.fit() # doctest: +ELLIPSIS
[ignore]...
>>> all(result.metrics["time_total_s"] < 6 for result in result_grid)
True
"""

def __call__(self, trial_id, result):
def __call__(self, trial_id: str, result: Dict[str, Any]) -> bool:
"""Returns true if the trial should be terminated given the result."""
raise NotImplementedError

def stop_all(self):
def stop_all(self) -> bool:
"""Returns true if the experiment should be terminated."""
raise NotImplementedError

Expand All @@ -56,28 +65,39 @@ class CombinedStopper(Stopper):
Examples:
>>> from ray.tune.stopper import (CombinedStopper,
... MaximumIterationStopper, TrialPlateauStopper)
>>> import numpy as np
>>> from ray import air, tune
>>> from ray.air import session
>>> from ray.tune.stopper import (
... CombinedStopper,
... MaximumIterationStopper,
... TrialPlateauStopper,
... )
>>>
>>> stopper = CombinedStopper(
... MaximumIterationStopper(max_iter=20),
... TrialPlateauStopper(metric="my_metric")
... TrialPlateauStopper(metric="my_metric"),
... )
>>>
>>> def train_fn(config):
... for i in range(25):
... session.report({"my_metric": np.random.normal(0, 1 - i / 25)})
...
>>> tuner = tune.Tuner(
... tune.Trainable,
... run_config=air.RunConfig(stop=stopper)
... train_fn,
... run_config=air.RunConfig(stop=stopper),
... )
>>> tuner.fit()
== Status ==...
>>> print("[ignore]"); result_grid = tuner.fit() # doctest: +ELLIPSIS
[ignore]...
>>> all(result.metrics["training_iteration"] <= 20 for result in result_grid)
True
"""

def __init__(self, *stoppers: Stopper):
self._stoppers = stoppers

def __call__(self, trial_id, result):
def __call__(self, trial_id: str, result: Dict[str, Any]) -> bool:
return any(s(trial_id, result) for s in self._stoppers)

def stop_all(self):
def stop_all(self) -> bool:
return any(s.stop_all() for s in self._stoppers)

0 comments on commit 421b527

Please sign in to comment.