Skip to content

Commit

Permalink
Cleanup in BatchedSend
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed May 5, 2022
1 parent 09e62e0 commit 14ac9f1
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 52 deletions.
29 changes: 11 additions & 18 deletions distributed/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
import logging
from collections import deque

from tornado import gen, locks
from tornado.ioloop import IOLoop

import dask
from dask.utils import parse_timedelta

Expand Down Expand Up @@ -37,14 +34,9 @@ class BatchedSend:
['Hello,', 'world!']
"""

# XXX why doesn't BatchedSend follow either the IOStream or Comm API?

def __init__(self, interval, loop=None, serializers=None):
# XXX is the loop arg useful?
self.loop = loop or IOLoop.current()
def __init__(self, interval, serializers=None):
self.interval = parse_timedelta(interval, default="ms")
self.waker = locks.Event()
self.stopped = locks.Event()
self.waker = asyncio.Event()
self.please_stop = False
self.buffer = []
self.comm = None
Expand All @@ -62,7 +54,6 @@ def start(self, comm):
if self._background_task and not self._background_task.done():
raise RuntimeError("Background task still running")
self.please_stop = False
self.stopped.clear()
self.waker.set()
self.next_deadline = None
self.comm = comm
Expand All @@ -86,9 +77,12 @@ def __repr__(self):
async def _background_send(self):
while not self.please_stop:
try:
await self.waker.wait(self.next_deadline)
timeout = None
if self.next_deadline:
timeout = self.next_deadline - time()
await asyncio.wait_for(self.waker.wait(), timeout=timeout)
self.waker.clear()
except gen.TimeoutError:
except asyncio.TimeoutError:
pass
if not self.buffer:
# Nothing to send
Expand All @@ -100,6 +94,7 @@ async def _background_send(self):
payload, self.buffer = self.buffer, []
self.batch_count += 1
self.next_deadline = time() + self.interval

try:
nbytes = await self.comm.write(
payload, serializers=self.serializers, on_error="raise"
Expand Down Expand Up @@ -154,13 +149,11 @@ def send(self, *msgs: dict) -> None:
if self.comm and not self.comm.closed() and self.next_deadline is None:
self.waker.set()

async def close(self, timeout=None):
"""Flush existing messages and then close comm
If set, raises `tornado.util.TimeoutError` after a timeout.
"""
async def close(self):
"""Flush existing messages and then close comm"""
self.please_stop = True
self.waker.set()

if self._background_task:
await self._background_task

Expand Down
8 changes: 2 additions & 6 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ async def _ensure_connected(self, timeout=None):
if msg[0].get("warning"):
warnings.warn(version_module.VersionMismatchWarning(msg[0]["warning"]))

bcomm = BatchedSend(interval="10ms", loop=self.loop)
bcomm = BatchedSend(interval="10ms")
bcomm.start(comm)
self.scheduler_comm = bcomm
if self._set_as_default:
Expand Down Expand Up @@ -1523,11 +1523,7 @@ async def _close(self, fast=False):
with suppress(asyncio.CancelledError, TimeoutError):
await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1)

if (
self.scheduler_comm
and self.scheduler_comm.comm
and not self.scheduler_comm.comm.closed()
):
if self.scheduler_comm:
await self.scheduler_comm.close()

for key in list(self.futures):
Expand Down
19 changes: 7 additions & 12 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3500,7 +3500,7 @@ async def close(self, fast=False, close_workers=False):
await future

for comm in self.client_comms.values():
comm.abort()
await comm.close()

await self.rpc.close()

Expand Down Expand Up @@ -3732,7 +3732,7 @@ async def add_worker(
# for key in keys: # TODO
# self.mark_key_in_memory(key, [address])

self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop)
self.stream_comms[address] = BatchedSend(interval="5ms")

if ws.nthreads > len(ws.processing):
self.idle[ws.address] = ws
Expand Down Expand Up @@ -4316,7 +4316,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True):

logger.info("Remove worker %s", ws)
if close:
with suppress(AttributeError, CommClosedError):
with suppress(AttributeError):
self.stream_comms[address].send({"op": "close", "report": False})

self.remove_resources(address)
Expand All @@ -4330,6 +4330,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True):
del self.host_info[host]

self.rpc.remove(address)
await self.stream_comms[address].close()
del self.stream_comms[address]
del self.aliases[ws.name]
self.idle.pop(ws.address, None)
Expand Down Expand Up @@ -4684,7 +4685,7 @@ async def add_client(
logger.exception(e)

try:
bcomm = BatchedSend(interval="2ms", loop=self.loop)
bcomm = BatchedSend(interval="2ms")
bcomm.start(comm)
self.client_comms[client] = bcomm
msg = {"op": "stream-start"}
Expand Down Expand Up @@ -5032,13 +5033,7 @@ def client_send(self, client, msg):
c = client_comms.get(client)
if c is None:
return
try:
c.send(msg)
except CommClosedError:
if self.status == Status.running:
logger.critical(
"Closed comm %r while trying to write %s", c, msg, exc_info=True
)
c.send(msg)

def send_all(self, client_msgs: dict, worker_msgs: dict):
"""Send messages to client and workers"""
Expand Down Expand Up @@ -5068,7 +5063,7 @@ def send_all(self, client_msgs: dict, worker_msgs: dict):
except KeyError:
# worker already gone
pass
except (CommClosedError, AttributeError):
except AttributeError:
self.loop.add_callback(
self.remove_worker,
address=worker,
Expand Down
8 changes: 4 additions & 4 deletions distributed/tests/test_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ async def test_restart():
b = BatchedSend(interval="2ms")
b.start(comm)
b.send(123)
assert (123,) == await comm.read()
b.abort()
assert await comm.read() == (123,)
await b.close()
assert b.closed()

# We can buffer stuff even while it is closed
Expand All @@ -269,7 +269,7 @@ async def test_restart():
new_comm = await connect(e.address)
b.start(new_comm)

assert (345,) == await new_comm.read()
assert await new_comm.read() == (345,)
await b.close()
assert new_comm.closed()

Expand All @@ -285,5 +285,5 @@ async def test_restart_fails_if_still_running():
b.start(comm)

b.send(123)
assert (123,) == await comm.read()
assert await comm.read() == (123,)
await b.close()
10 changes: 4 additions & 6 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2387,10 +2387,6 @@ async def test_hold_on_to_replicas(c, s, *workers):
await asyncio.sleep(0.01)


@pytest.mark.xfail(
WINDOWS and sys.version_info[:2] == (3, 8),
reason="https://github.com/dask/distributed/issues/5621",
)
@gen_cluster(client=True, nthreads=[("", 1), ("", 1)])
async def test_worker_reconnects_mid_compute(c, s, a, b):
"""Ensure that, if a worker disconnects while computing a result, the scheduler will
Expand Down Expand Up @@ -2436,7 +2432,8 @@ def fast_on_a(lock):

await s.stream_comms[a.address].close()

assert len(s.workers) == 1
while len(s.workers) == 1:
await asyncio.sleep(0.1)
a.heartbeat_active = False
await a.heartbeat()
assert len(s.workers) == 2
Expand Down Expand Up @@ -2508,7 +2505,8 @@ def fast_on_a(lock):
# The only way to get f3 to complete is for Worker A to reconnect.

f1.release()
assert len(s.workers) == 1
while len(s.workers) != 1:
await asyncio.sleep(0.01)
story = s.story(f1.key)
while len(story) == len(story_before):
story = s.story(f1.key)
Expand Down
10 changes: 4 additions & 6 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from inspect import isawaitable
from pickle import PicklingError
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
Expand Down Expand Up @@ -775,7 +774,7 @@ def __init__(
self.nthreads, thread_name_prefix="Dask-Default-Threads"
)

self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.batched_stream = BatchedSend(interval="2ms")
self.name = name
self.scheduler_delay = 0
self.stream_comms = {}
Expand Down Expand Up @@ -1269,7 +1268,7 @@ async def heartbeat(self):
async def handle_scheduler(self, comm):
await self.handle_stream(comm, every_cycle=[self.ensure_communicating])

self.batched_stream.abort()
await self.batched_stream.close()
if self.reconnect and self.status in Status.ANY_RUNNING:
logger.info("Connection to scheduler broken. Reconnecting...")
self.loop.add_callback(self.heartbeat)
Expand Down Expand Up @@ -1583,8 +1582,7 @@ async def close(
self.batched_stream.send({"op": "close-stream"})

if self.batched_stream:
with suppress(TimeoutError):
await self.batched_stream.close(timedelta(seconds=timeout))
await self.batched_stream.close()

for executor in self.executors.values():
if executor is utils._offload_executor:
Expand Down Expand Up @@ -1653,7 +1651,7 @@ async def wait_until_closed(self):

def send_to_worker(self, address, msg):
if address not in self.stream_comms:
bcomm = BatchedSend(interval="1ms", loop=self.loop)
bcomm = BatchedSend(interval="1ms")
self.stream_comms[address] = bcomm

async def batched_send_connect():
Expand Down

0 comments on commit 14ac9f1

Please sign in to comment.