diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 906babfc..020c473d 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -12,6 +12,8 @@ This library adheres to `Semantic Versioning 2.0 `_. - Removed a checkpoint when exiting a task group - Bumped minimum version of trio to v0.23 - Exposed the ``ResourceGuard`` class in the public API +- Fixed ``RuntimeError: Runner is closed`` when running higher-scoped async generator + fixtures in some cases (`#619 `_) **4.0.0** diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py index 762e9e83..a8dd6f3e 100644 --- a/src/anyio/pytest_plugin.py +++ b/src/anyio/pytest_plugin.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterator -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from inspect import isasyncgenfunction, iscoroutinefunction from typing import Any, Dict, Tuple, cast @@ -12,6 +12,8 @@ from .abc import TestRunner _current_runner: TestRunner | None = None +_runner_stack: ExitStack | None = None +_runner_leases = 0 def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]: @@ -28,27 +30,30 @@ def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]: def get_runner( backend_name: str, backend_options: dict[str, Any] ) -> Iterator[TestRunner]: - global _current_runner - if _current_runner: - yield _current_runner - return + global _current_runner, _runner_leases, _runner_stack + if _current_runner is None: + asynclib = get_async_backend(backend_name) + _runner_stack = ExitStack() + if sniffio.current_async_library_cvar.get(None) is None: + # Since we're in control of the event loop, we can cache the name of the + # async library + token = sniffio.current_async_library_cvar.set(backend_name) + _runner_stack.callback(sniffio.current_async_library_cvar.reset, token) - asynclib = get_async_backend(backend_name) - token = None - if sniffio.current_async_library_cvar.get(None) is None: - # Since we're in control of the event loop, we can cache the name of the async - # library - token = sniffio.current_async_library_cvar.set(backend_name) + backend_options = backend_options or {} + _current_runner = _runner_stack.enter_context( + asynclib.create_test_runner(backend_options) + ) + _runner_leases += 1 try: - backend_options = backend_options or {} - with asynclib.create_test_runner(backend_options) as runner: - _current_runner = runner - yield runner + yield _current_runner finally: - _current_runner = None - if token: - sniffio.current_async_library_cvar.reset(token) + _runner_leases -= 1 + if not _runner_leases: + assert _runner_stack is not None + _runner_stack.close() + _runner_stack = _current_runner = None def pytest_configure(config: Any) -> None: diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 296099d0..1193aca1 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -308,6 +308,44 @@ async def test_task_group(streams): result.assert_outcomes(passed=len(get_all_backends())) +def test_async_fixture_teardown_after_sync_test(testdir: Pytester) -> None: + # Regression test for #619 + testdir.makepyfile( + """ + import pytest + + from anyio import create_task_group, sleep + + @pytest.fixture(scope="session") + def anyio_backend(): + return "asyncio" + + + @pytest.fixture(scope="module") + async def bbbbbb(): + yield "" + + + @pytest.fixture(scope="module") + async def aaaaaa(): + yield "" + + + @pytest.mark.anyio + async def test_1(bbbbbb): + pass + + + @pytest.mark.anyio + async def test_2(aaaaaa, bbbbbb): + pass + """ + ) + + result = testdir.runpytest_subprocess(*pytest_args) + result.assert_outcomes(passed=2) + + def test_hypothesis_module_mark(testdir: Pytester) -> None: testdir.makepyfile( """