Skip to content

Commit

Permalink
Configurable timeouts for worker_client and get_client. (#4146)
Browse files Browse the repository at this point in the history
Issue: #4114

Previously, connection timeout for worker_client and get_client was
hard-coded to 3s by default.
With this change, the default timeout value is fetched from the
dask config 'distributed.comm.timeouts.connect'.
  • Loading branch information
geethanjalieswaran authored Nov 18, 2020
1 parent dfbe171 commit 04a6b78
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
24 changes: 19 additions & 5 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2953,7 +2953,7 @@ def client(self):
else:
return self._get_client()

def _get_client(self, timeout=3):
def _get_client(self, timeout=None):
"""Get local client attached to this worker
If no such client exists, create one
Expand All @@ -2962,6 +2962,12 @@ def _get_client(self, timeout=3):
--------
get_client
"""

if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")

timeout = parse_timedelta(timeout, "s")

try:
from .client import default_client

Expand Down Expand Up @@ -2992,6 +2998,7 @@ def _get_client(self, timeout=3):
)
if not asynchronous:
assert self._client.status == "running"

return self._client

def get_current_task(self):
Expand Down Expand Up @@ -3043,7 +3050,7 @@ def get_worker():
raise ValueError("No workers found")


def get_client(address=None, timeout=3, resolve_address=True):
def get_client(address=None, timeout=None, resolve_address=True):
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Expand All @@ -3053,8 +3060,9 @@ def get_client(address=None, timeout=3, resolve_address=True):
address : str, optional
The address of the scheduler to connect to. Defaults to the scheduler
the worker is connected to.
timeout : int, default 3
Timeout (in seconds) for getting the Client
timeout : int or str
Timeout (in seconds) for getting the Client. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
resolve_address : bool, default True
Whether to resolve `address` to its canonical form.
Expand All @@ -3065,7 +3073,7 @@ def get_client(address=None, timeout=3, resolve_address=True):
Examples
--------
>>> def f():
... client = get_client()
... client = get_client(timeout="10s")
... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks
... results = client.gather(futures)
... return sum(results)
Expand All @@ -3080,6 +3088,12 @@ def get_client(address=None, timeout=3, resolve_address=True):
worker_client
secede
"""

if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")

timeout = parse_timedelta(timeout, "s")

if address and resolve_address:
address = comm.resolve_address(address)
try:
Expand Down
17 changes: 13 additions & 4 deletions distributed/worker_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from contextlib import contextmanager
import warnings

import dask
from .threadpoolexecutor import secede, rejoin
from .worker import thread_state, get_client, get_worker
from .utils import parse_timedelta


@contextmanager
def worker_client(timeout=3, separate_thread=True):
def worker_client(timeout=None, separate_thread=True):
"""Get client for this thread
This context manager is intended to be called within functions that we run
Expand All @@ -15,16 +17,17 @@ def worker_client(timeout=3, separate_thread=True):
Parameters
----------
timeout: Number
Timeout after which to err
timeout: Number or String
Timeout after which to error out. Defaults to the
``distributed.comm.timeouts.connect`` configuration value.
separate_thread: bool, optional
Whether to run this function outside of the normal thread pool
defaults to True
Examples
--------
>>> def func(x):
... with worker_client() as c: # connect from worker back to scheduler
... with worker_client(timeout="10s") as c: # connect from worker back to scheduler
... a = c.submit(inc, x) # this task can submit more tasks
... b = c.submit(dec, x)
... result = c.gather([a, b]) # and gather results
Expand All @@ -38,6 +41,12 @@ def worker_client(timeout=3, separate_thread=True):
get_client
secede
"""

if timeout is None:
timeout = dask.config.get("distributed.comm.timeouts.connect")

timeout = parse_timedelta(timeout, "s")

worker = get_worker()
client = get_client(timeout=timeout)
if separate_thread:
Expand Down

0 comments on commit 04a6b78

Please sign in to comment.