diff --git a/distributed/__init__.py b/distributed/__init__.py index 9f7a8d6f4b..7ccba10be0 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -43,7 +43,7 @@ from .threadpoolexecutor import rejoin from .utils import CancelledError, TimeoutError, sync from .variable import Variable -from .worker import Reschedule, Worker, get_client, get_worker, secede +from .worker import Reschedule, Worker, get_client, get_worker, print, secede from .worker_client import local_client, worker_client versions = get_versions() diff --git a/distributed/client.py b/distributed/client.py index e7389220d4..86823beb6c 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -719,6 +719,8 @@ def __init__( "task-erred": self._handle_task_erred, "restart": self._handle_restart, "error": self._handle_error, + "print": self._handle_print, + "warn": self._handle_warn, } self._state_handlers = { @@ -1281,6 +1283,12 @@ def _handle_error(self, exception=None): logger.warning("Scheduler exception:") logger.exception(exception) + def _handle_warn(self, warning=None, **kwargs): + warnings.warn(warning) + + def _handle_print(self, message=None, **kwargs): + print(*message, **kwargs) + async def _close(self, fast=False): """Send close signal and wait until scheduler completes""" if self.status == "closed": diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1aedbd11d8..4df47ee46e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6749,8 +6749,8 @@ async def feed( if teardown: teardown(self, state) - def log_worker_event(self, worker=None, topic=None, msg=None): - self.log_event(topic, msg) + def log_worker_event(self, worker=None, topic=None, msg=None, **kwargs): + self.log_event(topic, msg, **kwargs) def subscribe_worker_status(self, comm=None): WorkerStatusPlugin(self, comm) @@ -7410,15 +7410,20 @@ async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False): ) return results - def log_event(self, name, msg): - event = (time(), msg) - if isinstance(name, list): + def log_event(self, name, msg, **kwargs): + if isinstance(name, (list, tuple)): for n in name: - self.events[n].append(event) - self.event_counts[n] += 1 + self.log_event(n, msg, **kwargs) else: + event = (time(), msg) self.events[name].append(event) self.event_counts[name] += 1 + if name == "print": + for comm in self.client_comms.values(): + comm.send({"op": "print", "message": msg, **kwargs}) + if name == "warn": + for comm in self.client_comms.values(): + comm.send({"op": "warn", "warning": msg, **kwargs}) def get_events(self, comm=None, topic=None): if topic is not None: diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index a65ab4494d..4b28286f7b 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6498,6 +6498,37 @@ def log_scheduler(dask_scheduler): assert events[1][1] == ("alice", "bob") +@gen_cluster(client=True) +async def test_log_event_warn(c, s, a, b): + def foo(): + get_worker().log_event(["foo", "warn"], "Hello!") + + with pytest.warns(Warning, match="Hello!"): + await c.submit(foo) + + +@gen_cluster(client=True, Worker=Nanny) +async def test_print(c, s, a, b, capsys): + from dask.distributed import print + + def foo(): + print("Hello!", 123, sep=":") + + await c.submit(foo) + + out, err = capsys.readouterr() + assert "Hello!:123" in out + + +def test_print_simple(capsys): + from dask.distributed import print + + print("Hello!", 123, sep=":") + + out, err = capsys.readouterr() + assert "Hello!:123" in out + + @gen_cluster(client=True) async def test_annotations_task_state(c, s, a, b): da = pytest.importorskip("dask.array") diff --git a/distributed/worker.py b/distributed/worker.py index 39e77a8bb9..f7e3d80a65 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1,5 +1,6 @@ import asyncio import bisect +import builtins import concurrent.futures import errno import heapq @@ -810,12 +811,13 @@ def __repr__(self): def logs(self): return self._deque_handler.deque - def log_event(self, topic, msg): + def log_event(self, topic, msg, **kwargs): self.batched_stream.send( { "op": "log-event", "topic": topic, "msg": msg, + **kwargs, } ) @@ -4120,3 +4122,19 @@ def gpu_startup(worker): return nvml.one_time() DEFAULT_STARTUP_INFORMATION["gpu"] = gpu_startup + + +def print(*args, **kwargs): + """Dask print function + + This prints both wherever this function is run, and also in the user's + client session + """ + try: + worker = get_worker() + except ValueError: + pass + else: + worker.log_event("print", args, **kwargs) + + builtins.print(*args, **kwargs)