diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a96abdc..84ea2de 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,7 +74,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: ['3.12'] + python-version: ['3.13-dev'] steps: - uses: actions/checkout@v3 - name: Setup Python diff --git a/requirements-dev-full.txt b/requirements-dev-full.txt index d64900a..dbc8570 100644 --- a/requirements-dev-full.txt +++ b/requirements-dev-full.txt @@ -21,7 +21,7 @@ build==1.2.1 # via pip-tools certifi==2024.6.2 # via requests -cffi==1.16.0 +cffi==1.17.0 # via cryptography charset-normalizer==3.3.2 # via requests @@ -194,9 +194,8 @@ tomli==2.0.1 # pytest tomlkit==0.12.5 # via pylint -trio==0.24.0 +trio==0.25.1 # via - # -r requirements-dev.in # pytest-trio # trio-websocket (setup.py) trustme==1.1.0 diff --git a/requirements-dev.in b/requirements-dev.in index e342182..922fb76 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -4,5 +4,4 @@ pip-tools>=5.5.0 pytest>=4.6 pytest-cov pytest-trio>=0.5.0 -trio<0.25 trustme diff --git a/requirements-dev.txt b/requirements-dev.txt index c9dd128..f28f625 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -71,9 +71,8 @@ tomli==2.0.1 # coverage # pip-tools # pytest -trio==0.24.0 +trio==0.25.1 # via - # -r requirements-dev.in # pytest-trio # trio-websocket (setup.py) trustme==1.1.0 diff --git a/tests/test_connection.py b/tests/test_connection.py index 3d45eb7..6cccefa 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -32,7 +32,9 @@ from __future__ import annotations from functools import partial, wraps +import re import ssl +import sys from unittest.mock import patch import attr @@ -48,6 +50,13 @@ except ImportError: from trio.hazmat import current_task # type: ignore # pylint: disable=ungrouped-imports + +# only available on trio>=0.25, we don't use it when testing lower versions +try: + from trio.testing import RaisesGroup +except ImportError: + pass + from trio_websocket import ( connect_websocket, connect_websocket_url, @@ -60,12 +69,18 @@ open_websocket, open_websocket_url, serve_websocket, + WebSocketConnection, WebSocketServer, WebSocketRequest, wrap_client_stream, wrap_server_stream ) +from trio_websocket._impl import _TRIO_EXC_GROUP_TYPE + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin + WS_PROTO_VERSION = tuple(map(int, wsproto.__version__.split('.'))) HOST = '127.0.0.1' @@ -428,6 +443,92 @@ async def handler(request): assert header_value == b'My test header' + + +@fail_after(5) +async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock): + """_reader_task._handle_ping_event triggers KeyboardInterrupt. + user code also raises exception. + Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup + """ + async def ki_raising_ping_handler(*args, **kwargs) -> None: + print("raising ki") + raise KeyboardInterrupt + monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler) + async def handler(request): + server_ws = await request.accept() + await server_ws.ping(b"a") + + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + with pytest.raises(KeyboardInterrupt) as exc_info: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): + with trio.fail_after(1) as cs: + cs.shield = True + await trio.sleep(2) + + e_cause = exc_info.value.__cause__ + assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE) + assert any(isinstance(e, trio.TooSlowError) for e in e_cause.exceptions) + +@fail_after(5) +async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock): + """_reader_task._handle_ping_event triggers ValueError. + user code also raises exception. + internal exception is in __cause__ exceptiongroup and user exc is delivered + """ + my_value_error = ValueError() + async def raising_ping_event(*args, **kwargs) -> None: + raise my_value_error + + monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event) + async def handler(request): + server_ws = await request.accept() + await server_ws.ping(b"a") + + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + with pytest.raises(trio.TooSlowError) as exc_info: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): + with trio.fail_after(1) as cs: + cs.shield = True + await trio.sleep(2) + + e_cause = exc_info.value.__cause__ + assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE) + assert my_value_error in e_cause.exceptions + +@fail_after(5) +async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock): + """Both user code and _reader_task raise Cancellation. + Check that open_websocket reraises the one from user code for traceback reasons. + """ + + + async def sleeping_ping_event(*args, **kwargs) -> None: + await trio.sleep_forever() + + # We monkeypatch WebSocketConnection._handle_ping_event to ensure it will actually + # raise Cancelled upon being cancelled. For some reason it doesn't otherwise. + monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", sleeping_ping_event) + async def handler(request): + server_ws = await request.accept() + await server_ws.ping(b"a") + user_cancelled = None + + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + with trio.move_on_after(2): + with pytest.raises(trio.Cancelled) as exc_info: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): + try: + await trio.sleep_forever() + except trio.Cancelled as e: + user_cancelled = e + raise + assert exc_info.value is user_cancelled + +def _trio_default_non_strict_exception_groups() -> bool: + assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" + return int(trio.__version__[2:4]) < 25 + @fail_after(1) async def test_handshake_exception_before_accept() -> None: ''' In #107, a request handler that throws an exception before finishing the @@ -436,7 +537,8 @@ async def test_handshake_exception_before_accept() -> None: async def handler(request): raise ValueError() - with pytest.raises(ValueError): + # pylint fails to resolve that BaseExceptionGroup will always be available + with pytest.raises((BaseExceptionGroup, ValueError)) as exc: # pylint: disable=possibly-used-before-assignment async with trio.open_nursery() as nursery: server = await nursery.start(serve_websocket, handler, HOST, 0, None) @@ -444,6 +546,19 @@ async def handler(request): use_ssl=False): pass + if _trio_default_non_strict_exception_groups(): + assert isinstance(exc.value, ValueError) + else: + # there's 4 levels of nurseries opened, leading to 4 nested groups: + # 1. this test + # 2. WebSocketServer.run + # 3. trio.serve_listeners + # 4. WebSocketServer._handle_connection + assert RaisesGroup( + RaisesGroup( + RaisesGroup( + RaisesGroup(ValueError)))).matches(exc.value) + @fail_after(1) async def test_reject_handshake(nursery): diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index b153034..a71e0be 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -13,6 +13,7 @@ import urllib.parse from typing import Iterable, List, Optional, Union +import outcome import trio import trio.abc from wsproto import ConnectionType, WSConnection @@ -35,7 +36,12 @@ # pylint doesn't care about the version_info check, so need to ignore the warning from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin -_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22) +_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22) + +if _IS_TRIO_MULTI_ERROR: + _TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member +else: + _TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds MESSAGE_QUEUE_SIZE = 1 @@ -44,6 +50,13 @@ logger = logging.getLogger('trio-websocket') +class TrioWebsocketInternalError(Exception): + """Raised as a fallback when open_websocket is unable to unwind an exceptiongroup + into a single preferred exception. This should never happen, if it does then + underlying assumptions about the internal code are incorrect. + """ + + def _ignore_cancel(exc): return None if isinstance(exc, trio.Cancelled) else exc @@ -70,7 +83,7 @@ def __exit__(self, ty, value, tb): if value is None or not self._armed: return False - if _TRIO_MULTI_ERROR: # pragma: no cover + if _IS_TRIO_MULTI_ERROR: # pragma: no cover filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled)) @@ -125,10 +138,33 @@ async def open_websocket( client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`), or server rejection (:exc:`ConnectionRejected`) during handshakes. ''' - async with trio.open_nursery() as new_nursery: + + # This context manager tries very very hard not to raise an exceptiongroup + # in order to be as transparent as possible for the end user. + # In the trivial case, this means that if user code inside the cm raises + # we make sure that it doesn't get wrapped. + + # If opening the connection fails, then we will raise that exception. User + # code is never executed, so we will never have multiple exceptions. + + # After opening the connection, we spawn _reader_task in the background and + # yield to user code. If only one of those raise a non-cancelled exception + # we will raise that non-cancelled exception. + # If we get multiple cancelled, we raise the user's cancelled. + # If both raise exceptions, we raise the user code's exception with the entire + # exception group as the __cause__. + # If we somehow get multiple exceptions, but no user exception, then we raise + # TrioWebsocketInternalError. + + # If closing the connection fails, then that will be raised as the top + # exception in the last `finally`. If we encountered exceptions in user code + # or in reader task then they will be set as the `__cause__`. + + + async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection: try: with trio.fail_after(connect_timeout): - connection = await connect_websocket(new_nursery, host, port, + return await connect_websocket(nursery, host, port, resource, use_ssl=use_ssl, subprotocols=subprotocols, extra_headers=extra_headers, message_queue_size=message_queue_size, @@ -137,14 +173,85 @@ async def open_websocket( raise ConnectionTimeout from None except OSError as e: raise HandshakeError from e + + async def _close_connection(connection: WebSocketConnection) -> None: try: - yield connection - finally: - try: - with trio.fail_after(disconnect_timeout): - await connection.aclose() - except trio.TooSlowError: - raise DisconnectionTimeout from None + with trio.fail_after(disconnect_timeout): + await connection.aclose() + except trio.TooSlowError: + raise DisconnectionTimeout from None + + connection: WebSocketConnection|None=None + close_result: outcome.Maybe[None] | None = None + user_error = None + + try: + async with trio.open_nursery() as new_nursery: + result = await outcome.acapture(_open_connection, new_nursery) + + if isinstance(result, outcome.Value): + connection = result.unwrap() + try: + yield connection + except BaseException as e: + user_error = e + raise + finally: + close_result = await outcome.acapture(_close_connection, connection) + # This exception handler should only be entered if either: + # 1. The _reader_task started in connect_websocket raises + # 2. User code raises an exception + # I.e. open/close_connection are not included + except _TRIO_EXC_GROUP_TYPE as e: + # user_error, or exception bubbling up from _reader_task + if len(e.exceptions) == 1: + raise e.exceptions[0] + + # contains at most 1 non-cancelled exceptions + exception_to_raise: BaseException|None = None + for sub_exc in e.exceptions: + if not isinstance(sub_exc, trio.Cancelled): + if exception_to_raise is not None: + # multiple non-cancelled + break + exception_to_raise = sub_exc + else: + if exception_to_raise is None: + # all exceptions are cancelled + # prefer raising the one from the user, for traceback reasons + if user_error is not None: + # no reason to raise from e, just to include a bunch of extra + # cancelleds. + raise user_error # pylint: disable=raise-missing-from + # multiple internal Cancelled is not possible afaik + raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from + raise exception_to_raise + + # if we have any KeyboardInterrupt in the group, make sure to raise it. + for sub_exc in e.exceptions: + if isinstance(sub_exc, KeyboardInterrupt): + raise sub_exc from e + + # Both user code and internal code raised non-cancelled exceptions. + # We "hide" the internal exception(s) in the __cause__ and surface + # the user_error. + if user_error is not None: + raise user_error from e + + raise TrioWebsocketInternalError( + "The trio-websocket API is not expected to raise multiple exceptions. " + "Please report this as a bug to " + "https://github.com/python-trio/trio-websocket" + ) from e # pragma: no cover + + finally: + if close_result is not None: + close_result.unwrap() + + + # error setting up, unwrap that exception + if connection is None: + result.unwrap() async def connect_websocket(nursery, host, port, resource, *, use_ssl,