diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index 516beed4..206873e5 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -7,6 +7,8 @@ This library adheres to `Semantic Versioning 2.0 `_.
- Fixed erroneous ``TypedAttributeLookupError`` if a typed attribute getter raises
``KeyError``
+- Fixed ``SocketStream.receive()`` not detecting EOF on asyncio if there is also data in
+ the read buffer (`#701 `_)
**4.3.0**
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index 7be8b1d7..db15f93d 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -1047,6 +1047,7 @@ class StreamProtocol(asyncio.Protocol):
read_event: asyncio.Event
write_event: asyncio.Event
exception: Exception | None = None
+ is_at_eof: bool = False
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.read_queue = deque()
@@ -1068,6 +1069,7 @@ def data_received(self, data: bytes) -> None:
self.read_event.set()
def eof_received(self) -> bool | None:
+ self.is_at_eof = True
self.read_event.set()
return True
@@ -1123,15 +1125,16 @@ def _raw_socket(self) -> socket.socket:
async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
- await AsyncIOBackend.checkpoint()
-
if (
not self._protocol.read_event.is_set()
and not self._transport.is_closing()
+ and not self._protocol.is_at_eof
):
self._transport.resume_reading()
await self._protocol.read_event.wait()
self._transport.pause_reading()
+ else:
+ await AsyncIOBackend.checkpoint_if_cancelled()
try:
chunk = self._protocol.read_queue.popleft()
diff --git a/tests/test_sockets.py b/tests/test_sockets.py
index a35b5003..43c7058d 100644
--- a/tests/test_sockets.py
+++ b/tests/test_sockets.py
@@ -28,6 +28,7 @@
BrokenResourceError,
BusyResourceError,
ClosedResourceError,
+ EndOfStream,
Event,
TypedAttributeLookupError,
connect_tcp,
@@ -681,6 +682,29 @@ async def handle(stream: SocketStream) -> None:
tg.cancel_scope.cancel()
+ async def test_eof_after_send(self, family: AnyIPAddressFamily) -> None:
+ """Regression test for #701."""
+ received_bytes = b""
+
+ async def handle(stream: SocketStream) -> None:
+ nonlocal received_bytes
+ async with stream:
+ received_bytes = await stream.receive()
+ with pytest.raises(EndOfStream), fail_after(1):
+ await stream.receive()
+
+ tg.cancel_scope.cancel()
+
+ multi = await create_tcp_listener(family=family, local_host="localhost")
+ async with multi, create_task_group() as tg:
+ with socket.socket(family) as client:
+ client.connect(multi.extra(SocketAttribute.local_address))
+ client.send(b"Hello")
+ client.shutdown(socket.SHUT_WR)
+ await multi.serve(handle)
+
+ assert received_bytes == b"Hello"
+
@skip_ipv6_mark
@pytest.mark.skipif(
sys.platform == "win32",