From 511c108ab7c47a0e8eec1ba539baaba699043ccc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 7 Dec 2022 12:36:16 +0000 Subject: [PATCH 1/3] Lazy import either 'anyio' or 'trio' --- httpcore/_synchronization.py | 115 +++++++++++++++++++++++++++++++---- 1 file changed, 102 insertions(+), 13 deletions(-) diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index a3b2a27a..d086b50a 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -2,17 +2,39 @@ from types import TracebackType from typing import Optional, Type -import anyio +import sniffio from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions class AsyncLock: def __init__(self) -> None: - self._lock = anyio.Lock() + self._backend = "" + + def setup(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a lock with the correct implementation. + """ + self._backend = sniffio.current_async_library() + if self._backend == "trio": + import trio + + self._trio_lock = trio.Lock() + else: + import anyio + + self._anyio_lock = anyio.Lock() async def __aenter__(self) -> "AsyncLock": - await self._lock.acquire() + if not self._backend: + self.setup() + + if self._backend == "trio": + await self._trio_lock.acquire() + else: + await self._anyio_lock.acquire() + return self async def __aexit__( @@ -21,32 +43,99 @@ async def __aexit__( exc_value: Optional[BaseException] = None, traceback: Optional[TracebackType] = None, ) -> None: - self._lock.release() + if self._backend == "trio": + self._trio_lock.release() + else: + self._anyio_lock.release() class AsyncEvent: def __init__(self) -> None: - self._event = anyio.Event() + self._backend = "" + + def setup(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a lock with the correct implementation. + """ + self._backend = sniffio.current_async_library() + if self._backend == "trio": + import trio + + self._trio_event = trio.Event() + else: + import anyio + + self._anyio_event = anyio.Event() def set(self) -> None: - self._event.set() + if not self._backend: + self.setup() + + if self._backend == "trio": + self._trio_event.set() + else: + self._anyio_event.set() async def wait(self, timeout: Optional[float] = None) -> None: - exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} - with map_exceptions(exc_map): - with anyio.fail_after(timeout): - await self._event.wait() + if not self._backend: + self.setup() + + if self._backend == "trio": + import trio + + trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout} + timeout_or_inf = float("inf") if timeout is None else timeout + with map_exceptions(trio_exc_map): + with trio.fail_after(timeout_or_inf): + await self._trio_event.wait() + else: + import anyio + + anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} + with map_exceptions(anyio_exc_map): + with anyio.fail_after(timeout): + await self._anyio_event.wait() class AsyncSemaphore: def __init__(self, bound: int) -> None: - self._semaphore = anyio.Semaphore(initial_value=bound, max_value=bound) + self._bound = bound + self._backend = "" + + def setup(self) -> None: + """ + Detect if we're running under 'asyncio' or 'trio' and create + a semaphore with the correct implementation. + """ + self._backend = sniffio.current_async_library() + if self._backend == "trio": + import trio + + self._trio_semaphore = trio.Semaphore( + initial_value=self._bound, max_value=self._bound + ) + else: + import anyio + + self._anyio_semaphore = anyio.Semaphore( + initial_value=self._bound, max_value=self._bound + ) async def acquire(self) -> None: - await self._semaphore.acquire() + if not self._backend: + self.setup() + + if self._backend == "trio": + await self._trio_semaphore.acquire() + else: + await self._anyio_semaphore.acquire() async def release(self) -> None: - self._semaphore.release() + if self._backend == "trio": + self._trio_semaphore.release() + else: + self._anyio_semaphore.release() class Lock: From 54c27efd814ed4c051d469d156e8583796dcbf5b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 7 Dec 2022 13:02:25 +0000 Subject: [PATCH 2/3] Add comments in _synchronization.py --- httpcore/_synchronization.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index d086b50a..dfc25d38 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -7,6 +7,11 @@ from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions +# Our async synchronization primatives use either 'anyio' or 'trio' depending +# on if they're running under asyncio or trio. +# +# We take care to only lazily import whichever of these two we need. + class AsyncLock: def __init__(self) -> None: self._backend = "" From df13bc6da80973cdd9260876ef7604b2708e327d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 7 Dec 2022 13:04:54 +0000 Subject: [PATCH 3/3] Add comments in _synchronization.py --- httpcore/_synchronization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index dfc25d38..be3dcfda 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -6,12 +6,12 @@ from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions - # Our async synchronization primatives use either 'anyio' or 'trio' depending # on if they're running under asyncio or trio. # # We take care to only lazily import whichever of these two we need. + class AsyncLock: def __init__(self) -> None: self._backend = "" @@ -143,6 +143,9 @@ async def release(self) -> None: self._anyio_semaphore.release() +# Our thread-based synchronization primitives... + + class Lock: def __init__(self) -> None: self._lock = threading.Lock()