diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 516beed4..206873e5 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -7,6 +7,8 @@ This library adheres to `Semantic Versioning 2.0 `_. - Fixed erroneous ``TypedAttributeLookupError`` if a typed attribute getter raises ``KeyError`` +- Fixed ``SocketStream.receive()`` not detecting EOF on asyncio if there is also data in + the read buffer (`#701 `_) **4.3.0** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 7be8b1d7..db15f93d 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -1047,6 +1047,7 @@ class StreamProtocol(asyncio.Protocol): read_event: asyncio.Event write_event: asyncio.Event exception: Exception | None = None + is_at_eof: bool = False def connection_made(self, transport: asyncio.BaseTransport) -> None: self.read_queue = deque() @@ -1068,6 +1069,7 @@ def data_received(self, data: bytes) -> None: self.read_event.set() def eof_received(self) -> bool | None: + self.is_at_eof = True self.read_event.set() return True @@ -1123,15 +1125,16 @@ def _raw_socket(self) -> socket.socket: async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: - await AsyncIOBackend.checkpoint() - if ( not self._protocol.read_event.is_set() and not self._transport.is_closing() + and not self._protocol.is_at_eof ): self._transport.resume_reading() await self._protocol.read_event.wait() self._transport.pause_reading() + else: + await AsyncIOBackend.checkpoint_if_cancelled() try: chunk = self._protocol.read_queue.popleft() diff --git a/tests/test_sockets.py b/tests/test_sockets.py index a35b5003..43c7058d 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -28,6 +28,7 @@ BrokenResourceError, BusyResourceError, ClosedResourceError, + EndOfStream, Event, TypedAttributeLookupError, connect_tcp, @@ -681,6 +682,29 @@ async def handle(stream: SocketStream) -> None: tg.cancel_scope.cancel() + async def test_eof_after_send(self, family: AnyIPAddressFamily) -> None: + """Regression test for #701.""" + received_bytes = b"" + + async def handle(stream: SocketStream) -> None: + nonlocal received_bytes + async with stream: + received_bytes = await stream.receive() + with pytest.raises(EndOfStream), fail_after(1): + await stream.receive() + + tg.cancel_scope.cancel() + + multi = await create_tcp_listener(family=family, local_host="localhost") + async with multi, create_task_group() as tg: + with socket.socket(family) as client: + client.connect(multi.extra(SocketAttribute.local_address)) + client.send(b"Hello") + client.shutdown(socket.SHUT_WR) + await multi.serve(handle) + + assert received_bytes == b"Hello" + @skip_ipv6_mark @pytest.mark.skipif( sys.platform == "win32",