diff --git a/distributed/worker.py b/distributed/worker.py index 3b365af6c6..31ead8f395 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2,6 +2,7 @@ import bisect import errno import heapq +import io import logging import os import random @@ -12,9 +13,9 @@ import weakref from collections import defaultdict, deque, namedtuple from collections.abc import MutableMapping -from contextlib import suppress +from contextlib import suppress ExitStack, redirect_stdout from datetime import timedelta -from functools import partial +from functools import partial, wraps from inspect import isawaitable from pickle import PicklingError @@ -3721,6 +3722,20 @@ def weight(k, v): return sizeof(v) +def log_stdout(func): + @wraps(func) + def wrapped(*args, **kwargs): + with ExitStack() as stack: + out, _ = io.StringIO(), io.StringIO() + stack.enter_context(redirect_stdout(out)) + try: + return func(*args, **kwargs) + finally: + logger.info(out.getvalue()) + + return wrapped + + async def run(server, comm, function, args=(), kwargs=None, is_coro=None, wait=True): kwargs = kwargs or {} function = pickle.loads(function) @@ -3743,12 +3758,12 @@ async def run(server, comm, function, args=(), kwargs=None, is_coro=None, wait=T logger.info("Run out-of-band function %r", funcname(function)) try: if not is_coro: - result = function(*args, **kwargs) + result = log_stdout(function)(*args, **kwargs) else: if wait: - result = await function(*args, **kwargs) + result = await log_stdout(function)(*args, **kwargs) else: - server.loop.add_callback(function, *args, **kwargs) + server.loop.add_callback(log_stdout(function), *args, **kwargs) result = None except Exception as e: