Skip to content

Commit

Permalink
Wait until state is CLOSED to acces close_exc.
Browse files Browse the repository at this point in the history
Fix #1449.
  • Loading branch information
aaugustin committed Sep 21, 2024
1 parent 20739e0 commit 3640923
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ Bug fixes
start the connection handler anymore when ``process_request`` or
``process_response`` returns an HTTP response.

* Fixed a bug in the :mod:`threading` implementation that could lead to
incorrect error reporting when closing a connection while
:meth:`~sync.connection.Connection.recv` is running.

13.0.1
------

Expand Down
14 changes: 11 additions & 3 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ async def recv(self, decode: bool | None = None) -> Data:
try:
return await self.recv_messages.get(decode)
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
await asyncio.shield(self.connection_lost_waiter)
raise self.protocol.close_exc from self.recv_exc
except ConcurrencyError:
raise ConcurrencyError(
Expand Down Expand Up @@ -329,6 +331,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data
async for frame in self.recv_messages.get_iter(decode):
yield frame
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
await asyncio.shield(self.connection_lost_waiter)
raise self.protocol.close_exc from self.recv_exc
except ConcurrencyError:
raise ConcurrencyError(
Expand Down Expand Up @@ -864,6 +868,7 @@ async def send_context(
# raise an exception.
if raise_close_exc:
self.close_transport()
# Wait for the protocol state to be CLOSED before accessing close_exc.
await asyncio.shield(self.connection_lost_waiter)
raise self.protocol.close_exc from original_exc

Expand Down Expand Up @@ -926,11 +931,14 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.transport = transport

def connection_lost(self, exc: Exception | None) -> None:
self.protocol.receive_eof() # receive_eof is idempotent
# Calling protocol.receive_eof() is safe because it's idempotent.
# This guarantees that the protocol state becomes CLOSED.
self.protocol.receive_eof()
assert self.protocol.state is CLOSED

# Abort recv() and pending pings with a ConnectionClosed exception.
# Set recv_exc first to get proper exception reporting.
self.set_recv_exc(exc)

# Abort recv() and pending pings with a ConnectionClosed exception.
self.recv_messages.close()
self.abort_pings()

Expand Down
14 changes: 12 additions & 2 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def recv(self, timeout: float | None = None) -> Data:
try:
return self.recv_messages.get(timeout)
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
raise self.protocol.close_exc from self.recv_exc
except ConcurrencyError:
raise ConcurrencyError(
Expand Down Expand Up @@ -240,6 +242,8 @@ def recv_streaming(self) -> Iterator[Data]:
for frame in self.recv_messages.get_iter():
yield frame
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
raise self.protocol.close_exc from self.recv_exc
except ConcurrencyError:
raise ConcurrencyError(
Expand Down Expand Up @@ -629,8 +633,6 @@ def recv_events(self) -> None:
self.logger.error("unexpected internal error", exc_info=True)
with self.protocol_mutex:
self.set_recv_exc(exc)
# We don't know where we crashed. Force protocol state to CLOSED.
self.protocol.state = CLOSED
finally:
# This isn't expected to raise an exception.
self.close_socket()
Expand Down Expand Up @@ -738,6 +740,7 @@ def send_context(
# raise an exception.
if raise_close_exc:
self.close_socket()
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
raise self.protocol.close_exc from original_exc

Expand Down Expand Up @@ -788,4 +791,11 @@ def close_socket(self) -> None:
except OSError:
pass # socket is already closed
self.socket.close()

# Calling protocol.receive_eof() is safe because it's idempotent.
# This guarantees that the protocol state becomes CLOSED.
self.protocol.receive_eof()
assert self.protocol.state is CLOSED

# Abort recv() with a ConnectionClosed exception.
self.recv_messages.close()

0 comments on commit 3640923

Please sign in to comment.