diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 9a9d0d6e3cc269..d0d8c80a5b9324 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -407,9 +407,9 @@ async def wait_for(fut, timeout): if timeout is None: return await fut - if timeout <= 0: - fut = ensure_future(fut, loop=loop) + fut = ensure_future(fut, loop=loop) + if timeout <= 0: if fut.done(): return fut.result() @@ -421,47 +421,24 @@ async def wait_for(fut, timeout): else: raise exceptions.TimeoutError() - waiter = loop.create_future() - timeout_handle = loop.call_later(timeout, _release_waiter, waiter) - cb = functools.partial(_release_waiter, waiter) + timeout_occurred = False - fut = ensure_future(fut, loop=loop) - fut.add_done_callback(cb) + def _on_timeout(fut): + nonlocal timeout_occurred + timeout_occurred = fut.cancel() + timeout_handle = loop.call_later(timeout, _on_timeout, fut) + cb = fut.add_done_callback(lambda fut: timeout_handle.cancel()) try: - # wait until the future completes or the timeout - try: - await waiter - except exceptions.CancelledError: - if fut.done(): - return fut.result() - else: - fut.remove_done_callback(cb) - # We must ensure that the task is not running - # after wait_for() returns. - # See https://bugs.python.org/issue32751 - await _cancel_and_wait(fut, loop=loop) - raise - - if fut.done(): - return fut.result() + return await fut + except exceptions.CancelledError as exc: + if timeout_occurred: + raise exceptions.TimeoutError() from exc else: - fut.remove_done_callback(cb) - # We must ensure that the task is not running - # after wait_for() returns. - # See https://bugs.python.org/issue32751 - await _cancel_and_wait(fut, loop=loop) - # In case task cancellation failed with some - # exception, we should re-raise it - # See https://bugs.python.org/issue40607 - try: - fut.result() - except exceptions.CancelledError as exc: - raise exceptions.TimeoutError() from exc - else: - raise exceptions.TimeoutError() + raise finally: timeout_handle.cancel() + fut.remove_done_callback(cb) async def _wait(fs, timeout, return_when, loop):