From a919a07172632a388311d9b8741613667377f384 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 18 Oct 2021 16:28:08 -0500 Subject: [PATCH 01/15] WIP: Use asyncio for TCP/TLS comms This is a WIP PR for using asyncio instead of tornado for TCP/TLS comms. There are a few goals here: - Reduce our dependency on tornado, in favor of the builtin asyncio support - Lower latency for small/medium sized messages. Right now the tornado call stack in the IOStream interface increases our latency for small messages. We can do better here by making use of asyncio protocols. - Equal performance for large messages. We should be able to make zero copy reads/writes just as before. In this case I wouldn't expect a performance increase from asyncio for large (10 MiB+) messages, but we also shouldn't see a slowdown. - Improved TLS performance. The TLS implementation in asyncio (and more-so in uvloop) is better optimized than the implementation in tornado. We should see a measurable performance boost here. - Reduced GIL contention when using TLS. Right now a single write or read through tornado drops and reaquires the GIL several times. If there are busy background threads, this can lead to the IO thread having to wait for the GIL mutex multiple times, leading to lower IO performance. The implementations in asyncio/uvloop don't drop the GIL except for IO (instead of also dropping it for TLS operations), which leads to fewer chances of another thread picking up the GIL and slowing IO. This PR is still a WIP, and still has some TODOs: - Pause the read side of the protocol after a certain limit. Right now the protocol will buffer in memory without a reader, providing no backpressure. - The tornado comms version includes some unnecessary length info in it's framing that I didn't notice when implementing the asyncio version originally. As is, the asyncio comm can't talk to the tornado comm. We'll want to fix this for now, as it would break cross-version support (but we can remove the excess framing later if/when we make other breaking changes). - Right now the asyncio implementation is slower than expected for large frames. Need to profile and debug why. - Do we want to keep both the tornado and asyncio implementations around for a while behind a config knob, or do we want to fully replace the implementation with asyncio? I tend towards the latter, but may add an interim config in this PR (to be ripped out before merge) to make it easier for users to test and benchmark this PR. --- distributed/comm/__init__.py | 10 +- distributed/comm/asyncio_tcp.py | 550 +++++++++++++++++++++++++++ distributed/comm/tcp.py | 6 +- distributed/comm/tests/test_comms.py | 24 +- 4 files changed, 577 insertions(+), 13 deletions(-) create mode 100644 distributed/comm/asyncio_tcp.py diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index 5ca2d1ede3..88ddee4a92 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -10,11 +10,19 @@ unparse_host_port, ) from .core import Comm, CommClosedError, connect, listen +from .registry import backends from .utils import get_tcp_server_address, get_tcp_server_addresses def _register_transports(): - from . import inproc, tcp, ws + from . import asyncio_tcp, inproc, tcp, ws + + if True: # TODO: some kind of config + backends["tcp"] = asyncio_tcp.TCPBackend() + backends["tls"] = asyncio_tcp.TLSBackend() + else: + backends["tcp"] = tcp.TCPBackend() + backends["tls"] = tcp.TLSBackend() try: from . import ucx diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py new file mode 100644 index 0000000000..fd527799c6 --- /dev/null +++ b/distributed/comm/asyncio_tcp.py @@ -0,0 +1,550 @@ +import asyncio +import logging +import socket +import struct +import weakref + +try: + import ssl +except ImportError: + ssl = None # type: ignore + +import dask + +from ..utils import ensure_ip, get_ip, get_ipv6, nbytes +from .addressing import parse_host_port, unparse_host_port +from .core import Comm, CommClosedError, Connector, Listener +from .registry import Backend +from .utils import ensure_concrete_host, from_frames, to_frames + +logger = logging.getLogger(__name__) + + +_COMM_CLOSED = object() + + +class DaskCommProtocol(asyncio.BufferedProtocol): + def __init__(self, on_connection=None, min_read_size=128 * 1024): + super().__init__() + self.on_connection = on_connection + self._exception = None + self._queue = asyncio.Queue() + self._transport = None + self._paused = False + self._drain_waiter = None + self._loop = asyncio.get_running_loop() + self._is_closed = self._loop.create_future() + + # Per-message state + self._using_default_buffer = True + + self._default_buffer = memoryview(bytearray(min_read_size)) + self._default_len = min_read_size + self._default_start = 0 + self._default_end = 0 + + self._nframes = None + self._frame_lengths = None + self._frames = None + self._frame_index = None + self._frame_nbytes_needed = None + + @property + def local_addr(self): + if self._transport is None: + return "" + try: + host, port = self._transport.get_extra_info("socket").getsockname()[:2] + return unparse_host_port(host, port) + except Exception: + breakpoint() # TODO + return "" + + @property + def peer_addr(self): + if self._transport is None: + return "" + try: + host, port = self._transport.get_extra_info("peername") + return unparse_host_port(host, port) + except Exception: + breakpoint() # TODO + return "" + + @property + def is_closed(self): + return self._transport is None + + def _abort(self): + if self._transport is not None: + self._transport, transport = None, self._transport + transport.abort() + + def _close_from_finalizer(self, comm_repr): + if self._transport is not None: + logger.warning(f"Closing dangling comm `{comm_repr}`") + self._abort() + + async def _close(self): + if self._transport is not None: + self._transport, transport = None, self._transport + transport.close() + await self._is_closed + + def connection_made(self, transport): + self._transport = transport + if self.on_connection is not None: + self.on_connection(self) + + def get_buffer(self, sizehint): + if self._frames is None or self._frame_nbytes_needed < self._default_len: + self._using_default_buffer = True + return self._default_buffer[self._default_end :] + else: + self._using_default_buffer = False + frame = self._frames[self._frame_index] + return frame[-self._frame_nbytes_needed :] + + def buffer_updated(self, nbytes): + if nbytes == 0: + return + + if self._using_default_buffer: + self._default_end += nbytes + while self._parse_next(): + pass + self._reset_default_buffer() + else: + self._frame_nbytes_needed -= nbytes + if not self._frame_nbytes_needed: + self._frame_index += 1 + if self._frame_index == self._nframes: + self._message_completed() + else: + self._frame_nbytes_needed = self._frame_lengths[self._frame_index] + + def _parse_next(self): + if self._nframes is None: + if not self._parse_nframes(): + return False + if len(self._frame_lengths) < self._nframes: + if not self._parse_frame_lengths(): + return False + return self._parse_frames() + + def _parse_nframes(self): + if self._default_end - self._default_start > 8: + self._nframes = struct.unpack_from( + " 1: + self._transport.writelines(frames) + else: + self._transport.write(frames[0]) + if self._transport.is_closing(): + await asyncio.sleep(0) + elif self._paused: + self._drain_waiter = self._loop.create_future() + await self._drain_waiter + + return frames_nbytes_total + + +class TCP(Comm): + max_shard_size = dask.utils.parse_bytes(dask.config.get("distributed.comm.shard")) + + def __init__( + self, + protocol, + local_addr: str, + peer_addr: str, + deserialize: bool = True, + ): + self._protocol = protocol + self._local_addr = local_addr + self._peer_addr = peer_addr + self._deserialize = deserialize + self._closed = False + super().__init__() + + # setup a finalizer to close the protocol if the comm was never explicitly closed + self._finalizer = weakref.finalize( + self, self._protocol._close_from_finalizer, repr(self) + ) + self._finalizer.atexit = False + + # Fill in any extra info about this comm + self._extra_info = self._get_extra_info() + + def _get_extra_info(self): + return {} + + @property + def local_address(self) -> str: + return self._local_addr + + @property + def peer_address(self) -> str: + return self._peer_addr + + async def read(self, deserializers=None): + frames = await self._protocol.read() + try: + return await from_frames( + frames, + deserialize=self._deserialize, + deserializers=deserializers, + allow_offload=self.allow_offload, + ) + except EOFError: + # Frames possibly garbled or truncated by communication error + self.abort() + raise CommClosedError("aborted stream on truncated data") + + async def write(self, msg, serializers=None, on_error="message"): + frames = await to_frames( + msg, + allow_offload=self.allow_offload, + serializers=serializers, + on_error=on_error, + context={ + "sender": self.local_info, + "recipient": self.remote_info, + **self.handshake_options, + }, + frame_split_size=self.max_shard_size, + ) + nbytes = await self._protocol.write(frames) + return nbytes + + async def close(self): + """Flush and close the comm""" + self._finalizer.detach() + await self._protocol._close() + + def abort(self): + """Hard close the comm""" + self._finalizer.detach() + self._protocol._abort() + + def closed(self): + return self._protocol.is_closed + + @property + def extra_info(self): + return self._extra_info + + +class TLS(TCP): + def _get_extra_info(self): + get = self._protocol._transport.get_extra_info + return {"peercert": get("peercert"), "cipher": get("cipher")} + + +def _expect_tls_context(connection_args): + ctx = connection_args.get("ssl_context") + if not isinstance(ctx, ssl.SSLContext): + raise TypeError( + "TLS expects a `ssl_context` argument of type " + "ssl.SSLContext (perhaps check your TLS configuration?)" + " Instead got %s" % str(ctx) + ) + return ctx + + +def _error_if_require_encryption(address, **kwargs): + if kwargs.get("require_encryption"): + raise RuntimeError( + "encryption required by Dask configuration, " + "refusing communication from/to %r" % ("tcp://" + address,) + ) + + +class TCPConnector(Connector): + prefix = "tcp://" + comm_class = TCP + + async def connect(self, address, deserialize=True, **kwargs): + loop = asyncio.get_event_loop() + ip, port = parse_host_port(address) + + kwargs = self._get_extra_kwargs(address, **kwargs) + transport, protocol = await loop.create_connection( + DaskCommProtocol, ip, port, **kwargs + ) + local_addr = self.prefix + protocol.local_addr + peer_addr = self.prefix + address + return self.comm_class(protocol, local_addr, peer_addr, deserialize=deserialize) + + def _get_extra_kwargs(self, address, **kwargs): + _error_if_require_encryption(address, **kwargs) + return {} + + +class TLSConnector(TCPConnector): + prefix = "tls://" + comm_class = TLS + + def _get_extra_kwargs(self, address, **kwargs): + ctx = _expect_tls_context(kwargs) + return {"ssl": ctx} + + +class TCPListener(Listener): + prefix = "tcp://" + comm_class = TCP + + def __init__( + self, + address, + comm_handler, + deserialize=True, + allow_offload=True, + default_port=0, + **kwargs, + ): + self.ip, self.port = parse_host_port(address, default_port) + self.comm_handler = comm_handler + self.deserialize = deserialize + self.allow_offload = allow_offload + self._extra_kwargs = self._get_extra_kwargs(address, **kwargs) + self._active_handlers = weakref.WeakSet() + self.bound_address = None + + def _get_extra_kwargs(self, address, **kwargs): + _error_if_require_encryption(address, **kwargs) + return {} + + def _on_connection(self, protocol): + logger.debug("Incoming connection") + comm = self.comm_class( + protocol, + local_addr=self.prefix + protocol.local_addr, + peer_addr=self.prefix + protocol.peer_addr, + deserialize=self.deserialize, + ) + comm.allow_offload = self.allow_offload + self._active_handlers.add(asyncio.ensure_future(self._comm_handler(comm))) + + async def _comm_handler(self, comm): + try: + await self.on_connection(comm) + except CommClosedError: + logger.debug("Connection closed before handshake completed") + return + await self.comm_handler(comm) + + async def start(self): + loop = asyncio.get_event_loop() + self._handle = await loop.create_server( + lambda: DaskCommProtocol(self._on_connection), + host=self.ip, + port=self.port, + **self._extra_kwargs, + ) + + def stop(self): + # Stop listening + self._handle.close() + # Cancel all active handlers + for handler in self._active_handlers: + handler.cancel() + # TODO: stop should really be asynchronous + asyncio.ensure_future(self._handle.wait_closed()) + + def get_host_port(self): + """ + The listening address as a (host, port) tuple. + """ + if self.bound_address is None: + + def get_socket(): + for family in [socket.AF_INET, socket.AF_INET6]: + for sock in self._handle.sockets: + if sock.family == socket.AF_INET: + return sock + raise RuntimeError("No active INET socket found?") + + sock = get_socket() + self.bound_address = sock.getsockname()[:2] + return self.bound_address + + @property + def listen_address(self): + """ + The listening address as a string. + """ + return self.prefix + unparse_host_port(*self.get_host_port()) + + @property + def contact_address(self): + """ + The contact address as a string. + """ + host, port = self.get_host_port() + host = ensure_concrete_host(host) + return self.prefix + unparse_host_port(host, port) + + +class TLSListener(TCPListener): + prefix = "tls://" + comm_class = TLS + + def _get_extra_kwargs(self, address, **kwargs): + ctx = _expect_tls_context(kwargs) + return {"ssl": ctx} + + +class TCPBackend(Backend): + _connector_class = TCPConnector + _listener_class = TCPListener + + def get_connector(self): + return self._connector_class() + + def get_listener(self, loc, handle_comm, deserialize, **connection_args): + return self._listener_class(loc, handle_comm, deserialize, **connection_args) + + def get_address_host(self, loc): + return parse_host_port(loc)[0] + + def get_address_host_port(self, loc): + return parse_host_port(loc) + + def resolve_address(self, loc): + host, port = parse_host_port(loc) + return unparse_host_port(ensure_ip(host), port) + + def get_local_address_for(self, loc): + host, port = parse_host_port(loc) + host = ensure_ip(host) + if ":" in host: + local_host = get_ipv6(host) + else: + local_host = get_ip(host) + return unparse_host_port(local_host, None) + + +class TLSBackend(TCPBackend): + _connector_class = TLSConnector + _listener_class = TLSListener diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index cad01427ce..9339088c90 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -30,7 +30,7 @@ from ..utils import ensure_ip, get_ip, get_ipv6, nbytes from .addressing import parse_host_port, unparse_host_port from .core import Comm, CommClosedError, Connector, FatalCommClosedError, Listener -from .registry import Backend, backends +from .registry import Backend from .utils import ensure_concrete_host, from_frames, get_tcp_server_address, to_frames logger = logging.getLogger(__name__) @@ -622,7 +622,3 @@ class TCPBackend(BaseTCPBackend): class TLSBackend(BaseTCPBackend): _connector_class = TLSConnector _listener_class = TLSListener - - -backends["tcp"] = TCPBackend() -backends["tls"] = TLSBackend() diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index b83d44ccbb..5c90a0e3ea 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -14,8 +14,9 @@ import dask import distributed +from distributed.comm import CommClosedError +from distributed.comm import asyncio_tcp as tcp from distributed.comm import ( - CommClosedError, connect, get_address_host, get_local_address_for, @@ -24,11 +25,10 @@ parse_address, parse_host_port, resolve_address, - tcp, unparse_host_port, ) +from distributed.comm.asyncio_tcp import TCP, TCPBackend, TCPConnector from distributed.comm.registry import backends, get_backend -from distributed.comm.tcp import TCP, TCPBackend, TCPConnector from distributed.metrics import time from distributed.protocol import Serialized, deserialize, serialize, to_serialize from distributed.utils import get_ip, get_ipv6 @@ -315,6 +315,7 @@ async def client_communicate(key, delay=0): @pytest.mark.asyncio +@pytest.mark.skip(reason="not applicable for asyncio") async def test_comm_failure_threading(): """ When we fail to connect, make sure we don't make a lot @@ -687,7 +688,8 @@ async def handle_comm(comm): with pytest.raises(EnvironmentError) as excinfo: await connect(listener.contact_address, timeout=2, ssl_context=cli_ctx) - assert "certificate verify failed" in str(excinfo.value.__cause__) + # XXX: For asyncio this is just a timeout error + # assert "certificate verify failed" in str(excinfo.value.__cause__) # @@ -815,6 +817,7 @@ async def handle_comm(comm): @pytest.mark.asyncio +@pytest.mark.skip(reason="Not applicable for asyncio") async def test_comm_closed_on_buffer_error(): # Internal errors from comm.stream.write, such as # BufferError should lead to the stream being closed @@ -898,7 +901,8 @@ class SlowConnector(TCPConnector): comm_class = SlowComm class SlowBackend(TCPBackend): - _connector_class = SlowConnector + def get_connector(self): + return SlowConnector() monkeypatch.setitem(backends, "tcp", SlowBackend()) @@ -977,8 +981,12 @@ async def check_listener_deserialize(addr, deserialize, in_value, check_out): q = asyncio.Queue() async def handle_comm(comm): - msg = await comm.read() - q.put_nowait(msg) + try: + msg = await comm.read() + except Exception as exc: + q.put_nowait(exc) + else: + q.put_nowait(msg) await comm.close() async with listen(addr, handle_comm, deserialize=deserialize) as listener: @@ -987,6 +995,8 @@ async def handle_comm(comm): await comm.write(in_value) out_value = await q.get() + if isinstance(out_value, Exception): + raise out_value # Prevents deadlocks, get actual deserialization exception check_out(out_value) await comm.close() From 504e92ab3a9ecb5f90a2aa6b1bfabda69b1bbee0 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 22 Oct 2021 15:21:51 -0500 Subject: [PATCH 02/15] Fixups Bug squashing. --- distributed/comm/asyncio_tcp.py | 55 ++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index fd527799c6..46b6e93916 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -47,7 +47,7 @@ def __init__(self, on_connection=None, min_read_size=128 * 1024): self._frame_lengths = None self._frames = None self._frame_index = None - self._frame_nbytes_needed = None + self._frame_nbytes_needed = 0 @property def local_addr(self): @@ -57,7 +57,6 @@ def local_addr(self): host, port = self._transport.get_extra_info("socket").getsockname()[:2] return unparse_host_port(host, port) except Exception: - breakpoint() # TODO return "" @property @@ -68,7 +67,6 @@ def peer_addr(self): host, port = self._transport.get_extra_info("peername") return unparse_host_port(host, port) except Exception: - breakpoint() # TODO return "" @property @@ -116,12 +114,8 @@ def buffer_updated(self, nbytes): self._reset_default_buffer() else: self._frame_nbytes_needed -= nbytes - if not self._frame_nbytes_needed: - self._frame_index += 1 - if self._frame_index == self._nframes: - self._message_completed() - else: - self._frame_nbytes_needed = self._frame_lengths[self._frame_index] + if not self._frames_check_remaining(): + self._message_completed() def _parse_next(self): if self._nframes is None: @@ -162,10 +156,25 @@ def _parse_frame_lengths(self): return True return False + def _frames_check_remaining(self): + # Current frame not filled + if self._frame_nbytes_needed: + return True + # Advance until next non-empty frame + while True: + self._frame_index += 1 + if self._frame_index < self._nframes: + self._frame_nbytes_needed = self._frame_lengths[self._frame_index] + if self._frame_nbytes_needed: + return True + else: + # No non-empty frames remain + return False + def _parse_frames(self): while True: # Are we out of frames? - if self._frame_index == self._nframes: + if not self._frames_check_remaining(): self._message_completed() return True # Are we out of data? @@ -175,15 +184,12 @@ def _parse_frames(self): frame = self._frames[self._frame_index] n_read = min(self._frame_nbytes_needed, available) - frame[-self._frame_nbytes_needed : n_read] = self._default_buffer[ - self._default_start : self._default_start + n_read - ] + frame[ + -self._frame_nbytes_needed : (n_read - self._frame_nbytes_needed) + or None + ] = self._default_buffer[self._default_start : self._default_start + n_read] self._default_start += n_read self._frame_nbytes_needed -= n_read - if not self._frame_nbytes_needed: - self._frame_index += 1 - if self._frame_index < self._nframes: - self._frame_nbytes_needed = self._frame_lengths[self._frame_index] def _reset_default_buffer(self): start = self._default_start @@ -202,6 +208,7 @@ def _message_completed(self): self._nframes = None self._frames = None self._frame_lengths = None + self._frame_nbytes_remaining = 0 def connection_lost(self, exc=None): if exc is None: @@ -283,7 +290,7 @@ def __init__( self._protocol = protocol self._local_addr = local_addr self._peer_addr = peer_addr - self._deserialize = deserialize + self.deserialize = deserialize self._closed = False super().__init__() @@ -312,7 +319,7 @@ async def read(self, deserializers=None): try: return await from_frames( frames, - deserialize=self._deserialize, + deserialize=self.deserialize, deserializers=deserializers, allow_offload=self.allow_offload, ) @@ -420,15 +427,16 @@ def __init__( comm_handler, deserialize=True, allow_offload=True, + default_host=None, default_port=0, **kwargs, ): self.ip, self.port = parse_host_port(address, default_port) + self.default_host = default_host self.comm_handler = comm_handler self.deserialize = deserialize self.allow_offload = allow_offload self._extra_kwargs = self._get_extra_kwargs(address, **kwargs) - self._active_handlers = weakref.WeakSet() self.bound_address = None def _get_extra_kwargs(self, address, **kwargs): @@ -444,7 +452,7 @@ def _on_connection(self, protocol): deserialize=self.deserialize, ) comm.allow_offload = self.allow_offload - self._active_handlers.add(asyncio.ensure_future(self._comm_handler(comm))) + asyncio.ensure_future(self._comm_handler(comm)) async def _comm_handler(self, comm): try: @@ -466,9 +474,6 @@ async def start(self): def stop(self): # Stop listening self._handle.close() - # Cancel all active handlers - for handler in self._active_handlers: - handler.cancel() # TODO: stop should really be asynchronous asyncio.ensure_future(self._handle.wait_closed()) @@ -502,7 +507,7 @@ def contact_address(self): The contact address as a string. """ host, port = self.get_host_port() - host = ensure_concrete_host(host) + host = ensure_concrete_host(host, default_host=self.default_host) return self.prefix + unparse_host_port(host, port) From 61bdddf9d9b4ac7d1dfc949ba1b10cafcf44ff54 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 1 Nov 2021 15:05:12 -0500 Subject: [PATCH 03/15] Fixup tests --- distributed/comm/asyncio_tcp.py | 9 ++++----- distributed/tests/test_worker.py | 3 +-- distributed/utils_test.py | 5 ++++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 46b6e93916..7bd21ca568 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -346,13 +346,13 @@ async def write(self, msg, serializers=None, on_error="message"): async def close(self): """Flush and close the comm""" - self._finalizer.detach() await self._protocol._close() + self._finalizer.detach() def abort(self): """Hard close the comm""" - self._finalizer.detach() self._protocol._abort() + self._finalizer.detach() def closed(self): return self._protocol.is_closed @@ -471,11 +471,10 @@ async def start(self): **self._extra_kwargs, ) - def stop(self): + async def stop(self): # Stop listening self._handle.close() - # TODO: stop should really be asynchronous - asyncio.ensure_future(self._handle.wait_closed()) + await self._handle.wait_closed() def get_host_port(self): """ diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 47233bc135..36b713eccf 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -31,8 +31,8 @@ get_worker, wait, ) +from distributed.comm.asyncio_tcp import TCPBackend from distributed.comm.registry import backends -from distributed.comm.tcp import TCPBackend from distributed.compatibility import LINUX, WINDOWS from distributed.core import CommClosedError, Status, rpc from distributed.diagnostics import nvml @@ -1952,7 +1952,6 @@ def get_worker_client_id(): @gen_cluster(nthreads=[("127.0.0.1", 0)]) async def test_worker_client_closes_if_created_on_worker_one_worker(s, a): async with Client(s.address, set_as_default=False, asynchronous=True) as c: - with pytest.raises(ValueError): default_client() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index c622a2faff..6feadc76a6 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -43,7 +43,7 @@ import dask -from distributed.comm.tcp import TCP +from distributed.comm.asyncio_tcp import TCP from . import system from .client import Client, _global_clients, default_client @@ -1499,6 +1499,9 @@ def check_thread_leak(): # TODO: Make sure profile thread is cleaned up # and remove the line below and "Profile" not in thread.name + # asyncio default executor thread pool is not shut down until loop + # is shut down + and "asyncio_" not in thread.name ] if not bad_threads: break From 67c7a7329648fe9e1aac5326b7a716da058105ee Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Tue, 2 Nov 2021 15:14:28 -0500 Subject: [PATCH 04/15] More fixups - Workaround cpython bug that prevents listening on all interfaces when also using a random port. - Assorted other small cleanups --- distributed/comm/asyncio_tcp.py | 108 +++++++++++++++++++++++++++----- distributed/comm/ws.py | 2 +- distributed/tests/test_core.py | 2 +- 3 files changed, 96 insertions(+), 16 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 7bd21ca568..7d46e788e2 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -81,7 +81,11 @@ def _abort(self): def _close_from_finalizer(self, comm_repr): if self._transport is not None: logger.warning(f"Closing dangling comm `{comm_repr}`") - self._abort() + try: + self._abort() + except RuntimeError: + # This happens if the event loop is already closed + pass async def _close(self): if self._transport is not None: @@ -462,19 +466,95 @@ async def _comm_handler(self, comm): return await self.comm_handler(comm) - async def start(self): + async def _start_all_interfaces_with_random_port(self): + """Due to a bug in asyncio, listening on `("", 0)` will result in two + different random ports being used (one for IPV4, one for IPV6), rather + than both interfaces sharing the same random port. We work around this + here. See https://bugs.python.org/issue45693 for more info.""" loop = asyncio.get_event_loop() - self._handle = await loop.create_server( - lambda: DaskCommProtocol(self._on_connection), - host=self.ip, - port=self.port, - **self._extra_kwargs, + # Typically resolves to list with length == 2 (one IPV4, one IPV6). + infos = await loop.getaddrinfo( + None, + 0, + family=socket.AF_UNSPEC, + type=socket.SOCK_STREAM, + flags=socket.AI_PASSIVE, + proto=0, ) + # This code is a simplified and modified version of that found in + # cpython here: + # https://github.com/python/cpython/blob/401272e6e660445d6556d5cd4db88ed4267a50b3/Lib/asyncio/base_events.py#L1439 + servers = [] + port = None + try: + for res in infos: + af, socktype, proto, canonname, sa = res + try: + sock = socket.socket(af, socktype, proto) + except OSError: + # Assume it's a bad family/type/protocol combination. + continue + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == getattr(socket, "AF_INET6", None) and hasattr( + socket, "IPPROTO_IPV6" + ): + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True) + + # If random port is already chosen, reuse it + if port is not None: + sa = (sa[0], port, *sa[2:]) + try: + sock.bind(sa) + except OSError as err: + raise OSError( + err.errno, + "error while attempting " + "to bind on address %r: %s" % (sa, err.strerror.lower()), + ) from None + + # If random port hadn't already been chosen, cache this port to + # reuse for other interfaces + if port is None: + port = sock.getsockname()[1] + + # Create a new server for the socket + server = await loop.create_server( + lambda: DaskCommProtocol(self._on_connection), + sock=sock, + **self._extra_kwargs, + ) + servers.append(server) + sock = None + except BaseException: + # Close all opened servers + for server in servers: + server.close() + # If a socket was already created but not converted to a server + # yet, close that as well. + if sock is not None: + sock.close() + + self._servers = servers + + async def start(self): + loop = asyncio.get_event_loop() + if not self.ip and not self.port: + await self._start_all_interfaces_with_random_port() + else: + server = await loop.create_server( + lambda: DaskCommProtocol(self._on_connection), + host=self.ip, + port=self.port, + **self._extra_kwargs, + ) + self._servers = [server] - async def stop(self): + def stop(self): # Stop listening - self._handle.close() - await self._handle.wait_closed() + for server in self._servers: + server.close() def get_host_port(self): """ @@ -482,14 +562,14 @@ def get_host_port(self): """ if self.bound_address is None: - def get_socket(): + def get_socket(server): for family in [socket.AF_INET, socket.AF_INET6]: - for sock in self._handle.sockets: - if sock.family == socket.AF_INET: + for sock in server.sockets: + if sock.family == family: return sock raise RuntimeError("No active INET socket found?") - sock = get_socket() + sock = get_socket(self._servers[0]) self.bound_address = sock.getsockname()[:2] return self.bound_address diff --git a/distributed/comm/ws.py b/distributed/comm/ws.py index a733031320..e958ed02dc 100644 --- a/distributed/comm/ws.py +++ b/distributed/comm/ws.py @@ -339,7 +339,7 @@ async def start(self): self.server = HTTPServer(web.Application(routes), **self.server_args) self.server.listen(self.port) - async def stop(self): + def stop(self): self.server.stop() def get_host_port(self): diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 3d03ccc701..68a73dc625 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -591,8 +591,8 @@ async def test_connection_pool_close_while_connecting(monkeypatch): Ensure a closed connection pool guarantees to have no connections left open even if it is closed mid-connecting """ + from distributed.comm.asyncio_tcp import TCPBackend, TCPConnector from distributed.comm.registry import backends - from distributed.comm.tcp import TCPBackend, TCPConnector class SlowConnector(TCPConnector): async def connect(self, address, deserialize, **connection_args): From 5a94477c651f41a13840d8ea7536f262d07d84c5 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Thu, 4 Nov 2021 07:55:25 -0500 Subject: [PATCH 05/15] Backwards compat, add envvar - Make asyncio protocol compatible with the tornado protocol. This adds back an unnecessary header prefix, which can be removed later if desired. - Add an environment variable for enabling asyncio-based comms. They are now disabled by default. --- distributed/comm/__init__.py | 8 +++++++- distributed/comm/asyncio_tcp.py | 22 +++++++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index 88ddee4a92..1f1a05341c 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -15,9 +15,15 @@ def _register_transports(): + import os + from . import asyncio_tcp, inproc, tcp, ws - if True: # TODO: some kind of config + if os.getenv("DISTRIBUTED_USE_ASYNCIO_FOR_TCP", "").lower() not in ( + "", + "0", + "false", + ): backends["tcp"] = asyncio_tcp.TCPBackend() backends["tls"] = asyncio_tcp.TLSBackend() else: diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 7d46e788e2..77d5eca70b 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -131,11 +131,15 @@ def _parse_next(self): return self._parse_frames() def _parse_nframes(self): - if self._default_end - self._default_start > 8: + # TODO: we drop the message total size prefix (sent as part of the + # tornado-based tcp implementation), as it's not needed. If we ever + # drop that prefix entirely, we can adjust this code (change 16 -> 8 + # and 8 -> 0). + if self._default_end - self._default_start >= 16: self._nframes = struct.unpack_from( - " Date: Thu, 4 Nov 2021 14:49:50 -0500 Subject: [PATCH 06/15] Optimize send for larger messages The builtin asyncio socket transport makes lots of copies, which can slow down large writes. To get around this, we implement a hacky wrapper for the transport that removes the use of copies (at the cost of some more bookkeeping). --- distributed/comm/asyncio_tcp.py | 183 ++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 77d5eca70b..5d8204fc47 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -1,4 +1,5 @@ import asyncio +import collections import logging import socket import struct @@ -94,7 +95,10 @@ async def _close(self): await self._is_closed def connection_made(self, transport): + if type(transport) is asyncio.selector_events._SelectorSocketTransport: + transport = _ZeroCopyWriter(transport) self._transport = transport + # self._transport = transport if self.on_connection is not None: self.on_connection(self) @@ -640,3 +644,182 @@ def get_local_address_for(self, loc): class TLSBackend(TCPBackend): _connector_class = TLSConnector _listener_class = TLSListener + + +_LARGE_BUF_LIMIT = 2048 + + +class _ZeroCopyWriter: + """The builtin socket transport in asyncio makes a bunch of copies, which + can make sending large amounts of data much slower. This hacks around that. + Note that this workaround isn't used on windows or uvloop""" + + def __init__(self, transport): + self.transport = transport + self._buffers = collections.deque() + self._first_pos = 0 + self._size = 0 + + def _buffer_append(self, data): + size = len(data) + if size > _LARGE_BUF_LIMIT: + self._buffers.append(memoryview(data)) + elif size > 0: + if self._buffers: + last_buf = self._buffers[-1] + last_buf_typ = type(last_buf) + new_buf = last_buf_typ is memoryview or len(last_buf) > _LARGE_BUF_LIMIT + else: + new_buf = True + + if new_buf: + self._buffers.append(data) + else: + if last_buf_typ is bytes: + last_buf = self._buffers[-1] = bytearray(last_buf) + last_buf.extend(data) + + self._size += size + + def _buffer_peek_one(self): + b = self._buffers[0] + if not isinstance(b, memoryview): + b = memoryview(b) + return b[self._first_pos :] + + def _buffer_peek_many(self): + pos = self._first_pos + buffers = [] + size = 0 + for b in self._buffers: + if pos: + if not isinstance(b, memoryview): + b = memoryview(b) + b = b[pos:] + pos = 0 + buffers.append(b) + size += len(b) + if size > 2 ** 17: + break + return buffers + + def _buffer_advance(self, size): + pos = self._first_pos + self._size -= size + buffers = self._buffers + while size: + b = buffers[0] + b_len = len(b) - pos + if b_len <= size: + buffers.popleft() + size -= b_len + pos = 0 + else: + pos += size + break + self._first_pos = pos + + def write(self, data): + transport = self.transport + + if transport._eof: + raise RuntimeError("Cannot call write() after write_eof()") + if transport._empty_waiter is not None: + raise RuntimeError("unable to write; sendfile is in progress") + if not data: + return + if transport._conn_lost: + return + + if not self._buffers: + try: + n = transport._sock.send(data) + except (BlockingIOError, InterruptedError): + pass + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + transport._fatal_error(exc, "Fatal write error on socket transport") + return + else: + data = data[n:] + if not data: + return + # Not all was written; register write handler. + transport._loop._add_writer(transport._sock_fd, self._on_write_ready) + + # Add it to the buffer. + self._buffer_append(data) + transport._maybe_pause_protocol() + + def writelines(self, buffers): + waiting = bool(self._buffers) + for b in buffers: + self._buffer_append(b) + if not waiting: + try: + self._do_bulk_write() + except (BlockingIOError, InterruptedError): + pass + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self.transport._fatal_error( + exc, "Fatal write error on socket transport" + ) + return + if not self._buffers: + return + # Not all was written; register write handler. + self.transport._loop._add_writer( + self.transport._sock_fd, self._on_write_ready + ) + + self.transport._maybe_pause_protocol() + + def is_closing(self): + return self.transport.is_closing() + + def close(self): + return self.transport.close() + + def abort(self): + return self.transport.abort() + + def get_extra_info(self, key): + return self.transport.get_extra_info(key) + + def _do_bulk_write(self): + # TODO: figure out why/when sendmsg is faster/slower + # buffers = self._buffer_peek_many() + # n = self.transport._sock.sendmsg(buffers) + n = self.transport._sock.send(self._buffer_peek_one()) + if n: + self._buffer_advance(n) + + def _on_write_ready(self): + transport = self.transport + if transport._conn_lost: + return + try: + self._do_bulk_write() + except (BlockingIOError, InterruptedError): + pass + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + transport._loop._remove_writer(transport._sock_fd) + self._buffers.clear() + transport._fatal_error(exc, "Fatal write error on socket transport") + if transport._empty_waiter is not None: + transport._empty_waiter.set_exception(exc) + else: + transport._maybe_resume_protocol() + if not self._buffers: + transport._loop._remove_writer(transport._sock_fd) + if transport._empty_waiter is not None: + transport._empty_waiter.set_result(None) + if transport._closing: + transport._call_connection_lost(None) + elif transport._eof: + transport._sock.shutdown(socket.SHUT_WR) From fc23a63b32c385d677183566f18618432a0d0829 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 5 Nov 2021 11:12:04 -0500 Subject: [PATCH 07/15] Simplify buffer management --- distributed/comm/asyncio_tcp.py | 72 +++++++-------------------------- 1 file changed, 14 insertions(+), 58 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 5d8204fc47..ad5a2d8fbe 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -98,7 +98,6 @@ def connection_made(self, transport): if type(transport) is asyncio.selector_events._SelectorSocketTransport: transport = _ZeroCopyWriter(transport) self._transport = transport - # self._transport = transport if self.on_connection is not None: self.on_connection(self) @@ -661,63 +660,22 @@ def __init__(self, transport): self._size = 0 def _buffer_append(self, data): - size = len(data) - if size > _LARGE_BUF_LIMIT: - self._buffers.append(memoryview(data)) - elif size > 0: - if self._buffers: - last_buf = self._buffers[-1] - last_buf_typ = type(last_buf) - new_buf = last_buf_typ is memoryview or len(last_buf) > _LARGE_BUF_LIMIT - else: - new_buf = True - - if new_buf: - self._buffers.append(data) - else: - if last_buf_typ is bytes: - last_buf = self._buffers[-1] = bytearray(last_buf) - last_buf.extend(data) - - self._size += size + data = memoryview(data).cast("B") + self._buffers.append(data) + self._size += len(data) - def _buffer_peek_one(self): - b = self._buffers[0] - if not isinstance(b, memoryview): - b = memoryview(b) - return b[self._first_pos :] - - def _buffer_peek_many(self): - pos = self._first_pos - buffers = [] - size = 0 - for b in self._buffers: - if pos: - if not isinstance(b, memoryview): - b = memoryview(b) - b = b[pos:] - pos = 0 - buffers.append(b) - size += len(b) - if size > 2 ** 17: - break - return buffers + def _buffer_peek(self): + return self._buffers[0][self._first_pos :] def _buffer_advance(self, size): - pos = self._first_pos + b = self._buffers[0] self._size -= size - buffers = self._buffers - while size: - b = buffers[0] - b_len = len(b) - pos - if b_len <= size: - buffers.popleft() - size -= b_len - pos = 0 - else: - pos += size - break - self._first_pos = pos + b_len = len(b) - self._first_pos + if b_len == size: + self._buffers.popleft() + self._first_pos = 0 + else: + self._first_pos += size def write(self, data): transport = self.transport @@ -790,10 +748,8 @@ def get_extra_info(self, key): return self.transport.get_extra_info(key) def _do_bulk_write(self): - # TODO: figure out why/when sendmsg is faster/slower - # buffers = self._buffer_peek_many() - # n = self.transport._sock.sendmsg(buffers) - n = self.transport._sock.send(self._buffer_peek_one()) + buf = self._buffer_peek() + n = self.transport._sock.send(buf) if n: self._buffer_advance(n) From 682eed884b604b41522e02c7b0920fb841a5a2ef Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Tue, 23 Nov 2021 12:43:52 -0600 Subject: [PATCH 08/15] Cleanups - Add config option for selecting TCP backend - Parametrize TCP comm tests to test both backends - A few small code cleanups --- distributed/comm/__init__.py | 23 ++++--- distributed/comm/asyncio_tcp.py | 55 ++++++++-------- distributed/comm/tests/test_comms.py | 98 ++++++++++++++++------------ distributed/distributed-schema.yaml | 8 +++ distributed/distributed.yaml | 3 + distributed/tests/test_core.py | 2 +- distributed/tests/test_worker.py | 2 +- distributed/utils_test.py | 2 +- 8 files changed, 113 insertions(+), 80 deletions(-) diff --git a/distributed/comm/__init__.py b/distributed/comm/__init__.py index 1f1a05341c..a93e7705d3 100644 --- a/distributed/comm/__init__.py +++ b/distributed/comm/__init__.py @@ -15,20 +15,27 @@ def _register_transports(): - import os + import dask.config - from . import asyncio_tcp, inproc, tcp, ws + from . import inproc, ws + + tcp_backend = dask.config.get("distributed.comm.tcp.backend") + + if tcp_backend == "asyncio": + from . import asyncio_tcp - if os.getenv("DISTRIBUTED_USE_ASYNCIO_FOR_TCP", "").lower() not in ( - "", - "0", - "false", - ): backends["tcp"] = asyncio_tcp.TCPBackend() backends["tls"] = asyncio_tcp.TLSBackend() - else: + elif tcp_backend == "tornado": + from . import tcp + backends["tcp"] = tcp.TCPBackend() backends["tls"] = tcp.TLSBackend() + else: + raise ValueError( + f"Expected `distributed.comm.tcp.backend` to be in `('asyncio', " + f"'tornado')`, got {tcp_backend}" + ) try: from . import ucx diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index ad5a2d8fbe..40c7eb5011 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -478,10 +478,11 @@ async def _comm_handler(self, comm): await self.comm_handler(comm) async def _start_all_interfaces_with_random_port(self): - """Due to a bug in asyncio, listening on `("", 0)` will result in two - different random ports being used (one for IPV4, one for IPV6), rather - than both interfaces sharing the same random port. We work around this - here. See https://bugs.python.org/issue45693 for more info.""" + """Due to a design decision in asyncio, listening on `("", 0)` will + result in two different random ports being used (one for IPV4, one for + IPV6), rather than both interfaces sharing the same random port. We + work around this here. See https://bugs.python.org/issue45693 for more + info.""" loop = asyncio.get_event_loop() # Typically resolves to list with length == 2 (one IPV4, one IPV6). infos = await loop.getaddrinfo( @@ -546,21 +547,24 @@ async def _start_all_interfaces_with_random_port(self): # yet, close that as well. if sock is not None: sock.close() + raise - self._servers = servers + return servers async def start(self): loop = asyncio.get_event_loop() if not self.ip and not self.port: - await self._start_all_interfaces_with_random_port() + servers = await self._start_all_interfaces_with_random_port() else: - server = await loop.create_server( - lambda: DaskCommProtocol(self._on_connection), - host=self.ip, - port=self.port, - **self._extra_kwargs, - ) - self._servers = [server] + servers = [ + await loop.create_server( + lambda: DaskCommProtocol(self._on_connection), + host=self.ip, + port=self.port, + **self._extra_kwargs, + ) + ] + self._servers = servers def stop(self): # Stop listening @@ -645,37 +649,34 @@ class TLSBackend(TCPBackend): _listener_class = TLSListener -_LARGE_BUF_LIMIT = 2048 - - class _ZeroCopyWriter: """The builtin socket transport in asyncio makes a bunch of copies, which can make sending large amounts of data much slower. This hacks around that. - Note that this workaround isn't used on windows or uvloop""" + + Note that this workaround isn't used with the windows ProactorEventLoop or + uvloop.""" def __init__(self, transport): self.transport = transport self._buffers = collections.deque() - self._first_pos = 0 - self._size = 0 + self._offset = 0 def _buffer_append(self, data): - data = memoryview(data).cast("B") - self._buffers.append(data) - self._size += len(data) + self._buffers.append(memoryview(data).cast("B")) def _buffer_peek(self): - return self._buffers[0][self._first_pos :] + offset = self._offset + buf = self._buffers[0] + return buf[offset:] if offset else buf def _buffer_advance(self, size): b = self._buffers[0] - self._size -= size - b_len = len(b) - self._first_pos + b_len = len(b) - self._offset if b_len == size: self._buffers.popleft() - self._first_pos = 0 + self._offset = 0 else: - self._first_pos += size + self._offset += size def write(self, data): transport = self.transport diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 5c90a0e3ea..8a15896b06 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -14,9 +14,9 @@ import dask import distributed -from distributed.comm import CommClosedError -from distributed.comm import asyncio_tcp as tcp from distributed.comm import ( + CommClosedError, + asyncio_tcp, connect, get_address_host, get_local_address_for, @@ -27,7 +27,6 @@ resolve_address, unparse_host_port, ) -from distributed.comm.asyncio_tcp import TCP, TCPBackend, TCPConnector from distributed.comm.registry import backends, get_backend from distributed.metrics import time from distributed.protocol import Serialized, deserialize, serialize, to_serialize @@ -47,6 +46,18 @@ EXTERNAL_IP6 = get_ipv6() +@pytest.fixture(params=["tornado", "asyncio"]) +def tcp(monkeypatch, request): + """Set the TCP backend to either tornado or asyncio""" + if request.param == "tornado": + import distributed.comm.tcp as tcp + else: + import distributed.comm.asyncio_tcp as tcp + monkeypatch.setitem(backends, "tcp", tcp.TCPBackend()) + monkeypatch.setitem(backends, "tls", tcp.TLSBackend()) + return tcp + + ca_file = get_cert("tls-ca-cert.pem") # The Subject field of our test certs @@ -117,7 +128,7 @@ async def debug_loop(): # -def test_parse_host_port(): +def test_parse_host_port(tcp): f = parse_host_port assert f("localhost:123") == ("localhost", 123) @@ -140,7 +151,7 @@ def test_parse_host_port(): f("::1") -def test_unparse_host_port(): +def test_unparse_host_port(tcp): f = unparse_host_port assert f("localhost", 123) == "localhost:123" @@ -157,14 +168,14 @@ def test_unparse_host_port(): assert f("::1", "*") == "[::1]:*" -def test_get_address_host(): +def test_get_address_host(tcp): f = get_address_host assert f("tcp://127.0.0.1:123") == "127.0.0.1" assert f("inproc://%s/%d/123" % (get_ip(), os.getpid())) == get_ip() -def test_resolve_address(): +def test_resolve_address(tcp): f = resolve_address assert f("tcp://127.0.0.1:123") == "tcp://127.0.0.1:123" @@ -184,7 +195,7 @@ def test_resolve_address(): assert f("tls://localhost:456") == "tls://127.0.0.1:456" -def test_get_local_address_for(): +def test_get_local_address_for(tcp): f = get_local_address_for assert f("tcp://127.0.0.1:80") == "tcp://127.0.0.1" @@ -204,7 +215,7 @@ def test_get_local_address_for(): @pytest.mark.asyncio -async def test_tcp_listener_does_not_call_handler_on_handshake_error(): +async def test_tcp_listener_does_not_call_handler_on_handshake_error(tcp): handle_comm_called = False async def handle_comm(comm): @@ -226,7 +237,7 @@ async def handle_comm(comm): @pytest.mark.asyncio -async def test_tcp_specific(): +async def test_tcp_specific(tcp): """ Test concrete TCP API. """ @@ -269,7 +280,7 @@ async def client_communicate(key, delay=0): @pytest.mark.asyncio -async def test_tls_specific(): +async def test_tls_specific(tcp): """ Test concrete TLS API. """ @@ -315,8 +326,7 @@ async def client_communicate(key, delay=0): @pytest.mark.asyncio -@pytest.mark.skip(reason="not applicable for asyncio") -async def test_comm_failure_threading(): +async def test_comm_failure_threading(tcp): """ When we fail to connect, make sure we don't make a lot of threads. @@ -324,6 +334,8 @@ async def test_comm_failure_threading(): We only assert for PY3, because the thread limit only is set for python 3. See github PR #2403 discussion for info. """ + if tcp is asyncio_tcp: + pytest.skip("not applicable for asyncio") async def sleep_for_60ms(): max_thread_count = 0 @@ -562,7 +574,7 @@ def checker(loc): @pytest.mark.asyncio -async def test_default_client_server_ipv4(): +async def test_default_client_server_ipv4(tcp): # Default scheme is (currently) TCP await check_client_server("127.0.0.1", tcp_eq("127.0.0.1")) await check_client_server("127.0.0.1:3201", tcp_eq("127.0.0.1", 3201)) @@ -579,7 +591,7 @@ async def test_default_client_server_ipv4(): @requires_ipv6 @pytest.mark.asyncio -async def test_default_client_server_ipv6(): +async def test_default_client_server_ipv6(tcp): await check_client_server("[::1]", tcp_eq("::1")) await check_client_server("[::1]:3211", tcp_eq("::1", 3211)) await check_client_server("[::]", tcp_eq("::"), tcp_eq(EXTERNAL_IP6)) @@ -589,7 +601,7 @@ async def test_default_client_server_ipv6(): @pytest.mark.asyncio -async def test_tcp_client_server_ipv4(): +async def test_tcp_client_server_ipv4(tcp): await check_client_server("tcp://127.0.0.1", tcp_eq("127.0.0.1")) await check_client_server("tcp://127.0.0.1:3221", tcp_eq("127.0.0.1", 3221)) await check_client_server("tcp://0.0.0.0", tcp_eq("0.0.0.0"), tcp_eq(EXTERNAL_IP4)) @@ -604,7 +616,7 @@ async def test_tcp_client_server_ipv4(): @requires_ipv6 @pytest.mark.asyncio -async def test_tcp_client_server_ipv6(): +async def test_tcp_client_server_ipv6(tcp): await check_client_server("tcp://[::1]", tcp_eq("::1")) await check_client_server("tcp://[::1]:3231", tcp_eq("::1", 3231)) await check_client_server("tcp://[::]", tcp_eq("::"), tcp_eq(EXTERNAL_IP6)) @@ -614,7 +626,7 @@ async def test_tcp_client_server_ipv6(): @pytest.mark.asyncio -async def test_tls_client_server_ipv4(): +async def test_tls_client_server_ipv4(tcp): await check_client_server("tls://127.0.0.1", tls_eq("127.0.0.1"), **tls_kwargs) await check_client_server( "tls://127.0.0.1:3221", tls_eq("127.0.0.1", 3221), **tls_kwargs @@ -626,7 +638,7 @@ async def test_tls_client_server_ipv4(): @requires_ipv6 @pytest.mark.asyncio -async def test_tls_client_server_ipv6(): +async def test_tls_client_server_ipv6(tcp): await check_client_server("tls://[::1]", tls_eq("::1"), **tls_kwargs) @@ -642,7 +654,7 @@ async def test_inproc_client_server(): @pytest.mark.asyncio -async def test_tls_reject_certificate(): +async def test_tls_reject_certificate(tcp): cli_ctx = get_client_ssl_context() serv_ctx = get_server_ssl_context() @@ -714,12 +726,12 @@ async def handle_comm(comm): @pytest.mark.asyncio -async def test_tcp_comm_closed_implicit(): +async def test_tcp_comm_closed_implicit(tcp): await check_comm_closed_implicit("tcp://127.0.0.1") @pytest.mark.asyncio -async def test_tls_comm_closed_implicit(): +async def test_tls_comm_closed_implicit(tcp): await check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs) @@ -752,12 +764,12 @@ async def check_comm_closed_explicit(addr, listen_args={}, connect_args={}): @pytest.mark.asyncio -async def test_tcp_comm_closed_explicit(): +async def test_tcp_comm_closed_explicit(tcp): await check_comm_closed_explicit("tcp://127.0.0.1") @pytest.mark.asyncio -async def test_tls_comm_closed_explicit(): +async def test_tls_comm_closed_explicit(tcp): await check_comm_closed_explicit("tls://127.0.0.1", **tls_kwargs) @@ -817,11 +829,13 @@ async def handle_comm(comm): @pytest.mark.asyncio -@pytest.mark.skip(reason="Not applicable for asyncio") -async def test_comm_closed_on_buffer_error(): +async def test_comm_closed_on_buffer_error(tcp): # Internal errors from comm.stream.write, such as # BufferError should lead to the stream being closed # and not re-used. See GitHub #4133 + if tcp is asyncio_tcp: + pytest.skip("Not applicable for asyncio") + reader, writer = await get_tcp_comm_pair() def _write(data): @@ -847,12 +861,12 @@ async def echo(comm): @pytest.mark.asyncio -async def test_retry_connect(monkeypatch): +async def test_retry_connect(tcp, monkeypatch): async def echo(comm): message = await comm.read() await comm.write(message) - class UnreliableConnector(TCPConnector): + class UnreliableConnector(tcp.TCPConnector): def __init__(self): self.num_failures = 2 @@ -866,7 +880,7 @@ async def connect(self, address, deserialize=True, **connection_args): self.failures += 1 raise OSError() - class UnreliableBackend(TCPBackend): + class UnreliableBackend(tcp.TCPBackend): _connector_class = UnreliableConnector monkeypatch.setitem(backends, "tcp", UnreliableBackend()) @@ -882,8 +896,8 @@ class UnreliableBackend(TCPBackend): @pytest.mark.asyncio -async def test_handshake_slow_comm(monkeypatch): - class SlowComm(TCP): +async def test_handshake_slow_comm(tcp, monkeypatch): + class SlowComm(tcp.TCP): def __init__(self, *args, delay_in_comm=0.5, **kwargs): super().__init__(*args, **kwargs) self.delay_in_comm = delay_in_comm @@ -897,10 +911,10 @@ async def write(self, *args, **kwargs): res = await super(type(self), self).write(*args, **kwargs) return res - class SlowConnector(TCPConnector): + class SlowConnector(tcp.TCPConnector): comm_class = SlowComm - class SlowBackend(TCPBackend): + class SlowBackend(tcp.TCPBackend): def get_connector(self): return SlowConnector() @@ -933,7 +947,7 @@ async def check_connect_timeout(addr): @pytest.mark.asyncio -async def test_tcp_connect_timeout(): +async def test_tcp_connect_timeout(tcp): await check_connect_timeout("tcp://127.0.0.1:44444") @@ -961,7 +975,7 @@ async def handle_comm(comm): @pytest.mark.asyncio -async def test_tcp_many_listeners(): +async def test_tcp_many_listeners(tcp): await check_many_listeners("tcp://127.0.0.1") await check_many_listeners("tcp://0.0.0.0") await check_many_listeners("tcp://") @@ -1117,7 +1131,7 @@ def check_out(deserialize_flag, out_value): @pytest.mark.asyncio -async def test_tcp_deserialize(): +async def test_tcp_deserialize(tcp): await check_deserialize("tcp://") @@ -1165,7 +1179,7 @@ async def test_inproc_deserialize_roundtrip(): @pytest.mark.asyncio -async def test_tcp_deserialize_roundtrip(): +async def test_tcp_deserialize_roundtrip(tcp): await check_deserialize_roundtrip("tcp://") @@ -1195,7 +1209,7 @@ async def handle_comm(comm): @pytest.mark.asyncio -async def test_tcp_deserialize_eoferror(): +async def test_tcp_deserialize_eoferror(tcp): await check_deserialize_eoferror("tcp://") @@ -1218,7 +1232,7 @@ async def check_repr(a, b): @pytest.mark.asyncio -async def test_tcp_repr(): +async def test_tcp_repr(tcp): a, b = await get_tcp_comm_pair() assert a.local_address in repr(b) assert b.local_address in repr(a) @@ -1226,7 +1240,7 @@ async def test_tcp_repr(): @pytest.mark.asyncio -async def test_tls_repr(): +async def test_tls_repr(tcp): a, b = await get_tls_comm_pair() assert a.local_address in repr(b) assert b.local_address in repr(a) @@ -1249,13 +1263,13 @@ async def check_addresses(a, b): @pytest.mark.asyncio -async def test_tcp_adresses(): +async def test_tcp_adresses(tcp): a, b = await get_tcp_comm_pair() await check_addresses(a, b) @pytest.mark.asyncio -async def test_tls_adresses(): +async def test_tls_adresses(tcp): a, b = await get_tls_comm_pair() await check_addresses(a, b) diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index eba7a6bf83..89e9116219 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -844,6 +844,14 @@ properties: ``True``, a CUDA context will be created on the first device listed in ``CUDA_VISIBLE_DEVICES``. + tcp: + type: object + properties: + backend: + type: string + description: | + The TCP backend implementation to use. Must be either `tornado` or `asyncio`. + websockets: type: object properties: diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 8edd3da9ed..05428a0ae4 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -221,6 +221,9 @@ distributed: key: null cert: null + tcp: + backend: tornado # The backend to use for TCP, one of {tornado, asyncio} + websockets: shard: 8MiB diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 68a73dc625..3d03ccc701 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -591,8 +591,8 @@ async def test_connection_pool_close_while_connecting(monkeypatch): Ensure a closed connection pool guarantees to have no connections left open even if it is closed mid-connecting """ - from distributed.comm.asyncio_tcp import TCPBackend, TCPConnector from distributed.comm.registry import backends + from distributed.comm.tcp import TCPBackend, TCPConnector class SlowConnector(TCPConnector): async def connect(self, address, deserialize, **connection_args): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 22d3ce3f9b..57176b7ce1 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -31,7 +31,6 @@ get_worker, wait, ) -from distributed.comm.asyncio_tcp import TCPBackend from distributed.comm.registry import backends from distributed.compatibility import LINUX, WINDOWS from distributed.core import CommClosedError, Status, rpc @@ -1538,6 +1537,7 @@ async def test_protocol_from_scheduler_address(cleanup, Worker): async def test_host_uses_scheduler_protocol(cleanup, monkeypatch): # Ensure worker uses scheduler's protocol to determine host address, not the default scheme # See https://github.com/dask/distributed/pull/4883 + from distributed.comm.tcp import TCPBackend class BadBackend(TCPBackend): def get_address_host(self, loc): diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 129af9ccce..727be41cc6 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -46,7 +46,7 @@ import dask -from distributed.comm.asyncio_tcp import TCP +from distributed.comm.tcp import TCP from . import system from .client import Client, _global_clients, default_client From 95f44e7a5dd224627690994cc09bc804d6eb8f17 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Wed, 24 Nov 2021 08:12:14 -0600 Subject: [PATCH 09/15] Prioritize IPV4 over IPV6 when reporting addresses --- distributed/comm/asyncio_tcp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 40c7eb5011..48c0775054 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -64,11 +64,10 @@ def local_addr(self): def peer_addr(self): if self._transport is None: return "" - try: - host, port = self._transport.get_extra_info("peername") - return unparse_host_port(host, port) - except Exception: - return "" + peername = self._transport.get_extra_info("peername") + if peername is not None: + return unparse_host_port(*peername[:2]) + return "" @property def is_closed(self): @@ -459,7 +458,6 @@ def _get_extra_kwargs(self, address, **kwargs): return {} def _on_connection(self, protocol): - logger.debug("Incoming connection") comm = self.comm_class( protocol, local_addr=self.prefix + protocol.local_addr, @@ -493,6 +491,8 @@ async def _start_all_interfaces_with_random_port(self): flags=socket.AI_PASSIVE, proto=0, ) + # Sort infos to always bind ipv4 before ipv6 + infos = sorted(infos, key=lambda x: x[0].name) # This code is a simplified and modified version of that found in # cpython here: # https://github.com/python/cpython/blob/401272e6e660445d6556d5cd4db88ed4267a50b3/Lib/asyncio/base_events.py#L1439 From 36470629055bc71963b7050ab1f12b7bb6748493 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 3 Dec 2021 14:15:47 -0600 Subject: [PATCH 10/15] coalesce small buffers Due to flaws in distributed's current serialization process, it's easy for a message to be serialized as hundreds/thousands of tiny buffers. We now coalesce these buffers into a series of larger buffers following a few heuristics. The goal is to minimize copies while also minimizing socket.send calls. This also moves to using `sendmsg` for writing from multiple buffers. Both optimizations seem critical for some real world workflows. --- distributed/comm/asyncio_tcp.py | 112 ++++++++++++++++++++++++++------ 1 file changed, 93 insertions(+), 19 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 48c0775054..b08ec9123f 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -24,6 +24,60 @@ _COMM_CLOSED = object() +def coalesce_buffers(buffers, target_buffer_size=64 * 1024, small_buffer_size=2048): + """Given a list of buffers, coalesce them into a new list of buffers that + minimizes both copying and tiny writes. + + Parameters + ---------- + buffers : list of bytes_like + target_buffer_size : int, optional + The target intermediate buffer size from concatenating small buffers + together. Coalesced buffers will be no larger than approximately this size. + small_buffer_size : int, optional + Buffers <= this size are considered "small" and may be copied. Buffers + larger than this may also be copied if the total message length is less + than ``target_buffer_size``. + """ + # Nothing to do + if len(buffers) == 1: + return buffers + + # If the whole message can be sent in <= target_buffer_size, always concatenate + if sum(map(len, buffers)) <= target_buffer_size: + return [b"".join(buffers)] + + out_buffers = [] + concat = [] # A list of buffers to concatenate + csize = 0 # The total size of the concatenated buffers + + def flush(): + nonlocal csize + if concat: + if len(concat) == 1: + out_buffers.append(concat[0]) + else: + out_buffers.append(b"".join(concat)) + concat.clear() + csize = 0 + + for b in buffers: + if isinstance(b, memoryview): + b = b.cast("B") + size = len(b) + if size <= small_buffer_size: + concat.append(b) + csize += size + if csize >= target_buffer_size: + flush() + else: + flush() + out_buffers.append(b) + flush() + + return out_buffers + + class DaskCommProtocol(asyncio.BufferedProtocol): def __init__(self, on_connection=None, min_read_size=128 * 1024): super().__init__() @@ -272,16 +326,13 @@ async def write(self, frames): # change to the comms. msg_nbytes = sum(frames_nbytes) + (nframes + 1) * 8 header = struct.pack(f"{nframes + 2}Q", msg_nbytes, nframes, *frames_nbytes) - frames = [header, *frames] - if msg_nbytes < 2 ** 17: # 128kiB - # small enough, send in one go - frames = [b"".join(frames)] + buffers = coalesce_buffers([header, *frames]) - if len(frames) > 1: - self._transport.writelines(frames) + if len(buffers) > 1: + self._transport.writelines(buffers) else: - self._transport.write(frames[0]) + self._transport.write(buffers[0]) if self._transport.is_closing(): await asyncio.sleep(0) elif self._paused: @@ -656,6 +707,9 @@ class _ZeroCopyWriter: Note that this workaround isn't used with the windows ProactorEventLoop or uvloop.""" + SENDMSG_MAX_SIZE = 1024 * 1024 # 1 MiB + SENDMSG_MAX_COUNT = 16 + def __init__(self, transport): self.transport = transport self._buffers = collections.deque() @@ -666,17 +720,34 @@ def _buffer_append(self, data): def _buffer_peek(self): offset = self._offset - buf = self._buffers[0] - return buf[offset:] if offset else buf + buffers = [] + size = 0 + count = 0 + for b in self._buffers: + if offset: + b = b[offset:] + offset = 0 + buffers.append(b) + size += len(b) + count += 1 + if size > self.SENDMSG_MAX_SIZE or count == self.SENDMSG_MAX_COUNT: + break + return buffers def _buffer_advance(self, size): - b = self._buffers[0] - b_len = len(b) - self._offset - if b_len == size: - self._buffers.popleft() - self._offset = 0 - else: - self._offset += size + offset = self._offset + buffers = self._buffers + while size: + b = buffers[0] + b_len = len(b) - offset + if b_len <= size: + buffers.popleft() + size -= b_len + offset = 0 + else: + offset += size + break + self._offset = offset def write(self, data): transport = self.transport @@ -749,9 +820,12 @@ def get_extra_info(self, key): return self.transport.get_extra_info(key) def _do_bulk_write(self): - buf = self._buffer_peek() - n = self.transport._sock.send(buf) - if n: + buffers = self._buffer_peek() + if len(buffers) == 1: + n = self.transport._sock.send(buffers[0]) + self._buffer_advance(n) + else: + n = self.transport._sock.sendmsg(buffers) self._buffer_advance(n) def _on_write_ready(self): From e65467e5e68449ea335f46009fa39fb423b8b88c Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 6 Dec 2021 10:15:45 -0600 Subject: [PATCH 11/15] Fixup memoryview nbytes During extraction of `coalesce_buffers` a bug handling memoryviews of non-single-byte formats was introduced. This fixes that. --- distributed/comm/asyncio_tcp.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index b08ec9123f..1fec4c2e4f 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -12,7 +12,7 @@ import dask -from ..utils import ensure_ip, get_ip, get_ipv6, nbytes +from ..utils import ensure_ip, get_ip, get_ipv6 from .addressing import parse_host_port, unparse_host_port from .core import Comm, CommClosedError, Connector, Listener from .registry import Backend @@ -62,8 +62,6 @@ def flush(): csize = 0 for b in buffers: - if isinstance(b, memoryview): - b = b.cast("B") size = len(b) if size <= small_buffer_size: concat.append(b) @@ -318,8 +316,11 @@ async def write(self, frames): if self._exception: raise self._exception + # Ensure all memoryviews are in single-byte format + frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames] + nframes = len(frames) - frames_nbytes = [nbytes(f) for f in frames] + frames_nbytes = [len(f) for f in frames] # TODO: the old TCP comm included an extra `msg_nbytes` prefix that # isn't really needed. We include it here for backwards compatibility, # but this could be removed if we ever want to make another breaking @@ -716,7 +717,7 @@ def __init__(self, transport): self._offset = 0 def _buffer_append(self, data): - self._buffers.append(memoryview(data).cast("B")) + self._buffers.append(memoryview(data)) def _buffer_peek(self): offset = self._offset From 01d99b4a21f00c69bbe6957996f91635744c4587 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 6 Dec 2021 13:57:53 -0600 Subject: [PATCH 12/15] Add write backpressure, respond to comments - Adds write backpressure to the `_ZeroCopyWriter` patch - Sets the high water mark for all transports - the default can be quite low in some implementations. - Respond to some feedback (still have more to do). --- distributed/comm/asyncio_tcp.py | 119 +++++++++++++++++++++++--------- 1 file changed, 86 insertions(+), 33 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 1fec4c2e4f..e035fcd34b 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -91,8 +91,8 @@ def __init__(self, on_connection=None, min_read_size=128 * 1024): # Per-message state self._using_default_buffer = True - self._default_buffer = memoryview(bytearray(min_read_size)) - self._default_len = min_read_size + self._default_len = max(min_read_size, 16) # need at least 16 bytes of buffer + self._default_buffer = memoryview(bytearray(self._default_len)) self._default_start = 0 self._default_end = 0 @@ -146,9 +146,15 @@ async def _close(self): await self._is_closed def connection_made(self, transport): + # XXX: When using asyncio, the default builtin transport makes + # excessive copies when buffering. For the case of TCP on asyncio (no + # TLS) we patch around that with a wrapper class that handles the write + # side with minimal copying. if type(transport) is asyncio.selector_events._SelectorSocketTransport: - transport = _ZeroCopyWriter(transport) + transport = _ZeroCopyWriter(self, transport) self._transport = transport + # Set the buffer limits to something more optimal for large data transfer. + self._transport.set_write_buffer_limits(high=512 * 1024) # 512 KiB if self.on_connection is not None: self.on_connection(self) @@ -235,12 +241,12 @@ def _frames_check_remaining(self): def _parse_frames(self): while True: + available = self._default_end - self._default_start # Are we out of frames? if not self._frames_check_remaining(): self._message_completed() - return True + return bool(available) # Are we out of data? - available = self._default_end - self._default_start if not available: return False @@ -315,6 +321,10 @@ async def read(self): async def write(self, frames): if self._exception: raise self._exception + elif self._paused: + # Wait until there's room in the write buffer + self._drain_waiter = self._loop.create_future() + await self._drain_waiter # Ensure all memoryviews are in single-byte format frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames] @@ -334,11 +344,6 @@ async def write(self, frames): self._transport.writelines(buffers) else: self._transport.write(buffers[0]) - if self._transport.is_closing(): - await asyncio.sleep(0) - elif self._paused: - self._drain_waiter = self._loop.create_future() - await self._drain_waiter return msg_nbytes @@ -458,7 +463,7 @@ class TCPConnector(Connector): comm_class = TCP async def connect(self, address, deserialize=True, **kwargs): - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() ip, port = parse_host_port(address) kwargs = self._get_extra_kwargs(address, **kwargs) @@ -533,7 +538,7 @@ async def _start_all_interfaces_with_random_port(self): IPV6), rather than both interfaces sharing the same random port. We work around this here. See https://bugs.python.org/issue45693 for more info.""" - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() # Typically resolves to list with length == 2 (one IPV4, one IPV6). infos = await loop.getaddrinfo( None, @@ -604,7 +609,7 @@ async def _start_all_interfaces_with_random_port(self): return servers async def start(self): - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() if not self.ip and not self.port: servers = await self._start_all_interfaces_with_random_port() else: @@ -711,13 +716,68 @@ class _ZeroCopyWriter: SENDMSG_MAX_SIZE = 1024 * 1024 # 1 MiB SENDMSG_MAX_COUNT = 16 - def __init__(self, transport): + def __init__(self, protocol, transport): + self.protocol = protocol self.transport = transport + self._loop = asyncio.get_running_loop() + + # This class mucks with the builtin asyncio transport's internals. + # Check that the bits we touch still exist. + for attr in [ + "_sock", + "_sock_fd", + "_fatal_error", + "_eof", + "_closing", + "_conn_lost", + "_call_connection_lost", + ]: + assert hasattr(transport, attr) + # Likewise, this calls a few internal methods of `loop`, ensure they + # still exist. + for attr in ["_add_writer", "_remove_writer"]: + assert hasattr(self._loop, attr) + self._buffers = collections.deque() self._offset = 0 + self._size = 0 + self._protocol_paused = False + self.set_write_buffer_limits() + + def set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64 * 1024 # 64 KiB + else: + high = 4 * low + if low is None: + low = high // 4 + self._high_water = high + self._low_water = low + self._maybe_pause_protocol() + + def _maybe_pause_protocol(self): + if not self._protocol_paused and self._size > self._high_water: + self._protocol_paused = True + self.protocol.pause_writing() + + def _maybe_resume_protocol(self): + if self._protocol_paused and self._size <= self._low_water: + self._protocol_paused = False + self.protocol.resume_writing() + + def _buffer_clear(self): + self._buffers.clear() + self._size = 0 + self._offset = 0 def _buffer_append(self, data): - self._buffers.append(memoryview(data)) + if not isinstance(data, memoryview): + data = memoryview(data) + if data.format != "B": + data = data.cast("B") + self._size += len(data) + self._buffers.append(data) def _buffer_peek(self): offset = self._offset @@ -736,6 +796,8 @@ def _buffer_peek(self): return buffers def _buffer_advance(self, size): + self._size -= size + offset = self._offset buffers = self._buffers while size: @@ -755,8 +817,6 @@ def write(self, data): if transport._eof: raise RuntimeError("Cannot call write() after write_eof()") - if transport._empty_waiter is not None: - raise RuntimeError("unable to write; sendfile is in progress") if not data: return if transport._conn_lost: @@ -777,11 +837,11 @@ def write(self, data): if not data: return # Not all was written; register write handler. - transport._loop._add_writer(transport._sock_fd, self._on_write_ready) + self._loop._add_writer(transport._sock_fd, self._on_write_ready) # Add it to the buffer. self._buffer_append(data) - transport._maybe_pause_protocol() + self._maybe_pause_protocol() def writelines(self, buffers): waiting = bool(self._buffers) @@ -802,19 +862,16 @@ def writelines(self, buffers): if not self._buffers: return # Not all was written; register write handler. - self.transport._loop._add_writer( - self.transport._sock_fd, self._on_write_ready - ) - - self.transport._maybe_pause_protocol() + self._loop._add_writer(self.transport._sock_fd, self._on_write_ready) - def is_closing(self): - return self.transport.is_closing() + self._maybe_pause_protocol() def close(self): + self._buffer_clear() return self.transport.close() def abort(self): + self._buffer_clear() return self.transport.abort() def get_extra_info(self, key): @@ -840,17 +897,13 @@ def _on_write_ready(self): except (SystemExit, KeyboardInterrupt): raise except BaseException as exc: - transport._loop._remove_writer(transport._sock_fd) + self._loop._remove_writer(transport._sock_fd) self._buffers.clear() transport._fatal_error(exc, "Fatal write error on socket transport") - if transport._empty_waiter is not None: - transport._empty_waiter.set_exception(exc) else: - transport._maybe_resume_protocol() + self._maybe_resume_protocol() if not self._buffers: - transport._loop._remove_writer(transport._sock_fd) - if transport._empty_waiter is not None: - transport._empty_waiter.set_result(None) + self._loop._remove_writer(transport._sock_fd) if transport._closing: transport._call_connection_lost(None) elif transport._eof: From cd47373dee596f81a064971bb307138cb7dc39d8 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Thu, 9 Dec 2021 10:50:52 -0600 Subject: [PATCH 13/15] A few fixups - Fix windows support (hopefully) - Lots of comments and docstrings - A bit of simplification of the buffer management --- distributed/comm/asyncio_tcp.py | 187 ++++++++++++++++++++++---------- 1 file changed, 132 insertions(+), 55 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index e035fcd34b..0971889d1c 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import asyncio import collections import logging +import os import socket import struct import weakref +from typing import Any try: import ssl @@ -24,7 +28,11 @@ _COMM_CLOSED = object() -def coalesce_buffers(buffers, target_buffer_size=64 * 1024, small_buffer_size=2048): +def coalesce_buffers( + buffers: list[bytes], + target_buffer_size: int = 64 * 1024, + small_buffer_size: int = 2048, +) -> list[bytes]: """Given a list of buffers, coalesce them into a new list of buffers that minimizes both copying and tiny writes. @@ -47,8 +55,8 @@ def coalesce_buffers(buffers, target_buffer_size=64 * 1024, small_buffer_size=20 if sum(map(len, buffers)) <= target_buffer_size: return [b"".join(buffers)] - out_buffers = [] - concat = [] # A list of buffers to concatenate + out_buffers: list[bytes] = [] + concat: list[bytes] = [] # A list of buffers to concatenate csize = 0 # The total size of the concatenated buffers def flush(): @@ -77,40 +85,72 @@ def flush(): class DaskCommProtocol(asyncio.BufferedProtocol): + """Manages a state machine for parsing the message framing used by dask. + + Parameters + ---------- + on_connection : callable, optional + A callback to call on connection, used server side for handling + incoming connections. + min_read_size : int, optional + The minimum buffer size to pass to ``socket.recv_into``. Larger sizes + will result in fewer recv calls, at the cost of more copying. For + request-response comms (where only one message may be in the queue at a + time), a smaller value is likely more performant. + """ + def __init__(self, on_connection=None, min_read_size=128 * 1024): super().__init__() self.on_connection = on_connection - self._exception = None + self._loop = asyncio.get_running_loop() + # On error (or close) this contains an exception to raise to the caller + self._exception: Exception | None = None + # A queue of received messages self._queue = asyncio.Queue() + # The corresponding transport, set on `connection_made` self._transport = None + # Is the protocol paused? self._paused = False - self._drain_waiter = None - self._loop = asyncio.get_running_loop() + # If the protocol is paused, this holds a future to wait on until it's + # unpaused. + self._drain_waiter: asyncio.Future | None = None + # A future for waiting until the protocol is actually closed self._is_closed = self._loop.create_future() + # In the interest of reducing the number of `recv` calls, we always + # want to provide the opportunity to read `min_read_size` bytes from + # the socket (since memcpy is much faster than recv). Each read event + # may read into either a default buffer (of size `min_read_size`), or + # directly into one of the message frames (if the frame size is > + # `min_read_size`). + # Per-message state self._using_default_buffer = True self._default_len = max(min_read_size, 16) # need at least 16 bytes of buffer self._default_buffer = memoryview(bytearray(self._default_len)) + # Index in default_buffer pointing to the first unparsed byte self._default_start = 0 + # Index in default_buffer pointing to the last written byte self._default_end = 0 - self._nframes = None - self._frame_lengths = None - self._frames = None - self._frame_index = None - self._frame_nbytes_needed = 0 + # Each message is composed of one or more frames, these attributes + # are filled in as the message is parsed, and cleared once a message + # is fully parsed. + self._nframes: int | None = None + self._frame_lengths: list[int] | None = None + self._frames: list[memoryview] | None = None + self._frame_index: int | None = None # current frame to parse + self._frame_nbytes_needed: int = 0 # nbytes left for parsing current frame @property def local_addr(self): if self._transport is None: return "" - try: - host, port = self._transport.get_extra_info("socket").getsockname()[:2] - return unparse_host_port(host, port) - except Exception: - return "" + sockname = self._transport.get_extra_info("sockname") + if sockname is not None: + return unparse_host_port(*sockname[:2]) + return "" @property def peer_addr(self): @@ -159,6 +199,9 @@ def connection_made(self, transport): self.on_connection(self) def get_buffer(self, sizehint): + """Get a buffer to read into for this read event""" + # Read into the default buffer if there are no frames or the current + # frame is small. Otherwise read directly into the current frame. if self._frames is None or self._frame_nbytes_needed < self._default_len: self._using_default_buffer = True return self._default_buffer[self._default_end :] @@ -173,24 +216,28 @@ def buffer_updated(self, nbytes): if self._using_default_buffer: self._default_end += nbytes - while self._parse_next(): - pass - self._reset_default_buffer() + self._parse_default_buffer() else: self._frame_nbytes_needed -= nbytes if not self._frames_check_remaining(): self._message_completed() - def _parse_next(self): - if self._nframes is None: - if not self._parse_nframes(): - return False - if len(self._frame_lengths) < self._nframes: - if not self._parse_frame_lengths(): - return False - return self._parse_frames() + def _parse_default_buffer(self): + """Parse all messages in the default buffer.""" + while True: + if self._nframes is None: + if not self._parse_nframes(): + break + if len(self._frame_lengths) < self._nframes: + if not self._parse_frame_lengths(): + break + if not self._parse_frames(): + break + self._reset_default_buffer() def _parse_nframes(self): + """Fill in `_nframes` from the default buffer. Returns True if + successful, False if more data is needed""" # TODO: we drop the message total size prefix (sent as part of the # tornado-based tcp implementation), as it's not needed. If we ever # drop that prefix entirely, we can adjust this code (change 16 -> 8 @@ -205,6 +252,8 @@ def _parse_nframes(self): return False def _parse_frame_lengths(self): + """Fill in `_frame_lengths` from the default buffer. Returns True if + successful, False if more data is needed""" needed = self._nframes - len(self._frame_lengths) available = (self._default_end - self._default_start) // 8 n_read = min(available, needed) @@ -240,6 +289,8 @@ def _frames_check_remaining(self): return False def _parse_frames(self): + """Fill in `_frames` from the default buffer. Returns True if + successful, False if more data is needed""" while True: available = self._default_end - self._default_start # Are we out of frames? @@ -260,18 +311,22 @@ def _parse_frames(self): self._frame_nbytes_needed -= n_read def _reset_default_buffer(self): + """Reset the default buffer for the next read event""" start = self._default_start end = self._default_end if start < end and start != 0: + # Still some unparsed data, copy it to the front of the buffer self._default_buffer[: end - start] = self._default_buffer[start:end] self._default_start = 0 self._default_end = end - start elif start == end: + # All data is parsed, just reset the indices self._default_start = 0 self._default_end = 0 def _message_completed(self): + """Push a completed message to the queue and reset per-message state""" self._queue.put_nowait(self._frames) self._nframes = None self._frames = None @@ -308,7 +363,8 @@ def resume_writing(self): if not waiter.done(): waiter.set_result(None) - async def read(self): + async def read(self) -> list[bytes]: + """Read a single message from the comm.""" # Even if comm is closed, we still yield all received data before # erroring if self._queue is not None: @@ -316,15 +372,17 @@ async def read(self): if out is not _COMM_CLOSED: return out self._queue = None + assert self._exception is not None raise self._exception - async def write(self, frames): + async def write(self, frames: list[bytes]) -> int: + """Write a message to the comm.""" if self._exception: raise self._exception elif self._paused: # Wait until there's room in the write buffer self._drain_waiter = self._loop.create_future() - await self._drain_waiter + await self._drain_waiter # type: ignore # Ensure all memoryviews are in single-byte format frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames] @@ -706,6 +764,8 @@ class TLSBackend(TCPBackend): _listener_class = TLSListener +# This class is based on parts of `asyncio.selector_events._SelectorSocketTransport` +# (https://github.com/python/cpython/blob/dc4a212bd305831cb4b187a2e0cc82666fcb15ca/Lib/asyncio/selector_events.py#L757). class _ZeroCopyWriter: """The builtin socket transport in asyncio makes a bunch of copies, which can make sending large amounts of data much slower. This hacks around that. @@ -713,8 +773,19 @@ class _ZeroCopyWriter: Note that this workaround isn't used with the windows ProactorEventLoop or uvloop.""" - SENDMSG_MAX_SIZE = 1024 * 1024 # 1 MiB - SENDMSG_MAX_COUNT = 16 + # We use sendmsg for scatter IO if it's available. Since bookkeeping + # scatter IO has a small cost, we want to minimize the amount of processing + # we do for each send call. We assume the system send buffer is < 4 MiB + # (which would be very large), and set a limit on the number of buffers to + # pass to sendmsg. + SENDMSG_MAX_SIZE = 4 * 1024 * 1024 # 4 MiB + if hasattr(socket.socket, "sendmsg"): + try: + SENDMSG_MAX_COUNT = os.sysconf("SC_IOV_MAX") + except Exception: + SENDMSG_MAX_COUNT = 16 # Should be supported on all systems + else: + SENDMSG_MAX_COUNT = 1 # sendmsg not supported, use send instead def __init__(self, protocol, transport): self.protocol = protocol @@ -738,13 +809,18 @@ def __init__(self, protocol, transport): for attr in ["_add_writer", "_remove_writer"]: assert hasattr(self._loop, attr) - self._buffers = collections.deque() - self._offset = 0 + # A deque of buffers to send + self._buffers: collections.deque[memoryview] = collections.deque() + # The total size of all bytes left to send in _buffers self._size = 0 + # Is the backing protocol paused? self._protocol_paused = False + # Initialize the buffer limits self.set_write_buffer_limits() - def set_write_buffer_limits(self, high=None, low=None): + def set_write_buffer_limits(self, high: int = None, low: int = None): + """Set the write buffer limits""" + # Copied almost verbatim from asyncio.transports._FlowControlMixin if high is None: if low is None: high = 64 * 1024 # 64 KiB @@ -757,21 +833,24 @@ def set_write_buffer_limits(self, high=None, low=None): self._maybe_pause_protocol() def _maybe_pause_protocol(self): + """If the high water mark has been reached, pause the protocol""" if not self._protocol_paused and self._size > self._high_water: self._protocol_paused = True self.protocol.pause_writing() def _maybe_resume_protocol(self): + """If the low water mark has been reached, unpause the protocol""" if self._protocol_paused and self._size <= self._low_water: self._protocol_paused = False self.protocol.resume_writing() def _buffer_clear(self): + """Clear the send buffer""" self._buffers.clear() self._size = 0 - self._offset = 0 - def _buffer_append(self, data): + def _buffer_append(self, data: bytes) -> None: + """Append new data to the send buffer""" if not isinstance(data, memoryview): data = memoryview(data) if data.format != "B": @@ -779,15 +858,12 @@ def _buffer_append(self, data): self._size += len(data) self._buffers.append(data) - def _buffer_peek(self): - offset = self._offset + def _buffer_peek(self) -> list[memoryview]: + """Get one or more buffers to write to the socket""" buffers = [] size = 0 count = 0 for b in self._buffers: - if offset: - b = b[offset:] - offset = 0 buffers.append(b) size += len(b) count += 1 @@ -795,24 +871,23 @@ def _buffer_peek(self): break return buffers - def _buffer_advance(self, size): + def _buffer_advance(self, size: int) -> None: + """Advance the buffer index forward by `size`""" self._size -= size - offset = self._offset buffers = self._buffers while size: b = buffers[0] - b_len = len(b) - offset + b_len = len(b) if b_len <= size: buffers.popleft() size -= b_len - offset = 0 else: - offset += size + buffers[0] = b[size:] break - self._offset = offset - def write(self, data): + def write(self, data: bytes) -> None: + # Copied almost verbatim from asyncio.selector_events._SelectorSocketTransport transport = self.transport if transport._eof: @@ -843,7 +918,8 @@ def write(self, data): self._buffer_append(data) self._maybe_pause_protocol() - def writelines(self, buffers): + def writelines(self, buffers: list[bytes]) -> None: + # Based on modified version of `write` above waiting = bool(self._buffers) for b in buffers: self._buffer_append(b) @@ -866,18 +942,18 @@ def writelines(self, buffers): self._maybe_pause_protocol() - def close(self): + def close(self) -> None: self._buffer_clear() return self.transport.close() - def abort(self): + def abort(self) -> None: self._buffer_clear() return self.transport.abort() - def get_extra_info(self, key): + def get_extra_info(self, key: str) -> Any: return self.transport.get_extra_info(key) - def _do_bulk_write(self): + def _do_bulk_write(self) -> None: buffers = self._buffer_peek() if len(buffers) == 1: n = self.transport._sock.send(buffers[0]) @@ -886,7 +962,8 @@ def _do_bulk_write(self): n = self.transport._sock.sendmsg(buffers) self._buffer_advance(n) - def _on_write_ready(self): + def _on_write_ready(self) -> None: + # Copied almost verbatim from asyncio.selector_events._SelectorSocketTransport transport = self.transport if transport._conn_lost: return From 3a8acf5ed346adf798de60213cbda582cc752db5 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Thu, 9 Dec 2021 13:05:24 -0600 Subject: [PATCH 14/15] More fixups - Closed comms always raise `CommClosedError` instead of underlying OS error. - Simplify buffer peak implementation - Decrease threshold for always concatenating buffers. With `sendmsg` there's less pressure to do so, as scatter IO is equally efficient. --- distributed/comm/asyncio_tcp.py | 60 ++++++++++++--------------------- 1 file changed, 21 insertions(+), 39 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 0971889d1c..34ccfade04 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -7,6 +7,7 @@ import socket import struct import weakref +from itertools import islice from typing import Any try: @@ -43,18 +44,12 @@ def coalesce_buffers( The target intermediate buffer size from concatenating small buffers together. Coalesced buffers will be no larger than approximately this size. small_buffer_size : int, optional - Buffers <= this size are considered "small" and may be copied. Buffers - larger than this may also be copied if the total message length is less - than ``target_buffer_size``. + Buffers <= this size are considered "small" and may be copied. """ # Nothing to do if len(buffers) == 1: return buffers - # If the whole message can be sent in <= target_buffer_size, always concatenate - if sum(map(len, buffers)) <= target_buffer_size: - return [b"".join(buffers)] - out_buffers: list[bytes] = [] concat: list[bytes] = [] # A list of buffers to concatenate csize = 0 # The total size of the concatenated buffers @@ -103,8 +98,6 @@ def __init__(self, on_connection=None, min_read_size=128 * 1024): super().__init__() self.on_connection = on_connection self._loop = asyncio.get_running_loop() - # On error (or close) this contains an exception to raise to the caller - self._exception: Exception | None = None # A queue of received messages self._queue = asyncio.Queue() # The corresponding transport, set on `connection_made` @@ -115,7 +108,7 @@ def __init__(self, on_connection=None, min_read_size=128 * 1024): # unpaused. self._drain_waiter: asyncio.Future | None = None # A future for waiting until the protocol is actually closed - self._is_closed = self._loop.create_future() + self._closed_waiter = self._loop.create_future() # In the interest of reducing the number of `recv` calls, we always # want to provide the opportunity to read `min_read_size` bytes from @@ -145,7 +138,7 @@ def __init__(self, on_connection=None, min_read_size=128 * 1024): @property def local_addr(self): - if self._transport is None: + if self.is_closed: return "" sockname = self._transport.get_extra_info("sockname") if sockname is not None: @@ -154,7 +147,7 @@ def local_addr(self): @property def peer_addr(self): - if self._transport is None: + if self.is_closed: return "" peername = self._transport.get_extra_info("peername") if peername is not None: @@ -166,12 +159,12 @@ def is_closed(self): return self._transport is None def _abort(self): - if self._transport is not None: + if not self.is_closed: self._transport, transport = None, self._transport transport.abort() def _close_from_finalizer(self, comm_repr): - if self._transport is not None: + if not self.is_closed: logger.warning(f"Closing dangling comm `{comm_repr}`") try: self._abort() @@ -180,10 +173,10 @@ def _close_from_finalizer(self, comm_repr): pass async def _close(self): - if self._transport is not None: + if not self.is_closed: self._transport, transport = None, self._transport transport.close() - await self._is_closed + await self._closed_waiter def connection_made(self, transport): # XXX: When using asyncio, the default builtin transport makes @@ -334,11 +327,8 @@ def _message_completed(self): self._frame_nbytes_remaining = 0 def connection_lost(self, exc=None): - if exc is None: - exc = CommClosedError("Connection closed") - self._exception = exc self._transport = None - self._is_closed.set_result(None) + self._closed_waiter.set_result(None) # Unblock read, if any self._queue.put_nowait(_COMM_CLOSED) @@ -349,7 +339,7 @@ def connection_lost(self, exc=None): if waiter is not None: self._drain_waiter = None if not waiter.done(): - waiter.set_exception(exc) + waiter.set_exception(CommClosedError("Connection closed")) def pause_writing(self): self._paused = True @@ -372,13 +362,12 @@ async def read(self) -> list[bytes]: if out is not _COMM_CLOSED: return out self._queue = None - assert self._exception is not None - raise self._exception + raise CommClosedError("Connection closed") async def write(self, frames: list[bytes]) -> int: """Write a message to the comm.""" - if self._exception: - raise self._exception + if self.is_closed: + raise CommClosedError("Connection closed") elif self._paused: # Wait until there's room in the write buffer self._drain_waiter = self._loop.create_future() @@ -396,7 +385,11 @@ async def write(self, frames: list[bytes]) -> int: msg_nbytes = sum(frames_nbytes) + (nframes + 1) * 8 header = struct.pack(f"{nframes + 2}Q", msg_nbytes, nframes, *frames_nbytes) - buffers = coalesce_buffers([header, *frames]) + if msg_nbytes < 4 * 1024: + # Always concatenate small messages + buffers = [b"".join([header, *frames])] + else: + buffers = coalesce_buffers([header, *frames]) if len(buffers) > 1: self._transport.writelines(buffers) @@ -778,7 +771,6 @@ class _ZeroCopyWriter: # we do for each send call. We assume the system send buffer is < 4 MiB # (which would be very large), and set a limit on the number of buffers to # pass to sendmsg. - SENDMSG_MAX_SIZE = 4 * 1024 * 1024 # 4 MiB if hasattr(socket.socket, "sendmsg"): try: SENDMSG_MAX_COUNT = os.sysconf("SC_IOV_MAX") @@ -860,16 +852,7 @@ def _buffer_append(self, data: bytes) -> None: def _buffer_peek(self) -> list[memoryview]: """Get one or more buffers to write to the socket""" - buffers = [] - size = 0 - count = 0 - for b in self._buffers: - buffers.append(b) - size += len(b) - count += 1 - if size > self.SENDMSG_MAX_SIZE or count == self.SENDMSG_MAX_COUNT: - break - return buffers + return list(islice(self._buffers, self.SENDMSG_MAX_COUNT)) def _buffer_advance(self, size: int) -> None: """Advance the buffer index forward by `size`""" @@ -957,10 +940,9 @@ def _do_bulk_write(self) -> None: buffers = self._buffer_peek() if len(buffers) == 1: n = self.transport._sock.send(buffers[0]) - self._buffer_advance(n) else: n = self.transport._sock.sendmsg(buffers) - self._buffer_advance(n) + self._buffer_advance(n) def _on_write_ready(self) -> None: # Copied almost verbatim from asyncio.selector_events._SelectorSocketTransport From 8351fbfffdd77fbebed257fe07ca5daebbec30bd Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 10 Dec 2021 09:08:09 -0600 Subject: [PATCH 15/15] mypy fixup --- distributed/comm/asyncio_tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/comm/asyncio_tcp.py b/distributed/comm/asyncio_tcp.py index 34ccfade04..a281a1ea73 100644 --- a/distributed/comm/asyncio_tcp.py +++ b/distributed/comm/asyncio_tcp.py @@ -370,8 +370,8 @@ async def write(self, frames: list[bytes]) -> int: raise CommClosedError("Connection closed") elif self._paused: # Wait until there's room in the write buffer - self._drain_waiter = self._loop.create_future() - await self._drain_waiter # type: ignore + drain_waiter = self._drain_waiter = self._loop.create_future() + await drain_waiter # Ensure all memoryviews are in single-byte format frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames]