Skip to content

Commit

Permalink
[PR #9326/fe26ae2 backport][3.10] Fix TimerContext not uncancelling t…
Browse files Browse the repository at this point in the history
…he current task (#9328)
  • Loading branch information
bdraco authored Sep 28, 2024
1 parent 52e0b91 commit a308f74
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGES/9326.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed cancellation leaking upwards on timeout -- by :user:`bdraco`.
23 changes: 20 additions & 3 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._tasks: List[asyncio.Task[Any]] = []
self._cancelled = False
self._cancelling = 0

def assert_timeout(self) -> None:
"""Raise TimeoutError if timer has already been cancelled."""
Expand All @@ -694,12 +695,17 @@ def assert_timeout(self) -> None:

def __enter__(self) -> BaseTimerContext:
task = asyncio.current_task(loop=self._loop)

if task is None:
raise RuntimeError(
"Timeout context manager should be used " "inside a task"
)

if sys.version_info >= (3, 11):
# Remember if the task was already cancelling
# so when we __exit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = task.cancelling()

if self._cancelled:
raise asyncio.TimeoutError from None

Expand All @@ -712,11 +718,22 @@ def __exit__(
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
enter_task: Optional[asyncio.Task[Any]] = None
if self._tasks:
self._tasks.pop()
enter_task = self._tasks.pop()

if exc_type is asyncio.CancelledError and self._cancelled:
raise asyncio.TimeoutError from None
assert enter_task is not None
# The timeout was hit, and the task was cancelled
# so we need to uncancel the last task that entered the context manager
# since the cancellation should not leak out of the context manager
if sys.version_info >= (3, 11):
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
if enter_task.uncancel() > self._cancelling:
return None
raise asyncio.TimeoutError from exc_val
return None

def timeout(self) -> None:
Expand Down
56 changes: 55 additions & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,61 @@ def test_timer_context_not_cancelled() -> None:
assert not m_asyncio.current_task.return_value.cancel.called


def test_timer_context_no_task(loop) -> None:
@pytest.mark.skipif(
sys.version_info < (3, 11), reason="Python 3.11+ is required for .cancelling()"
)
async def test_timer_context_timeout_does_not_leak_upward() -> None:
"""Verify that the TimerContext does not leak cancellation outside the context manager."""
loop = asyncio.get_running_loop()
ctx = helpers.TimerContext(loop)
current_task = asyncio.current_task()
assert current_task is not None
with pytest.raises(asyncio.TimeoutError):
with ctx:
assert current_task.cancelling() == 0
loop.call_soon(ctx.timeout)
await asyncio.sleep(1)

# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0


@pytest.mark.skipif(
sys.version_info < (3, 11), reason="Python 3.11+ is required for .cancelling()"
)
async def test_timer_context_timeout_does_swallow_cancellation() -> None:
"""Verify that the TimerContext does not swallow cancellation."""
loop = asyncio.get_running_loop()
current_task = asyncio.current_task()
assert current_task is not None
ctx = helpers.TimerContext(loop)

async def task_with_timeout() -> None:
nonlocal ctx
new_task = asyncio.current_task()
assert new_task is not None
with pytest.raises(asyncio.TimeoutError):
with ctx:
assert new_task.cancelling() == 0
await asyncio.sleep(1)

task = asyncio.create_task(task_with_timeout())
await asyncio.sleep(0)
task.cancel()
assert task.cancelling() == 1
ctx.timeout()

# Cancellation should not leak into the current task
assert current_task.cancelling() == 0
# Cancellation should not be swallowed if the task is cancelled
# and it also times out
await asyncio.sleep(0)
with pytest.raises(asyncio.CancelledError):
await task
assert task.cancelling() == 1


def test_timer_context_no_task(loop: asyncio.AbstractEventLoop) -> None:
with pytest.raises(RuntimeError):
with helpers.TimerContext(loop):
pass
Expand Down

0 comments on commit a308f74

Please sign in to comment.