Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log event print warn #5220

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .threadpoolexecutor import rejoin
from .utils import CancelledError, TimeoutError, sync
from .variable import Variable
from .worker import Reschedule, Worker, get_client, get_worker, secede
from .worker import Reschedule, Worker, get_client, get_worker, print, secede
from .worker_client import local_client, worker_client

versions = get_versions()
Expand Down
8 changes: 8 additions & 0 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,8 @@ def __init__(
"task-erred": self._handle_task_erred,
"restart": self._handle_restart,
"error": self._handle_error,
"print": self._handle_print,
"warn": self._handle_warn,
}

self._state_handlers = {
Expand Down Expand Up @@ -1281,6 +1283,12 @@ def _handle_error(self, exception=None):
logger.warning("Scheduler exception:")
logger.exception(exception)

def _handle_warn(self, warning=None, **kwargs):
warnings.warn(warning)

def _handle_print(self, message=None, **kwargs):
print(*message, **kwargs)

async def _close(self, fast=False):
"""Send close signal and wait until scheduler completes"""
if self.status == "closed":
Expand Down
19 changes: 12 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6749,8 +6749,8 @@ async def feed(
if teardown:
teardown(self, state)

def log_worker_event(self, worker=None, topic=None, msg=None):
self.log_event(topic, msg)
def log_worker_event(self, worker=None, topic=None, msg=None, **kwargs):
self.log_event(topic, msg, **kwargs)

def subscribe_worker_status(self, comm=None):
WorkerStatusPlugin(self, comm)
Expand Down Expand Up @@ -7410,15 +7410,20 @@ async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False):
)
return results

def log_event(self, name, msg):
event = (time(), msg)
if isinstance(name, list):
def log_event(self, name, msg, **kwargs):
if isinstance(name, (list, tuple)):
for n in name:
self.events[n].append(event)
self.event_counts[n] += 1
self.log_event(n, msg, **kwargs)
else:
event = (time(), msg)
self.events[name].append(event)
self.event_counts[name] += 1
if name == "print":
for comm in self.client_comms.values():
comm.send({"op": "print", "message": msg, **kwargs})
if name == "warn":
for comm in self.client_comms.values():
comm.send({"op": "warn", "warning": msg, **kwargs})
Comment on lines +7421 to +7426
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case that we're logging the same message under multiple topics (e.g. name=["model_score", "print"]) we won't send print or warn ops to the client. I'd suggest modifying the logic above to be something along the lines of

if not isinstance(name, list):
    name = [name]
for n in name:
    ... # logic in the current `else:` branch here 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went the other way, and called self.log_event(...) a few times, but I agree that not duplicating the logic twice is a good idea.


def get_events(self, comm=None, topic=None):
if topic is not None:
Expand Down
31 changes: 31 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6498,6 +6498,37 @@ def log_scheduler(dask_scheduler):
assert events[1][1] == ("alice", "bob")


@gen_cluster(client=True)
async def test_log_event_warn(c, s, a, b):
def foo():
get_worker().log_event(["foo", "warn"], "Hello!")

with pytest.warns(Warning, match="Hello!"):
await c.submit(foo)


@gen_cluster(client=True, Worker=Nanny)
async def test_print(c, s, a, b, capsys):
from dask.distributed import print

def foo():
print("Hello!", 123, sep=":")

await c.submit(foo)

out, err = capsys.readouterr()
assert "Hello!:123" in out


def test_print_simple(capsys):
from dask.distributed import print
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me if this import will impact other tests which are run later. I'm wondering if we should use

from dask.distributed import print as dask_print

to not override the builtin print function for later tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we're ok here. I think that imports are properly scoped

In [1]: def f():
   ...:     from dask.distributed import print
   ...: 

In [2]: print
Out[2]: <function print>

In [3]: print?
Docstring:
print(value, ..., sep=' ', end='\n', file=sys.stdout, flush=False)

Prints the values to a stream, or to sys.stdout by default.
Optional keyword arguments:
file:  a file-like object (stream); defaults to the current sys.stdout.
sep:   string inserted between values, default a space.
end:   string appended after the last value, default a newline.
flush: whether to forcibly flush the stream.
Type:      builtin_function_or_method

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I forgot to call `f() there, but I just did it again and it's fine.


print("Hello!", 123, sep=":")

out, err = capsys.readouterr()
assert "Hello!:123" in out


@gen_cluster(client=True)
async def test_annotations_task_state(c, s, a, b):
da = pytest.importorskip("dask.array")
Expand Down
20 changes: 19 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import bisect
import builtins
import concurrent.futures
import errno
import heapq
Expand Down Expand Up @@ -810,12 +811,13 @@ def __repr__(self):
def logs(self):
return self._deque_handler.deque

def log_event(self, topic, msg):
def log_event(self, topic, msg, **kwargs):
self.batched_stream.send(
{
"op": "log-event",
"topic": topic,
"msg": msg,
**kwargs,
}
)

Expand Down Expand Up @@ -4120,3 +4122,19 @@ def gpu_startup(worker):
return nvml.one_time()

DEFAULT_STARTUP_INFORMATION["gpu"] = gpu_startup


def print(*args, **kwargs):
"""Dask print function

This prints both wherever this function is run, and also in the user's
client session
"""
try:
worker = get_worker()
except ValueError:
pass
else:
worker.log_event("print", args, **kwargs)

builtins.print(*args, **kwargs)