From 08e555321431fd5297fc26fb7ae1f41742999896 Mon Sep 17 00:00:00 2001 From: Jochen Ott Date: Wed, 11 Nov 2020 12:37:24 +0100 Subject: [PATCH] comm: close comm on low-level errors --- distributed/comm/tcp.py | 22 +++++++++++++++++----- distributed/comm/tests/test_comms.py | 19 +++++++++++++++++++ distributed/core.py | 16 +++++++++++----- distributed/tests/test_batched.py | 8 +------- 4 files changed, 48 insertions(+), 17 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index d92f83fa445..d672c5de3c2 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -200,6 +200,13 @@ async def read(self, deserializers=None): self.stream = None if not shutting_down(): convert_stream_closed_error(self, e) + except Exception: + # Some OSError or a another "low-level" exception. We do not really know what + # was already read from the underlying socket, so it is not even safe to retry + # here using the same stream. The only safe thing to do is to abort. + # (See also GitHub #4133). + self.abort() + raise else: try: msg = await from_frames( @@ -253,13 +260,18 @@ async def write(self, msg, serializers=None, on_error="message"): await future bytes_since_last_yield = 0 except StreamClosedError as e: - stream = None - convert_stream_closed_error(self, e) - except TypeError as e: + self.stream = None + if not shutting_down(): + convert_stream_closed_error(self, e) + except Exception: + # Some OSError or a another "low-level" exception. We do not really know what + # was already written to the underlying socket, so it is not even safe to retry + # here using the same stream. The only safe thing to do is to abort. + # (See also GitHub #4133). if stream._write_buffer is None: logger.info("tried to write message %s on closed stream", msg) - else: - raise + self.abort() + raise return sum(lengths) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 707819ba596..ce080b8ae88 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -790,6 +790,25 @@ async def handle_comm(comm): await comm.close() +@pytest.mark.asyncio +async def test_comm_closed_on_buffer_error(): + # Internal errors from comm.stream.write, such as + # BufferError should lead to the stream being closed + # and not re-used. See GitHub #4133 + reader, writer = await get_tcp_comm_pair() + + def _write(data): + raise BufferError + + writer.stream.write = _write + with pytest.raises(BufferError): + await writer.write("x") + assert writer.stream is None + + await reader.close() + await writer.close() + + # # Various stress tests # diff --git a/distributed/core.py b/distributed/core.py index 5ede25d7a05..20fb4b428a5 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -472,9 +472,12 @@ async def handle_comm(self, comm, shutting_down=shutting_down): ) break except Exception as e: - logger.exception(e) - await comm.write(error_message(e, status="uncaught-error")) - continue + logger.exception("Exception while reading from %s", address) + if comm.closed(): + raise + else: + await comm.write(error_message(e, status="uncaught-error")) + continue if not isinstance(msg, dict): raise TypeError( "Bad message type. Expected dict, got\n " + str(msg) @@ -531,8 +534,11 @@ async def handle_comm(self, comm, shutting_down=shutting_down): logger.info("Lost connection to %r: %s", address, e) break except Exception as e: - logger.exception(e) - result = error_message(e, status="uncaught-error") + logger.exception("Exception while handling op %s", op) + if comm.closed(): + raise + else: + result = error_message(e, status="uncaught-error") # result is not type stable: # when LHS is not Status then RHS must not be Status or it raises. diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index fa12f9649e9..5189353e43b 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -297,11 +297,5 @@ def raise_buffererror(*args, **kwargs): b.send("hello") b.send("world") await asyncio.sleep(0.020) - result = await comm.read() - assert result == ("hello", "hello", "world") - - b.send("raises when flushed") - await asyncio.sleep(0.020) # CommClosedError hit in callback - with pytest.raises(CommClosedError): - b.send("raises when sent") + await comm.read()