Skip to content

Commit

Permalink
[Tune] Fix Jupyter output with Ray Client and Tuner (ray-project#29956
Browse files Browse the repository at this point in the history
)

Ensures that we can have rich Jupyter output with the Tuner API.

Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
Yard1 authored and WeichenXu123 committed Dec 19, 2022
1 parent 9643605 commit fd6b4e3
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 50 deletions.
2 changes: 1 addition & 1 deletion python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ py_test(

py_test(
name = "optuna_multiobjective_example",
size = "small",
size = "medium",
srcs = ["examples/optuna_multiobjective_example.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive", "example", "medium_instance"],
Expand Down
18 changes: 18 additions & 0 deletions python/ray/tune/impl/tuner_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from ray.train.trainer import BaseTrainer
from ray.util.queue import Queue


_TRAINABLE_PKL = "trainable.pkl"
Expand Down Expand Up @@ -126,6 +127,19 @@ def __init__(
pickle.dump(self._trainable, fp)
self._maybe_warn_resource_contention()

def get_run_config(self) -> RunConfig:
return self._run_config

# For Jupyter output with Ray Client
def set_run_config_and_remote_string_queue(
self, run_config: RunConfig, string_queue: "Queue"
):
self._run_config = run_config
self._tuner_kwargs["_remote_string_queue"] = string_queue

def clear_remote_string_queue(self):
self._tuner_kwargs.pop("_remote_string_queue", None)

def _expected_utilization(self, cpus_per_trial, cpus_total):
num_samples = self._tune_config.num_samples
if num_samples < 0: # TODO: simplify this in Tune
Expand Down Expand Up @@ -404,6 +418,7 @@ def _fit_internal(self, trainable, param_space) -> ExperimentAnalysis:
analysis = run(
**args,
)
self.clear_remote_string_queue()
return analysis

def _fit_resume(self, trainable) -> ExperimentAnalysis:
Expand Down Expand Up @@ -431,10 +446,13 @@ def _fit_resume(self, trainable) -> ExperimentAnalysis:
**self._tuner_kwargs,
}
analysis = run(**args)
self.clear_remote_string_queue()
return analysis

def __getstate__(self):
state = self.__dict__.copy()
state["_tuner_kwargs"] = state["_tuner_kwargs"].copy()
state["_tuner_kwargs"].pop("_remote_string_queue", None)
state.pop(_TRAINABLE_KEY, None)
state.pop(_PARAM_SPACE_KEY, None)
state.pop(_EXPERIMENT_ANALYSIS_KEY, None)
Expand Down
68 changes: 64 additions & 4 deletions python/ray/tune/progress_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import numpy as np
from ray._private.dict import flatten_dict

import ray
from ray.tune.callback import Callback
from ray.tune.logger import logger, pretty_print
from ray.tune.logger import pretty_print
from ray.tune.result import (
AUTO_RESULT_KEYS,
DEFAULT_METRIC,
Expand All @@ -34,9 +35,10 @@
from ray.tune.experiment.trial import DEBUG_PRINT_INTERVAL, Trial, _Location
from ray.tune.trainable import Trainable
from ray.tune.utils import unflattened_lookup
from ray.tune.utils.log import Verbosity, has_verbosity
from ray.tune.utils.node import _force_on_current_node
from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.queue import Queue
from ray.util.queue import Empty, Queue

from ray.widgets import Template

Expand Down Expand Up @@ -534,7 +536,7 @@ def __init__(
)

if not IS_NOTEBOOK:
logger.warning(
warnings.warn(
"You are using the `JupyterNotebookReporter`, but not "
"IPython/Jupyter-compatible environment was detected. "
"If this leads to unformatted output (e.g. like "
Expand Down Expand Up @@ -1532,3 +1534,61 @@ def _detect_progress_metrics(
return None

return getattr(trainable, "_progress_metrics", None)


def _prepare_progress_reporter_for_ray_client(
progress_reporter: ProgressReporter,
verbosity: Union[int, Verbosity],
string_queue: Optional[Queue] = None,
) -> Tuple[ProgressReporter, Queue]:
"""Prepares progress reported for Ray Client by setting the string queue.
The string queue will be created if it's None."""
set_verbosity(verbosity)
progress_reporter = progress_reporter or _detect_reporter()

# JupyterNotebooks don't work with remote tune runs out of the box
# (e.g. via Ray client) as they don't have access to the main
# process stdout. So we introduce a queue here that accepts
# strings, which will then be displayed on the driver side.
if isinstance(progress_reporter, RemoteReporterMixin):
if string_queue is None:
string_queue = Queue(
actor_options={"num_cpus": 0, **_force_on_current_node(None)}
)
progress_reporter.output_queue = string_queue

return progress_reporter, string_queue


def _stream_client_output(
remote_future: ray.ObjectRef,
progress_reporter: ProgressReporter,
string_queue: Queue,
) -> Any:
"""
Stream items from string queue to progress_reporter until remote_future resolves
"""
if string_queue is None:
return

def get_next_queue_item():
try:
return string_queue.get(block=False)
except Empty:
return None

def _handle_string_queue():
string_item = get_next_queue_item()
while string_item is not None:
# This happens on the driver side
progress_reporter.display(string_item)
string_item = get_next_queue_item()

# ray.wait(...)[1] returns futures that are not ready, yet
while ray.wait([remote_future], timeout=0.2)[1]:
# Check if we have items to execute
_handle_string_queue()

# Handle queue one last time
_handle_string_queue()
44 changes: 44 additions & 0 deletions python/ray/tune/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
import os
import tempfile
import time

import pytest
import sys
import io
from contextlib import redirect_stdout

import ray
from ray import tune
from ray.air import session, RunConfig
from ray.tune.progress_reporter import JupyterNotebookReporter
from ray.util.client.ray_client_helpers import ray_start_client_server


@pytest.fixture
def start_client_server():
with ray_start_client_server() as client:
yield client
ray.shutdown()


@pytest.fixture
def start_client_server_2_cpus():
ray.init(num_cpus=2)
with ray_start_client_server() as client:
yield client
ray.shutdown()


@pytest.fixture
def start_client_server_4_cpus():
ray.init(num_cpus=4)
with ray_start_client_server() as client:
yield client
ray.shutdown()


def test_pbt_function(start_client_server_2_cpus):
Expand Down Expand Up @@ -94,6 +103,41 @@ def test_pbt_transformers(start_client_server):
tune_transformer(num_samples=1, gpus_per_trial=0, smoke_test=True)


def test_jupyter_rich_output(start_client_server_4_cpus):
assert ray.util.client.ray.is_connected()

def dummy_objective(config):
time.sleep(1)
session.report(dict(metric=1))

ip = ray.util.get_node_ip_address()

class MockJupyterNotebookReporter(JupyterNotebookReporter):
def display(self, string: str) -> None:
# Make sure display is called on the driver
assert ip == ray.util.get_node_ip_address()
if string:
assert "<div" in string
print(string)

reporter = MockJupyterNotebookReporter()
buffer = io.StringIO()
with redirect_stdout(buffer):
tune.run(dummy_objective, num_samples=2, progress_reporter=reporter)
print("", flush=True)
assert "<div" in buffer.getvalue()

reporter = MockJupyterNotebookReporter()
buffer = io.StringIO()
with redirect_stdout(buffer):
tuner = tune.Tuner(
dummy_objective, run_config=RunConfig(progress_reporter=reporter)
)
tuner.fit()
print("", flush=True)
assert "<div" in buffer.getvalue()


if __name__ == "__main__":
import pytest

Expand Down
56 changes: 14 additions & 42 deletions python/ray/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from ray.tune.experiment import Experiment, _convert_to_experiment_list
from ray.tune.progress_reporter import (
ProgressReporter,
RemoteReporterMixin,
_detect_reporter,
_detect_progress_metrics,
_prepare_progress_reporter_for_ray_client,
_stream_client_output,
)
from ray.tune.execution.ray_trial_executor import RayTrialExecutor
from ray.tune.registry import get_trainable_cls, is_function_trainable
Expand Down Expand Up @@ -56,7 +57,7 @@
from ray.tune.utils.node import _force_on_current_node
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.util.annotations import PublicAPI
from ray.util.queue import Empty, Queue
from ray.util.queue import Queue

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -167,6 +168,8 @@ def run(
# == internal only ==
_experiment_checkpoint_dir: Optional[str] = None,
_remote: Optional[bool] = None,
# Passed by the Tuner.
_remote_string_queue: Optional[Queue] = None,
) -> ExperimentAnalysis:
"""Executes training.
Expand Down Expand Up @@ -388,51 +391,20 @@ class and registered trainables.
# Make sure tune.run is called on the sever node.
remote_run = _force_on_current_node(remote_run)

set_verbosity(verbose)
progress_reporter = progress_reporter or _detect_reporter()

# JupyterNotebooks don't work with remote tune runs out of the box
# (e.g. via Ray client) as they don't have access to the main
# process stdout. So we introduce a queue here that accepts
# strings, which will then be displayed on the driver side.
if isinstance(progress_reporter, RemoteReporterMixin):
string_queue = Queue(
actor_options={"num_cpus": 0, **_force_on_current_node(None)}
)
progress_reporter.output_queue = string_queue

def get_next_queue_item():
try:
return string_queue.get(block=False)
except Empty:
return None

else:
# If we don't need a queue, use this dummy get fn instead of
# scheduling an unneeded actor
def get_next_queue_item():
return None

def _handle_string_queue():
string_item = get_next_queue_item()
while string_item is not None:
# This happens on the driver side
progress_reporter.display(string_item)

string_item = get_next_queue_item()
progress_reporter, string_queue = _prepare_progress_reporter_for_ray_client(
progress_reporter, verbose, _remote_string_queue
)

# Override with detected progress reporter
remote_run_kwargs["progress_reporter"] = progress_reporter
remote_future = remote_run.remote(_remote=False, **remote_run_kwargs)

# ray.wait(...)[1] returns futures that are not ready, yet
while ray.wait([remote_future], timeout=0.2)[1]:
# Check if we have items to execute
_handle_string_queue()

# Handle queue one last time
_handle_string_queue()
remote_future = remote_run.remote(_remote=False, **remote_run_kwargs)

_stream_client_output(
remote_future,
progress_reporter,
string_queue,
)
return ray.get(remote_future)

del remote_run_kwargs
Expand Down
Loading

0 comments on commit fd6b4e3

Please sign in to comment.