Skip to content
This repository has been archived by the owner on Mar 6, 2024. It is now read-only.

Commit

Permalink
Pre-empt current task before running handle, allowing unpatched tasks,
Browse files Browse the repository at this point in the history
…fixes #80
  • Loading branch information
erdewit committed Jan 21, 2024
1 parent 35618de commit c31e7c4
Showing 1 changed file with 14 additions and 38 deletions.
52 changes: 14 additions & 38 deletions nest_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def apply(loop=None):
"""Patch asyncio to make its event loop reentrant."""
_patch_asyncio()
_patch_policy()
_patch_task()
_patch_tornado()

loop = loop or asyncio.get_event_loop()
Expand Down Expand Up @@ -126,9 +125,20 @@ def _run_once(self):
break
handle = ready.popleft()
if not handle._cancelled:
handle._run()
# preempt the current task so that that checks in
# Task.__step do not raise
curr_task = curr_tasks.pop(self, None)

try:
handle._run()
finally:
# restore the current task
if curr_task is not None:
curr_tasks[self] = curr_task

handle = None


@contextmanager
def manage_run(self):
"""Set up the loop for running."""
Expand Down Expand Up @@ -193,45 +203,11 @@ def _check_running(self):
os.name == 'nt' and issubclass(cls, asyncio.ProactorEventLoop))
if sys.version_info < (3, 7, 0):
cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper
curr_tasks = asyncio.tasks._current_tasks \
if sys.version_info >= (3, 7, 0) else asyncio.Task._current_tasks
cls._nest_patched = True


def _patch_task():
"""Patch the Task's step and enter/leave methods to make it reentrant."""

def step(task, exc=None):
curr_task = curr_tasks.get(task._loop)
try:
step_orig(task, exc)
finally:
if curr_task is None:
curr_tasks.pop(task._loop, None)
else:
curr_tasks[task._loop] = curr_task

Task = asyncio.Task
if hasattr(Task, '_nest_patched'):
return
if sys.version_info >= (3, 7, 0):

def enter_task(loop, task):
curr_tasks[loop] = task

def leave_task(loop, task):
curr_tasks.pop(loop, None)

asyncio.tasks._enter_task = enter_task
asyncio.tasks._leave_task = leave_task
curr_tasks = asyncio.tasks._current_tasks
step_orig = Task._Task__step
Task._Task__step = step
else:
curr_tasks = Task._current_tasks
step_orig = Task._step
Task._step = step
Task._nest_patched = True


def _patch_tornado():
"""
If tornado is imported before nest_asyncio, make tornado aware of
Expand Down

0 comments on commit c31e7c4

Please sign in to comment.