From df6d341e0ccd5da6dd226689fc08d9897ef6bad3 Mon Sep 17 00:00:00 2001 From: Hasier Date: Mon, 29 Jan 2024 16:03:03 +0000 Subject: [PATCH 1/8] Add async strategies --- tenacity/__init__.py | 23 +- tenacity/{_asyncio.py => asyncio/__init__.py} | 108 ++++++- tenacity/asyncio/retry.py | 283 ++++++++++++++++++ tenacity/asyncio/stop.py | 122 ++++++++ tenacity/asyncio/wait.py | 219 ++++++++++++++ tests/test_asyncio.py | 16 +- 6 files changed, 751 insertions(+), 20 deletions(-) rename tenacity/{_asyncio.py => asyncio/__init__.py} (56%) create mode 100644 tenacity/asyncio/retry.py create mode 100644 tenacity/asyncio/stop.py create mode 100644 tenacity/asyncio/wait.py diff --git a/tenacity/__init__.py b/tenacity/__init__.py index bcee3f5..c160a62 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -24,7 +24,8 @@ import warnings from abc import ABC, abstractmethod from concurrent import futures -from inspect import iscoroutinefunction + +from . import asyncio as tasyncio # Import all built-in retry strategies for easier usage. from .retry import retry_base # noqa @@ -593,16 +594,16 @@ def retry(func: WrappedFn) -> WrappedFn: ... @t.overload def retry( - sleep: t.Callable[[t.Union[int, float]], t.Optional[t.Awaitable[None]]] = sleep, - stop: "StopBaseT" = stop_never, - wait: "WaitBaseT" = wait_none(), - retry: "RetryBaseT" = retry_if_exception_type(), - before: t.Callable[["RetryCallState"], None] = before_nothing, - after: t.Callable[["RetryCallState"], None] = after_nothing, - before_sleep: t.Optional[t.Callable[["RetryCallState"], None]] = None, + sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = sleep, + stop: "t.Union[StopBaseT, tasyncio.stop.StopBaseT]" = stop_never, + wait: "t.Union[WaitBaseT, tasyncio.wait.WaitBaseT]" = wait_none(), + retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(), + before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, + after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, + before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, reraise: bool = False, retry_error_cls: t.Type["RetryError"] = RetryError, - retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Any]] = None, + retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]] = None, ) -> t.Callable[[WrappedFn], WrappedFn]: ... @@ -624,7 +625,7 @@ def wrap(f: WrappedFn) -> WrappedFn: f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)" ) r: "BaseRetrying" - if iscoroutinefunction(f): + if tasyncio.is_coroutine_callable(f): r = AsyncRetrying(*dargs, **dkw) elif ( tornado @@ -640,7 +641,7 @@ def wrap(f: WrappedFn) -> WrappedFn: return wrap -from tenacity._asyncio import AsyncRetrying # noqa:E402,I100 +from tenacity.asyncio import AsyncRetrying # noqa:E402,I100 if tornado: from tenacity.tornadoweb import TornadoRetrying diff --git a/tenacity/_asyncio.py b/tenacity/asyncio/__init__.py similarity index 56% rename from tenacity/_asyncio.py rename to tenacity/asyncio/__init__.py index b06303f..a0c4c89 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/asyncio/__init__.py @@ -24,8 +24,48 @@ from tenacity import DoAttempt from tenacity import DoSleep from tenacity import RetryCallState +from tenacity import RetryError +from tenacity import after_nothing +from tenacity import before_nothing from tenacity import _utils +# Import all built-in retry strategies for easier usage. +from .retry import RetryBaseT +from .retry import retry_all # noqa +from .retry import retry_always # noqa +from .retry import retry_any # noqa +from .retry import retry_if_exception # noqa +from .retry import retry_if_exception_type # noqa +from .retry import retry_if_exception_cause_type # noqa +from .retry import retry_if_not_exception_type # noqa +from .retry import retry_if_not_result # noqa +from .retry import retry_if_result # noqa +from .retry import retry_never # noqa +from .retry import retry_unless_exception_type # noqa +from .retry import retry_if_exception_message # noqa +from .retry import retry_if_not_exception_message # noqa +# Import all built-in stop strategies for easier usage. +from .stop import StopBaseT +from .stop import stop_after_attempt # noqa +from .stop import stop_after_delay # noqa +from .stop import stop_before_delay # noqa +from .stop import stop_all # noqa +from .stop import stop_any # noqa +from .stop import stop_never # noqa +from .stop import stop_when_event_set # noqa +# Import all built-in wait strategies for easier usage. +from .wait import WaitBaseT +from .wait import wait_chain # noqa +from .wait import wait_combine # noqa +from .wait import wait_exponential # noqa +from .wait import wait_fixed # noqa +from .wait import wait_incrementing # noqa +from .wait import wait_none # noqa +from .wait import wait_random # noqa +from .wait import wait_random_exponential # noqa +from .wait import wait_random_exponential as wait_full_jitter # noqa +from .wait import wait_exponential_jitter # noqa + WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) @@ -38,15 +78,31 @@ def asyncio_sleep(duration: float) -> t.Awaitable[None]: class AsyncRetrying(BaseRetrying): - sleep: t.Callable[[float], t.Awaitable[t.Any]] - def __init__( self, - sleep: t.Callable[[float], t.Awaitable[t.Any]] = asyncio_sleep, - **kwargs: t.Any, + sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = asyncio_sleep, + stop: "t.Union[StopBaseT, StopBaseT]" = stop_never, + wait: "t.Union[WaitBaseT, WaitBaseT]" = wait_none(), + retry: "t.Union[RetryBaseT, RetryBaseT]" = retry_if_exception_type(), + before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, + after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, + before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, + reraise: bool = False, + retry_error_cls: t.Type["RetryError"] = RetryError, + retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]] = None, ) -> None: - super().__init__(**kwargs) - self.sleep = sleep + super().__init__( + sleep=sleep, # type: ignore[arg-type] + stop=stop, # type: ignore[arg-type] + wait=wait, # type: ignore[arg-type] + retry=retry, # type: ignore[arg-type] + before=before, # type: ignore[arg-type] + after=after, # type: ignore[arg-type] + before_sleep=before_sleep, # type: ignore[arg-type] + reraise=reraise, + retry_error_cls=retry_error_cls, + retry_error_callback=retry_error_callback, + ) async def __call__( # type: ignore[override] self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any @@ -65,7 +121,7 @@ async def __call__( # type: ignore[override] retry_state.set_result(result) elif isinstance(do, DoSleep): retry_state.prepare_for_next_attempt() - await self.sleep(do) + await self.sleep(do) # type: ignore[misc] else: return do # type: ignore[no-any-return] @@ -127,7 +183,7 @@ async def __anext__(self) -> AttemptManager: return AttemptManager(retry_state=self._retry_state) elif isinstance(do, DoSleep): self._retry_state.prepare_for_next_attempt() - await self.sleep(do) + await self.sleep(do) # type: ignore[misc] else: raise StopAsyncIteration @@ -146,3 +202,39 @@ async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined] return async_wrapped # type: ignore[return-value] + + +__all__ = [ + "retry_all", + "retry_always", + "retry_any", + "retry_if_exception", + "retry_if_exception_type", + "retry_if_exception_cause_type", + "retry_if_not_exception_type", + "retry_if_not_result", + "retry_if_result", + "retry_never", + "retry_unless_exception_type", + "retry_if_exception_message", + "retry_if_not_exception_message", + "stop_after_attempt", + "stop_after_delay", + "stop_before_delay", + "stop_all", + "stop_any", + "stop_never", + "stop_when_event_set", + "wait_chain", + "wait_combine", + "wait_exponential", + "wait_fixed", + "wait_incrementing", + "wait_none", + "wait_random", + "wait_random_exponential", + "wait_full_jitter", + "wait_exponential_jitter", + "WrappedFn", + "AsyncRetrying", +] diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py new file mode 100644 index 0000000..eb63286 --- /dev/null +++ b/tenacity/asyncio/retry.py @@ -0,0 +1,283 @@ +# Copyright 2016–2021 Julien Danjou +# Copyright 2016 Joshua Harlow +# Copyright 2013-2014 Ray Holder +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import re +import typing + +from tenacity import retry_base + +if typing.TYPE_CHECKING: + from tenacity import RetryCallState + + +class retry_base(retry_base): # type: ignore[no-redef] + """Abstract base class for retry strategies.""" + + @abc.abstractmethod + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + pass + + +RetryBaseT = typing.Union[retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]] + + +class _retry_never(retry_base): + """Retry strategy that never rejects any result.""" + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return False + + +retry_never = _retry_never() + + +class _retry_always(retry_base): + """Retry strategy that always rejects any result.""" + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return True + + +retry_always = _retry_always() + + +class retry_if_exception(retry_base): + """Retry strategy that retries if an exception verifies a predicate.""" + + def __init__(self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]]) -> None: + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if retry_state.outcome.failed: + exception = retry_state.outcome.exception() + if exception is None: + raise RuntimeError("outcome failed but the exception is None") + return await self.predicate(exception) + else: + return False + + +class retry_if_exception_type(retry_if_exception): + """Retries if an exception has been raised of one or more types.""" + + def __init__( + self, + exception_types: typing.Union[ + typing.Type[BaseException], + typing.Tuple[typing.Type[BaseException], ...], + ] = Exception, + ) -> None: + self.exception_types = exception_types + + async def predicate(e: BaseException) -> bool: + return isinstance(e, exception_types) + + super().__init__(predicate) + + +class retry_if_not_exception_type(retry_if_exception): + """Retries except an exception has been raised of one or more types.""" + + def __init__( + self, + exception_types: typing.Union[ + typing.Type[BaseException], + typing.Tuple[typing.Type[BaseException], ...], + ] = Exception, + ) -> None: + self.exception_types = exception_types + + async def predicate(e: BaseException) -> bool: + return not isinstance(e, exception_types) + + super().__init__(predicate) + + +class retry_unless_exception_type(retry_if_exception): + """Retries until an exception is raised of one or more types.""" + + def __init__( + self, + exception_types: typing.Union[ + typing.Type[BaseException], + typing.Tuple[typing.Type[BaseException], ...], + ] = Exception, + ) -> None: + self.exception_types = exception_types + + async def predicate(e: BaseException) -> bool: + return not isinstance(e, exception_types) + + super().__init__(predicate) + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + # always retry if no exception was raised + if not retry_state.outcome.failed: + return True + + exception = retry_state.outcome.exception() + if exception is None: + raise RuntimeError("outcome failed but the exception is None") + return await self.predicate(exception) + + +class retry_if_exception_cause_type(retry_base): + """Retries if any of the causes of the raised exception is of one or more types. + + The check on the type of the cause of the exception is done recursively (until finding + an exception in the chain that has no `__cause__`) + """ + + def __init__( + self, + exception_types: typing.Union[ + typing.Type[BaseException], + typing.Tuple[typing.Type[BaseException], ...], + ] = Exception, + ) -> None: + self.exception_cause_types = exception_types + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__ called before outcome was set") + + if retry_state.outcome.failed: + exc = retry_state.outcome.exception() + while exc is not None: + if isinstance(exc.__cause__, self.exception_cause_types): + return True + exc = exc.__cause__ + + return False + + +class retry_if_result(retry_base): + """Retries if the result verifies a predicate.""" + + def __init__(self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]) -> None: + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if not retry_state.outcome.failed: + return await self.predicate(retry_state.outcome.result()) + else: + return False + + +class retry_if_not_result(retry_base): + """Retries if the result refutes a predicate.""" + + def __init__(self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]) -> None: + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if not retry_state.outcome.failed: + return not await self.predicate(retry_state.outcome.result()) + else: + return False + + +class retry_if_exception_message(retry_if_exception): + """Retries if an exception message equals or matches.""" + + def __init__( + self, + message: typing.Optional[str] = None, + match: typing.Optional[str] = None, + ) -> None: + if message and match: + raise TypeError(f"{self.__class__.__name__}() takes either 'message' or 'match', not both") + + # set predicate + if message: + + async def message_fnc(exception: BaseException) -> bool: + return message == str(exception) + + predicate = message_fnc + elif match: + prog = re.compile(match) + + async def match_fnc(exception: BaseException) -> bool: + return bool(prog.match(str(exception))) + + predicate = match_fnc + else: + raise TypeError(f"{self.__class__.__name__}() missing 1 required argument 'message' or 'match'") + + super().__init__(predicate) + + +class retry_if_not_exception_message(retry_if_exception_message): + """Retries until an exception message equals or matches.""" + + def __init__( + self, + message: typing.Optional[str] = None, + match: typing.Optional[str] = None, + ) -> None: + super().__init__(message, match) + if_predicate = self.predicate + + # invert predicate + async def predicate(e: BaseException) -> bool: + return not if_predicate(e) + + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if not retry_state.outcome.failed: + return True + + exception = retry_state.outcome.exception() + if exception is None: + raise RuntimeError("outcome failed but the exception is None") + return await self.predicate(exception) + + +class retry_any(retry_base): + """Retries if any of the retries condition is valid.""" + + def __init__(self, *retries: retry_base) -> None: + self.retries = retries + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return any(r(retry_state) for r in self.retries) + + +class retry_all(retry_base): + """Retries if all the retries condition are valid.""" + + def __init__(self, *retries: retry_base) -> None: + self.retries = retries + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return all(r(retry_state) for r in self.retries) diff --git a/tenacity/asyncio/stop.py b/tenacity/asyncio/stop.py new file mode 100644 index 0000000..1528426 --- /dev/null +++ b/tenacity/asyncio/stop.py @@ -0,0 +1,122 @@ +# Copyright 2016–2021 Julien Danjou +# Copyright 2016 Joshua Harlow +# Copyright 2013-2014 Ray Holder +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import typing + +from tenacity import _utils +from tenacity.stop import stop_base + +if typing.TYPE_CHECKING: + import asyncio + + from tenacity import RetryCallState + + +class stop_base(stop_base): # type: ignore[no-redef] + """Abstract base class for stop strategies.""" + + @abc.abstractmethod + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + pass + + +StopBaseT = typing.Union[stop_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]] + + +class stop_any(stop_base): + """Stop if any of the stop condition is valid.""" + + def __init__(self, *stops: stop_base) -> None: + self.stops = stops + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return any(x(retry_state) for x in self.stops) + + +class stop_all(stop_base): + """Stop if all the stop conditions are valid.""" + + def __init__(self, *stops: stop_base) -> None: + self.stops = stops + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return all(x(retry_state) for x in self.stops) + + +class _stop_never(stop_base): + """Never stop.""" + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return False + + +stop_never = _stop_never() + + +class stop_when_event_set(stop_base): + """Stop when the given event is set.""" + + def __init__(self, event: "asyncio.Event") -> None: + self.event = event + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return self.event.is_set() + + +class stop_after_attempt(stop_base): + """Stop when the previous attempt >= max_attempt.""" + + def __init__(self, max_attempt_number: int) -> None: + self.max_attempt_number = max_attempt_number + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + return retry_state.attempt_number >= self.max_attempt_number + + +class stop_after_delay(stop_base): + """ + Stop when the time from the first attempt >= limit. + + Note: `max_delay` will be exceeded, so when used with a `wait`, the actual total delay will be greater + than `max_delay` by some of the final sleep period before `max_delay` is exceeded. + + If you need stricter timing with waits, consider `stop_before_delay` instead. + """ + + def __init__(self, max_delay: _utils.time_unit_type) -> None: + self.max_delay = _utils.to_seconds(max_delay) + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.seconds_since_start is None: + raise RuntimeError("__call__() called but seconds_since_start is not set") + return retry_state.seconds_since_start >= self.max_delay + + +class stop_before_delay(stop_base): + """ + Stop right before the next attempt would take place after the time from the first attempt >= limit. + + Most useful when you are using with a `wait` function like wait_random_exponential, but need to make + sure that the max_delay is not exceeded. + """ + + def __init__(self, max_delay: _utils.time_unit_type) -> None: + self.max_delay = _utils.to_seconds(max_delay) + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.seconds_since_start is None: + raise RuntimeError("__call__() called but seconds_since_start is not set") + return retry_state.seconds_since_start + retry_state.upcoming_sleep >= self.max_delay diff --git a/tenacity/asyncio/wait.py b/tenacity/asyncio/wait.py new file mode 100644 index 0000000..021b34d --- /dev/null +++ b/tenacity/asyncio/wait.py @@ -0,0 +1,219 @@ +# Copyright 2016–2021 Julien Danjou +# Copyright 2016 Joshua Harlow +# Copyright 2013-2014 Ray Holder +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import random +import typing + +from tenacity import _utils +from tenacity.wait import wait_base + +if typing.TYPE_CHECKING: + from tenacity import RetryCallState + + +class wait_base(wait_base): # type: ignore[no-redef] + """Abstract base class for wait strategies.""" + + @abc.abstractmethod + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + pass + + +WaitBaseT = typing.Union[wait_base, typing.Callable[["RetryCallState"], typing.Awaitable[typing.Union[float, int]]]] + + +class wait_fixed(wait_base): + """Wait strategy that waits a fixed amount of time between each retry.""" + + def __init__(self, wait: _utils.time_unit_type) -> None: + self.wait_fixed = _utils.to_seconds(wait) + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + return self.wait_fixed + + +class wait_none(wait_fixed): + """Wait strategy that doesn't wait at all before retrying.""" + + def __init__(self) -> None: + super().__init__(0) + + +class wait_random(wait_base): + """Wait strategy that waits a random amount of time between min/max.""" + + def __init__(self, min: _utils.time_unit_type = 0, max: _utils.time_unit_type = 1) -> None: # noqa + self.wait_random_min = _utils.to_seconds(min) + self.wait_random_max = _utils.to_seconds(max) + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + return self.wait_random_min + (random.random() * (self.wait_random_max - self.wait_random_min)) + + +class wait_combine(wait_base): + """Combine several waiting strategies.""" + + def __init__(self, *strategies: wait_base) -> None: + self.wait_funcs = strategies + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + return sum(x(retry_state=retry_state) for x in self.wait_funcs) + + +class wait_chain(wait_base): + """Chain two or more waiting strategies. + + If all strategies are exhausted, the very last strategy is used + thereafter. + + For example:: + + @retry(wait=wait_chain(*[wait_fixed(1) for i in range(3)] + + [wait_fixed(2) for j in range(5)] + + [wait_fixed(5) for k in range(4))) + def wait_chained(): + print("Wait 1s for 3 attempts, 2s for 5 attempts and 5s + thereafter.") + """ + + def __init__(self, *strategies: wait_base) -> None: + self.strategies = strategies + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + wait_func_no = min(max(retry_state.attempt_number, 1), len(self.strategies)) + wait_func = self.strategies[wait_func_no - 1] + return wait_func(retry_state=retry_state) + + +class wait_incrementing(wait_base): + """Wait an incremental amount of time after each attempt. + + Starting at a starting value and incrementing by a value for each attempt + (and restricting the upper limit to some maximum value). + """ + + def __init__( + self, + start: _utils.time_unit_type = 0, + increment: _utils.time_unit_type = 100, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + ) -> None: + self.start = _utils.to_seconds(start) + self.increment = _utils.to_seconds(increment) + self.max = _utils.to_seconds(max) + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + result = self.start + (self.increment * (retry_state.attempt_number - 1)) + return max(0, min(result, self.max)) + + +class wait_exponential(wait_base): + """Wait strategy that applies exponential backoff. + + It allows for a customized multiplier and an ability to restrict the + upper and lower limits to some maximum and minimum value. + + The intervals are fixed (i.e. there is no jitter), so this strategy is + suitable for balancing retries against latency when a required resource is + unavailable for an unknown duration, but *not* suitable for resolving + contention between multiple processes for a shared resource. Use + wait_random_exponential for the latter case. + """ + + def __init__( + self, + multiplier: typing.Union[int, float] = 1, + max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa + exp_base: typing.Union[int, float] = 2, + min: _utils.time_unit_type = 0, # noqa + ) -> None: + self.multiplier = multiplier + self.min = _utils.to_seconds(min) + self.max = _utils.to_seconds(max) + self.exp_base = exp_base + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.multiplier * exp + except OverflowError: + return self.max + return max(max(0, self.min), min(result, self.max)) + + +class wait_random_exponential(wait_exponential): + """Random wait with exponentially widening window. + + An exponential backoff strategy used to mediate contention between multiple + uncoordinated processes for a shared resource in distributed systems. This + is the sense in which "exponential backoff" is meant in e.g. Ethernet + networking, and corresponds to the "Full Jitter" algorithm described in + this blog post: + + https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + + Each retry occurs at a random time in a geometrically expanding interval. + It allows for a custom multiplier and an ability to restrict the upper + limit of the random interval to some maximum value. + + Example:: + + wait_random_exponential(multiplier=0.5, # initial window 0.5s + max=60) # max 60s timeout + + When waiting for an unavailable resource to become available again, as + opposed to trying to resolve contention for a shared resource, the + wait_exponential strategy (which uses a fixed interval) may be preferable. + + """ + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + high = await super().__call__(retry_state=retry_state) + return random.uniform(0, high) + + +class wait_exponential_jitter(wait_base): + """Wait strategy that applies exponential backoff and jitter. + + It allows for a customized initial wait, maximum wait and jitter. + + This implements the strategy described here: + https://cloud.google.com/storage/docs/retry-strategy + + The wait time is min(initial * 2**n + random.uniform(0, jitter), maximum) + where n is the retry count. + """ + + def __init__( + self, + initial: float = 1, + max: float = _utils.MAX_WAIT, # noqa + exp_base: float = 2, + jitter: float = 1, + ) -> None: + self.initial = initial + self.max = max + self.exp_base = exp_base + self.jitter = jitter + + async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] + jitter = random.uniform(0, self.jitter) + try: + exp = self.exp_base ** (retry_state.attempt_number - 1) + result = self.initial * exp + jitter + except OverflowError: + result = self.max + return max(0, min(result, self.max)) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 24cf6ed..60a8089 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -22,7 +22,7 @@ import tenacity from tenacity import AsyncRetrying, RetryError -from tenacity import _asyncio as tasyncio +from tenacity import asyncio as tasyncio from tenacity import retry, retry_if_result, stop_after_attempt from tenacity.wait import wait_fixed @@ -55,6 +55,12 @@ async def _retryable_coroutine_with_2_attempts(thing): thing.go() +@retry(stop=tasyncio.stop_after_attempt(2)) +async def _async_retryable_coroutine_with_2_attempts(thing): + await asyncio.sleep(0.00001) + thing.go() + + class TestAsync(unittest.TestCase): @asynctest async def test_retry(self): @@ -82,6 +88,14 @@ async def test_stop_after_attempt(self): except RetryError: assert thing.counter == 2 + @asynctest + async def test_stop_after_attempt_async(self): + thing = NoIOErrorAfterCount(2) + try: + await _async_retryable_coroutine_with_2_attempts(thing) + except RetryError: + assert thing.counter == 2 + def test_repr(self): repr(tasyncio.AsyncRetrying()) From 84fe2087ded98be1587fb2cf5fa2dbbfec3884ce Mon Sep 17 00:00:00 2001 From: Hasier Date: Sat, 3 Feb 2024 15:40:53 +0000 Subject: [PATCH 2/8] Fix init typing --- tenacity/asyncio/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py index a0c4c89..39fee70 100644 --- a/tenacity/asyncio/__init__.py +++ b/tenacity/asyncio/__init__.py @@ -19,6 +19,7 @@ import sys import typing as t +import tenacity from tenacity import AttemptManager from tenacity import BaseRetrying from tenacity import DoAttempt @@ -66,6 +67,8 @@ from .wait import wait_random_exponential as wait_full_jitter # noqa from .wait import wait_exponential_jitter # noqa +from ..retry import RetryBaseT as SyncRetryBaseT + WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) @@ -81,9 +84,9 @@ class AsyncRetrying(BaseRetrying): def __init__( self, sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = asyncio_sleep, - stop: "t.Union[StopBaseT, StopBaseT]" = stop_never, - wait: "t.Union[WaitBaseT, WaitBaseT]" = wait_none(), - retry: "t.Union[RetryBaseT, RetryBaseT]" = retry_if_exception_type(), + stop: "t.Union[tenacity.stop.StopBaseT, StopBaseT]" = stop_never, + wait: "t.Union[tenacity.wait.WaitBaseT, WaitBaseT]" = wait_none(), + retry: "t.Union[SyncRetryBaseT, RetryBaseT]" = retry_if_exception_type(), before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, From e145ef3408e09d607a84e54980ec7ae01c0753ca Mon Sep 17 00:00:00 2001 From: Hasier Date: Mon, 5 Feb 2024 09:34:54 +0000 Subject: [PATCH 3/8] Reuse is_coroutine_callable --- tenacity/__init__.py | 5 +++-- tenacity/asyncio/__init__.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tenacity/__init__.py b/tenacity/__init__.py index c160a62..d626b2e 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -25,7 +25,7 @@ from abc import ABC, abstractmethod from concurrent import futures -from . import asyncio as tasyncio +from . import _utils # Import all built-in retry strategies for easier usage. from .retry import retry_base # noqa @@ -88,6 +88,7 @@ if t.TYPE_CHECKING: import types + from . import asyncio as tasyncio from .retry import RetryBaseT from .stop import StopBaseT from .wait import WaitBaseT @@ -625,7 +626,7 @@ def wrap(f: WrappedFn) -> WrappedFn: f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)" ) r: "BaseRetrying" - if tasyncio.is_coroutine_callable(f): + if _utils.is_coroutine_callable(f): r = AsyncRetrying(*dargs, **dkw) elif ( tornado diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py index 39fee70..f2cbae5 100644 --- a/tenacity/asyncio/__init__.py +++ b/tenacity/asyncio/__init__.py @@ -20,6 +20,7 @@ import typing as t import tenacity +from tenacity import _utils from tenacity import AttemptManager from tenacity import BaseRetrying from tenacity import DoAttempt From c97c13855e2570582ccb67577220bd966731334f Mon Sep 17 00:00:00 2001 From: Hasier Date: Mon, 5 Feb 2024 11:02:56 +0000 Subject: [PATCH 4/8] Keep only async predicate overrides and DRY implementations --- tenacity/__init__.py | 4 +- tenacity/_utils.py | 12 ++ tenacity/asyncio/__init__.py | 91 ++---------- tenacity/asyncio/retry.py | 262 +++++------------------------------ tenacity/asyncio/stop.py | 122 ---------------- tenacity/asyncio/wait.py | 219 ----------------------------- tests/test_asyncio.py | 68 +++++++-- 7 files changed, 119 insertions(+), 659 deletions(-) delete mode 100644 tenacity/asyncio/stop.py delete mode 100644 tenacity/asyncio/wait.py diff --git a/tenacity/__init__.py b/tenacity/__init__.py index d626b2e..d5bfe80 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -596,8 +596,8 @@ def retry(func: WrappedFn) -> WrappedFn: ... @t.overload def retry( sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = sleep, - stop: "t.Union[StopBaseT, tasyncio.stop.StopBaseT]" = stop_never, - wait: "t.Union[WaitBaseT, tasyncio.wait.WaitBaseT]" = wait_none(), + stop: "StopBaseT" = stop_never, + wait: "WaitBaseT" = wait_none(), retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(), before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, diff --git a/tenacity/_utils.py b/tenacity/_utils.py index 4e34115..f11a088 100644 --- a/tenacity/_utils.py +++ b/tenacity/_utils.py @@ -87,3 +87,15 @@ def is_coroutine_callable(call: typing.Callable[..., typing.Any]) -> bool: partial_call = isinstance(call, functools.partial) and call.func dunder_call = partial_call or getattr(call, "__call__", None) return inspect.iscoroutinefunction(dunder_call) + + +def wrap_to_async_func( + call: typing.Callable[..., typing.Any], +) -> typing.Callable[..., typing.Awaitable[typing.Any]]: + if is_coroutine_callable(call): + return call + + async def inner(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + return call(*args, **kwargs) + + return inner diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py index f2cbae5..019bbb6 100644 --- a/tenacity/asyncio/__init__.py +++ b/tenacity/asyncio/__init__.py @@ -20,13 +20,13 @@ import typing as t import tenacity -from tenacity import _utils from tenacity import AttemptManager from tenacity import BaseRetrying from tenacity import DoAttempt from tenacity import DoSleep from tenacity import RetryCallState from tenacity import RetryError +from tenacity import _utils from tenacity import after_nothing from tenacity import before_nothing from tenacity import _utils @@ -34,42 +34,15 @@ # Import all built-in retry strategies for easier usage. from .retry import RetryBaseT from .retry import retry_all # noqa -from .retry import retry_always # noqa from .retry import retry_any # noqa from .retry import retry_if_exception # noqa -from .retry import retry_if_exception_type # noqa -from .retry import retry_if_exception_cause_type # noqa -from .retry import retry_if_not_exception_type # noqa -from .retry import retry_if_not_result # noqa from .retry import retry_if_result # noqa -from .retry import retry_never # noqa -from .retry import retry_unless_exception_type # noqa -from .retry import retry_if_exception_message # noqa -from .retry import retry_if_not_exception_message # noqa -# Import all built-in stop strategies for easier usage. -from .stop import StopBaseT -from .stop import stop_after_attempt # noqa -from .stop import stop_after_delay # noqa -from .stop import stop_before_delay # noqa -from .stop import stop_all # noqa -from .stop import stop_any # noqa -from .stop import stop_never # noqa -from .stop import stop_when_event_set # noqa -# Import all built-in wait strategies for easier usage. -from .wait import WaitBaseT -from .wait import wait_chain # noqa -from .wait import wait_combine # noqa -from .wait import wait_exponential # noqa -from .wait import wait_fixed # noqa -from .wait import wait_incrementing # noqa -from .wait import wait_none # noqa -from .wait import wait_random # noqa -from .wait import wait_random_exponential # noqa -from .wait import wait_random_exponential as wait_full_jitter # noqa -from .wait import wait_exponential_jitter # noqa - from ..retry import RetryBaseT as SyncRetryBaseT +if t.TYPE_CHECKING: + from tenacity.stop import StopBaseT + from tenacity.wait import WaitBaseT + WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) @@ -85,9 +58,9 @@ class AsyncRetrying(BaseRetrying): def __init__( self, sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = asyncio_sleep, - stop: "t.Union[tenacity.stop.StopBaseT, StopBaseT]" = stop_never, - wait: "t.Union[tenacity.wait.WaitBaseT, WaitBaseT]" = wait_none(), - retry: "t.Union[SyncRetryBaseT, RetryBaseT]" = retry_if_exception_type(), + stop: "StopBaseT" = tenacity.stop.stop_never, + wait: "WaitBaseT" = tenacity.wait.wait_none(), + retry: "t.Union[SyncRetryBaseT, RetryBaseT]" = tenacity.retry_if_exception_type(), before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, @@ -97,8 +70,8 @@ def __init__( ) -> None: super().__init__( sleep=sleep, # type: ignore[arg-type] - stop=stop, # type: ignore[arg-type] - wait=wait, # type: ignore[arg-type] + stop=stop, + wait=wait, retry=retry, # type: ignore[arg-type] before=before, # type: ignore[arg-type] after=after, # type: ignore[arg-type] @@ -129,27 +102,17 @@ async def __call__( # type: ignore[override] else: return do # type: ignore[no-any-return] - @classmethod - def _wrap_action_func(cls, fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - if _utils.is_coroutine_callable(fn): - return fn - - async def inner(*args: t.Any, **kwargs: t.Any) -> t.Any: - return fn(*args, **kwargs) - - return inner - def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None: - self.iter_state.actions.append(self._wrap_action_func(fn)) + self.iter_state.actions.append(_utils.wrap_to_async_func(fn)) async def _run_retry(self, retry_state: "RetryCallState") -> None: # type: ignore[override] - self.iter_state.retry_run_result = await self._wrap_action_func(self.retry)( + self.iter_state.retry_run_result = await _utils.wrap_to_async_func(self.retry)( retry_state ) async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignore[override] if self.wait: - sleep = await self._wrap_action_func(self.wait)(retry_state) + sleep = await _utils.wrap_to_async_func(self.wait)(retry_state) else: sleep = 0.0 @@ -157,7 +120,7 @@ async def _run_wait(self, retry_state: "RetryCallState") -> None: # type: ignor async def _run_stop(self, retry_state: "RetryCallState") -> None: # type: ignore[override] self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start - self.iter_state.stop_run_result = await self._wrap_action_func(self.stop)( + self.iter_state.stop_run_result = await _utils.wrap_to_async_func(self.stop)( retry_state ) @@ -210,35 +173,9 @@ async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: __all__ = [ "retry_all", - "retry_always", "retry_any", "retry_if_exception", - "retry_if_exception_type", - "retry_if_exception_cause_type", - "retry_if_not_exception_type", - "retry_if_not_result", "retry_if_result", - "retry_never", - "retry_unless_exception_type", - "retry_if_exception_message", - "retry_if_not_exception_message", - "stop_after_attempt", - "stop_after_delay", - "stop_before_delay", - "stop_all", - "stop_any", - "stop_never", - "stop_when_event_set", - "wait_chain", - "wait_combine", - "wait_exponential", - "wait_fixed", - "wait_incrementing", - "wait_none", - "wait_random", - "wait_random_exponential", - "wait_full_jitter", - "wait_exponential_jitter", "WrappedFn", "AsyncRetrying", ] diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py index eb63286..7e00e2d 100644 --- a/tenacity/asyncio/retry.py +++ b/tenacity/asyncio/retry.py @@ -14,270 +14,82 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -import re +import inspect import typing +from tenacity import _utils from tenacity import retry_base +from tenacity import retry_if_exception as _retry_if_exception +from tenacity import retry_if_result as _retry_if_result if typing.TYPE_CHECKING: from tenacity import RetryCallState -class retry_base(retry_base): # type: ignore[no-redef] - """Abstract base class for retry strategies.""" +class async_retry_base(retry_base): + """Abstract base class for async retry strategies.""" @abc.abstractmethod async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] pass + def __and__(self, other: "typing.Union[retry_base, async_retry_base]") -> "retry_all": # type: ignore[override] + return retry_all(self, other) -RetryBaseT = typing.Union[retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]] + def __or__(self, other: "typing.Union[retry_base, async_retry_base]") -> "retry_any": # type: ignore[override] + return retry_any(self, other) -class _retry_never(retry_base): - """Retry strategy that never rejects any result.""" +class async_predicate_mixin: + async def __call__(self, retry_state: "RetryCallState") -> bool: + result = super().__call__(retry_state) # type: ignore[misc] + if inspect.isawaitable(result): + result = await result + return typing.cast(bool, result) - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - return False - - -retry_never = _retry_never() +RetryBaseT = typing.Union[async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]] -class _retry_always(retry_base): - """Retry strategy that always rejects any result.""" - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - return True - -retry_always = _retry_always() - - -class retry_if_exception(retry_base): +class retry_if_exception(async_predicate_mixin, _retry_if_exception, async_retry_base): # type: ignore[misc] """Retry strategy that retries if an exception verifies a predicate.""" def __init__(self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]]) -> None: - self.predicate = predicate - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - if retry_state.outcome is None: - raise RuntimeError("__call__() called before outcome was set") - - if retry_state.outcome.failed: - exception = retry_state.outcome.exception() - if exception is None: - raise RuntimeError("outcome failed but the exception is None") - return await self.predicate(exception) - else: - return False - - -class retry_if_exception_type(retry_if_exception): - """Retries if an exception has been raised of one or more types.""" - - def __init__( - self, - exception_types: typing.Union[ - typing.Type[BaseException], - typing.Tuple[typing.Type[BaseException], ...], - ] = Exception, - ) -> None: - self.exception_types = exception_types - - async def predicate(e: BaseException) -> bool: - return isinstance(e, exception_types) - - super().__init__(predicate) - - -class retry_if_not_exception_type(retry_if_exception): - """Retries except an exception has been raised of one or more types.""" - - def __init__( - self, - exception_types: typing.Union[ - typing.Type[BaseException], - typing.Tuple[typing.Type[BaseException], ...], - ] = Exception, - ) -> None: - self.exception_types = exception_types - - async def predicate(e: BaseException) -> bool: - return not isinstance(e, exception_types) - - super().__init__(predicate) - - -class retry_unless_exception_type(retry_if_exception): - """Retries until an exception is raised of one or more types.""" - - def __init__( - self, - exception_types: typing.Union[ - typing.Type[BaseException], - typing.Tuple[typing.Type[BaseException], ...], - ] = Exception, - ) -> None: - self.exception_types = exception_types - - async def predicate(e: BaseException) -> bool: - return not isinstance(e, exception_types) - - super().__init__(predicate) - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - if retry_state.outcome is None: - raise RuntimeError("__call__() called before outcome was set") - - # always retry if no exception was raised - if not retry_state.outcome.failed: - return True - - exception = retry_state.outcome.exception() - if exception is None: - raise RuntimeError("outcome failed but the exception is None") - return await self.predicate(exception) - - -class retry_if_exception_cause_type(retry_base): - """Retries if any of the causes of the raised exception is of one or more types. - - The check on the type of the cause of the exception is done recursively (until finding - an exception in the chain that has no `__cause__`) - """ + super().__init__(predicate) # type: ignore[arg-type] - def __init__( - self, - exception_types: typing.Union[ - typing.Type[BaseException], - typing.Tuple[typing.Type[BaseException], ...], - ] = Exception, - ) -> None: - self.exception_cause_types = exception_types - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - if retry_state.outcome is None: - raise RuntimeError("__call__ called before outcome was set") - - if retry_state.outcome.failed: - exc = retry_state.outcome.exception() - while exc is not None: - if isinstance(exc.__cause__, self.exception_cause_types): - return True - exc = exc.__cause__ - - return False - - -class retry_if_result(retry_base): +class retry_if_result(async_predicate_mixin, _retry_if_result, async_retry_base): # type: ignore[misc] """Retries if the result verifies a predicate.""" def __init__(self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]) -> None: - self.predicate = predicate - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - if retry_state.outcome is None: - raise RuntimeError("__call__() called before outcome was set") - - if not retry_state.outcome.failed: - return await self.predicate(retry_state.outcome.result()) - else: - return False - - -class retry_if_not_result(retry_base): - """Retries if the result refutes a predicate.""" - - def __init__(self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]) -> None: - self.predicate = predicate - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - if retry_state.outcome is None: - raise RuntimeError("__call__() called before outcome was set") - - if not retry_state.outcome.failed: - return not await self.predicate(retry_state.outcome.result()) - else: - return False - - -class retry_if_exception_message(retry_if_exception): - """Retries if an exception message equals or matches.""" - - def __init__( - self, - message: typing.Optional[str] = None, - match: typing.Optional[str] = None, - ) -> None: - if message and match: - raise TypeError(f"{self.__class__.__name__}() takes either 'message' or 'match', not both") - - # set predicate - if message: - - async def message_fnc(exception: BaseException) -> bool: - return message == str(exception) - - predicate = message_fnc - elif match: - prog = re.compile(match) - - async def match_fnc(exception: BaseException) -> bool: - return bool(prog.match(str(exception))) - - predicate = match_fnc - else: - raise TypeError(f"{self.__class__.__name__}() missing 1 required argument 'message' or 'match'") - - super().__init__(predicate) - - -class retry_if_not_exception_message(retry_if_exception_message): - """Retries until an exception message equals or matches.""" - - def __init__( - self, - message: typing.Optional[str] = None, - match: typing.Optional[str] = None, - ) -> None: - super().__init__(message, match) - if_predicate = self.predicate - - # invert predicate - async def predicate(e: BaseException) -> bool: - return not if_predicate(e) - - self.predicate = predicate - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - if retry_state.outcome is None: - raise RuntimeError("__call__() called before outcome was set") - - if not retry_state.outcome.failed: - return True - - exception = retry_state.outcome.exception() - if exception is None: - raise RuntimeError("outcome failed but the exception is None") - return await self.predicate(exception) + super().__init__(predicate) # type: ignore[arg-type] -class retry_any(retry_base): +class retry_any(async_retry_base): """Retries if any of the retries condition is valid.""" - def __init__(self, *retries: retry_base) -> None: + def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None: self.retries = retries async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - return any(r(retry_state) for r in self.retries) + result = False + for r in self.retries: + result = result or await _utils.wrap_to_async_func(r)(retry_state) + if result: + break + return result -class retry_all(retry_base): +class retry_all(async_retry_base): """Retries if all the retries condition are valid.""" - def __init__(self, *retries: retry_base) -> None: + def __init__(self, *retries: typing.Union[retry_base, async_retry_base]) -> None: self.retries = retries async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - return all(r(retry_state) for r in self.retries) + result = True + for r in self.retries: + result = result and await _utils.wrap_to_async_func(r)(retry_state) + if not result: + break + return result diff --git a/tenacity/asyncio/stop.py b/tenacity/asyncio/stop.py deleted file mode 100644 index 1528426..0000000 --- a/tenacity/asyncio/stop.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2016–2021 Julien Danjou -# Copyright 2016 Joshua Harlow -# Copyright 2013-2014 Ray Holder -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import abc -import typing - -from tenacity import _utils -from tenacity.stop import stop_base - -if typing.TYPE_CHECKING: - import asyncio - - from tenacity import RetryCallState - - -class stop_base(stop_base): # type: ignore[no-redef] - """Abstract base class for stop strategies.""" - - @abc.abstractmethod - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - pass - - -StopBaseT = typing.Union[stop_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]] - - -class stop_any(stop_base): - """Stop if any of the stop condition is valid.""" - - def __init__(self, *stops: stop_base) -> None: - self.stops = stops - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - return any(x(retry_state) for x in self.stops) - - -class stop_all(stop_base): - """Stop if all the stop conditions are valid.""" - - def __init__(self, *stops: stop_base) -> None: - self.stops = stops - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - return all(x(retry_state) for x in self.stops) - - -class _stop_never(stop_base): - """Never stop.""" - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - return False - - -stop_never = _stop_never() - - -class stop_when_event_set(stop_base): - """Stop when the given event is set.""" - - def __init__(self, event: "asyncio.Event") -> None: - self.event = event - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - return self.event.is_set() - - -class stop_after_attempt(stop_base): - """Stop when the previous attempt >= max_attempt.""" - - def __init__(self, max_attempt_number: int) -> None: - self.max_attempt_number = max_attempt_number - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - return retry_state.attempt_number >= self.max_attempt_number - - -class stop_after_delay(stop_base): - """ - Stop when the time from the first attempt >= limit. - - Note: `max_delay` will be exceeded, so when used with a `wait`, the actual total delay will be greater - than `max_delay` by some of the final sleep period before `max_delay` is exceeded. - - If you need stricter timing with waits, consider `stop_before_delay` instead. - """ - - def __init__(self, max_delay: _utils.time_unit_type) -> None: - self.max_delay = _utils.to_seconds(max_delay) - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - if retry_state.seconds_since_start is None: - raise RuntimeError("__call__() called but seconds_since_start is not set") - return retry_state.seconds_since_start >= self.max_delay - - -class stop_before_delay(stop_base): - """ - Stop right before the next attempt would take place after the time from the first attempt >= limit. - - Most useful when you are using with a `wait` function like wait_random_exponential, but need to make - sure that the max_delay is not exceeded. - """ - - def __init__(self, max_delay: _utils.time_unit_type) -> None: - self.max_delay = _utils.to_seconds(max_delay) - - async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] - if retry_state.seconds_since_start is None: - raise RuntimeError("__call__() called but seconds_since_start is not set") - return retry_state.seconds_since_start + retry_state.upcoming_sleep >= self.max_delay diff --git a/tenacity/asyncio/wait.py b/tenacity/asyncio/wait.py deleted file mode 100644 index 021b34d..0000000 --- a/tenacity/asyncio/wait.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2016–2021 Julien Danjou -# Copyright 2016 Joshua Harlow -# Copyright 2013-2014 Ray Holder -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import abc -import random -import typing - -from tenacity import _utils -from tenacity.wait import wait_base - -if typing.TYPE_CHECKING: - from tenacity import RetryCallState - - -class wait_base(wait_base): # type: ignore[no-redef] - """Abstract base class for wait strategies.""" - - @abc.abstractmethod - async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] - pass - - -WaitBaseT = typing.Union[wait_base, typing.Callable[["RetryCallState"], typing.Awaitable[typing.Union[float, int]]]] - - -class wait_fixed(wait_base): - """Wait strategy that waits a fixed amount of time between each retry.""" - - def __init__(self, wait: _utils.time_unit_type) -> None: - self.wait_fixed = _utils.to_seconds(wait) - - async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] - return self.wait_fixed - - -class wait_none(wait_fixed): - """Wait strategy that doesn't wait at all before retrying.""" - - def __init__(self) -> None: - super().__init__(0) - - -class wait_random(wait_base): - """Wait strategy that waits a random amount of time between min/max.""" - - def __init__(self, min: _utils.time_unit_type = 0, max: _utils.time_unit_type = 1) -> None: # noqa - self.wait_random_min = _utils.to_seconds(min) - self.wait_random_max = _utils.to_seconds(max) - - async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] - return self.wait_random_min + (random.random() * (self.wait_random_max - self.wait_random_min)) - - -class wait_combine(wait_base): - """Combine several waiting strategies.""" - - def __init__(self, *strategies: wait_base) -> None: - self.wait_funcs = strategies - - async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] - return sum(x(retry_state=retry_state) for x in self.wait_funcs) - - -class wait_chain(wait_base): - """Chain two or more waiting strategies. - - If all strategies are exhausted, the very last strategy is used - thereafter. - - For example:: - - @retry(wait=wait_chain(*[wait_fixed(1) for i in range(3)] + - [wait_fixed(2) for j in range(5)] + - [wait_fixed(5) for k in range(4))) - def wait_chained(): - print("Wait 1s for 3 attempts, 2s for 5 attempts and 5s - thereafter.") - """ - - def __init__(self, *strategies: wait_base) -> None: - self.strategies = strategies - - async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] - wait_func_no = min(max(retry_state.attempt_number, 1), len(self.strategies)) - wait_func = self.strategies[wait_func_no - 1] - return wait_func(retry_state=retry_state) - - -class wait_incrementing(wait_base): - """Wait an incremental amount of time after each attempt. - - Starting at a starting value and incrementing by a value for each attempt - (and restricting the upper limit to some maximum value). - """ - - def __init__( - self, - start: _utils.time_unit_type = 0, - increment: _utils.time_unit_type = 100, - max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa - ) -> None: - self.start = _utils.to_seconds(start) - self.increment = _utils.to_seconds(increment) - self.max = _utils.to_seconds(max) - - async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] - result = self.start + (self.increment * (retry_state.attempt_number - 1)) - return max(0, min(result, self.max)) - - -class wait_exponential(wait_base): - """Wait strategy that applies exponential backoff. - - It allows for a customized multiplier and an ability to restrict the - upper and lower limits to some maximum and minimum value. - - The intervals are fixed (i.e. there is no jitter), so this strategy is - suitable for balancing retries against latency when a required resource is - unavailable for an unknown duration, but *not* suitable for resolving - contention between multiple processes for a shared resource. Use - wait_random_exponential for the latter case. - """ - - def __init__( - self, - multiplier: typing.Union[int, float] = 1, - max: _utils.time_unit_type = _utils.MAX_WAIT, # noqa - exp_base: typing.Union[int, float] = 2, - min: _utils.time_unit_type = 0, # noqa - ) -> None: - self.multiplier = multiplier - self.min = _utils.to_seconds(min) - self.max = _utils.to_seconds(max) - self.exp_base = exp_base - - async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] - try: - exp = self.exp_base ** (retry_state.attempt_number - 1) - result = self.multiplier * exp - except OverflowError: - return self.max - return max(max(0, self.min), min(result, self.max)) - - -class wait_random_exponential(wait_exponential): - """Random wait with exponentially widening window. - - An exponential backoff strategy used to mediate contention between multiple - uncoordinated processes for a shared resource in distributed systems. This - is the sense in which "exponential backoff" is meant in e.g. Ethernet - networking, and corresponds to the "Full Jitter" algorithm described in - this blog post: - - https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ - - Each retry occurs at a random time in a geometrically expanding interval. - It allows for a custom multiplier and an ability to restrict the upper - limit of the random interval to some maximum value. - - Example:: - - wait_random_exponential(multiplier=0.5, # initial window 0.5s - max=60) # max 60s timeout - - When waiting for an unavailable resource to become available again, as - opposed to trying to resolve contention for a shared resource, the - wait_exponential strategy (which uses a fixed interval) may be preferable. - - """ - - async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] - high = await super().__call__(retry_state=retry_state) - return random.uniform(0, high) - - -class wait_exponential_jitter(wait_base): - """Wait strategy that applies exponential backoff and jitter. - - It allows for a customized initial wait, maximum wait and jitter. - - This implements the strategy described here: - https://cloud.google.com/storage/docs/retry-strategy - - The wait time is min(initial * 2**n + random.uniform(0, jitter), maximum) - where n is the retry count. - """ - - def __init__( - self, - initial: float = 1, - max: float = _utils.MAX_WAIT, # noqa - exp_base: float = 2, - jitter: float = 1, - ) -> None: - self.initial = initial - self.max = max - self.exp_base = exp_base - self.jitter = jitter - - async def __call__(self, retry_state: "RetryCallState") -> float: # type: ignore[override] - jitter = random.uniform(0, self.jitter) - try: - exp = self.exp_base ** (retry_state.attempt_number - 1) - result = self.initial * exp + jitter - except OverflowError: - result = self.max - return max(0, min(result, self.max)) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 60a8089..af045d1 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -55,12 +55,6 @@ async def _retryable_coroutine_with_2_attempts(thing): thing.go() -@retry(stop=tasyncio.stop_after_attempt(2)) -async def _async_retryable_coroutine_with_2_attempts(thing): - await asyncio.sleep(0.00001) - thing.go() - - class TestAsync(unittest.TestCase): @asynctest async def test_retry(self): @@ -88,14 +82,6 @@ async def test_stop_after_attempt(self): except RetryError: assert thing.counter == 2 - @asynctest - async def test_stop_after_attempt_async(self): - thing = NoIOErrorAfterCount(2) - try: - await _async_retryable_coroutine_with_2_attempts(thing) - except RetryError: - assert thing.counter == 2 - def test_repr(self): repr(tasyncio.AsyncRetrying()) @@ -216,6 +202,60 @@ def lt_3(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result_or(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + class CustomException(Exception): + pass + + async def is_exc(e: BaseException) -> bool: + return isinstance(e, CustomException) + + retry_strategy = tasyncio.retry_if_result(lt_3) | tasyncio.retry_if_exception(is_exc) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + if 1 < attempts < 3: + raise CustomException() + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_result_and(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + def gt_0(x: float) -> bool: + return x > 0 + + retry_strategy = tasyncio.retry_if_result(lt_3) & retry_if_result(gt_0) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + @asynctest async def test_async_retying_iterator(self): thing = NoIOErrorAfterCount(5) From d0f11cb3a7f763b4c5d3a447f8a81e66ef9b966d Mon Sep 17 00:00:00 2001 From: Hasier Date: Mon, 18 Mar 2024 16:40:40 +0000 Subject: [PATCH 5/8] Ensure async and/or versions called when necessary --- tenacity/asyncio/retry.py | 30 ++++++++-- tenacity/retry.py | 10 +++- tests/test_asyncio.py | 115 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 144 insertions(+), 11 deletions(-) diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py index 7e00e2d..138cebe 100644 --- a/tenacity/asyncio/retry.py +++ b/tenacity/asyncio/retry.py @@ -33,12 +33,26 @@ class async_retry_base(retry_base): async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] pass - def __and__(self, other: "typing.Union[retry_base, async_retry_base]") -> "retry_all": # type: ignore[override] + def __and__( # type: ignore[override] + self, other: "typing.Union[retry_base, async_retry_base]" + ) -> "retry_all": return retry_all(self, other) - def __or__(self, other: "typing.Union[retry_base, async_retry_base]") -> "retry_any": # type: ignore[override] + def __rand__( # type: ignore[misc,override] + self, other: "typing.Union[retry_base, async_retry_base]" + ) -> "retry_all": + return retry_all(other, self) + + def __or__( # type: ignore[override] + self, other: "typing.Union[retry_base, async_retry_base]" + ) -> "retry_any": return retry_any(self, other) + def __ror__( # type: ignore[misc,override] + self, other: "typing.Union[retry_base, async_retry_base]" + ) -> "retry_any": + return retry_any(other, self) + class async_predicate_mixin: async def __call__(self, retry_state: "RetryCallState") -> bool: @@ -48,20 +62,26 @@ async def __call__(self, retry_state: "RetryCallState") -> bool: return typing.cast(bool, result) -RetryBaseT = typing.Union[async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]]] +RetryBaseT = typing.Union[ + async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]] +] class retry_if_exception(async_predicate_mixin, _retry_if_exception, async_retry_base): # type: ignore[misc] """Retry strategy that retries if an exception verifies a predicate.""" - def __init__(self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]]) -> None: + def __init__( + self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]] + ) -> None: super().__init__(predicate) # type: ignore[arg-type] class retry_if_result(async_predicate_mixin, _retry_if_result, async_retry_base): # type: ignore[misc] """Retries if the result verifies a predicate.""" - def __init__(self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]]) -> None: + def __init__( + self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]] + ) -> None: super().__init__(predicate) # type: ignore[arg-type] diff --git a/tenacity/retry.py b/tenacity/retry.py index c5e55a6..9211631 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -30,10 +30,16 @@ def __call__(self, retry_state: "RetryCallState") -> bool: pass def __and__(self, other: "retry_base") -> "retry_all": - return retry_all(self, other) + return other.__rand__(self) + + def __rand__(self, other: "retry_base") -> "retry_all": + return retry_all(other, self) def __or__(self, other: "retry_base") -> "retry_any": - return retry_any(self, other) + return other.__ror__(self) + + def __ror__(self, other: "retry_base") -> "retry_any": + return retry_any(other, self) RetryBaseT = typing.Union[retry_base, typing.Callable[["RetryCallState"], bool]] diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index af045d1..48f6286 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -23,7 +23,7 @@ import tenacity from tenacity import AsyncRetrying, RetryError from tenacity import asyncio as tasyncio -from tenacity import retry, retry_if_result, stop_after_attempt +from tenacity import retry, retry_if_exception, retry_if_result, stop_after_attempt from tenacity.wait import wait_fixed from .test_tenacity import NoIOErrorAfterCount, current_time_ms @@ -202,6 +202,59 @@ def lt_3(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + async for attempt in tasyncio.AsyncRetrying( + retry=tasyncio.retry_if_result(lt_3) + ): + with attempt: + attempts += 1 + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + + @asynctest + async def test_retry_with_async_exc(self): + async def test(): + attempts = 0 + + class CustomException(Exception): + pass + + async def is_exc(e: BaseException) -> bool: + return isinstance(e, CustomException) + + async for attempt in tasyncio.AsyncRetrying( + retry=tasyncio.retry_if_exception(is_exc) + ): + with attempt: + attempts += 1 + if attempts < 3: + raise CustomException() + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + @asynctest async def test_retry_with_async_result_or(self): async def test(): @@ -213,14 +266,45 @@ async def lt_3(x: float) -> bool: class CustomException(Exception): pass + def is_exc(e: BaseException) -> bool: + return isinstance(e, CustomException) + + retry_strategy = tasyncio.retry_if_result(lt_3) | retry_if_exception(is_exc) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + if 2 < attempts < 4: + raise CustomException() + + assert attempt.retry_state.outcome # help mypy + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(4, result) + + @asynctest + async def test_retry_with_async_result_ror(self): + async def test(): + attempts = 0 + + def lt_3(x: float) -> bool: + return x < 3 + + class CustomException(Exception): + pass + async def is_exc(e: BaseException) -> bool: return isinstance(e, CustomException) - retry_strategy = tasyncio.retry_if_result(lt_3) | tasyncio.retry_if_exception(is_exc) + retry_strategy = retry_if_result(lt_3) | tasyncio.retry_if_exception(is_exc) async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): with attempt: attempts += 1 - if 1 < attempts < 3: + if 2 < attempts < 4: raise CustomException() assert attempt.retry_state.outcome # help mypy @@ -231,7 +315,7 @@ async def is_exc(e: BaseException) -> bool: result = await test() - self.assertEqual(3, result) + self.assertEqual(4, result) @asynctest async def test_retry_with_async_result_and(self): @@ -256,6 +340,29 @@ def gt_0(x: float) -> bool: self.assertEqual(3, result) + @asynctest + async def test_retry_with_async_result_rand(self): + async def test(): + attempts = 0 + + async def lt_3(x: float) -> bool: + return x < 3 + + def gt_0(x: float) -> bool: + return x > 0 + + retry_strategy = retry_if_result(gt_0) & tasyncio.retry_if_result(lt_3) + async for attempt in tasyncio.AsyncRetrying(retry=retry_strategy): + with attempt: + attempts += 1 + attempt.retry_state.set_result(attempts) + + return attempts + + result = await test() + + self.assertEqual(3, result) + @asynctest async def test_async_retying_iterator(self): thing = NoIOErrorAfterCount(5) From 3c5f7887bafca6edd1ac4c72720956863e066e4a Mon Sep 17 00:00:00 2001 From: Hasier Date: Mon, 18 Mar 2024 16:48:24 +0000 Subject: [PATCH 6/8] Run ruff format --- tenacity/__init__.py | 16 ++++++++++++---- tenacity/asyncio/__init__.py | 21 +++++++++++++++------ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/tenacity/__init__.py b/tenacity/__init__.py index d5bfe80..7de36d4 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -599,12 +599,20 @@ def retry( stop: "StopBaseT" = stop_never, wait: "WaitBaseT" = wait_none(), retry: "t.Union[RetryBaseT, tasyncio.retry.RetryBaseT]" = retry_if_exception_type(), - before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, - after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, - before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, + before: t.Callable[ + ["RetryCallState"], t.Union[None, t.Awaitable[None]] + ] = before_nothing, + after: t.Callable[ + ["RetryCallState"], t.Union[None, t.Awaitable[None]] + ] = after_nothing, + before_sleep: t.Optional[ + t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] + ] = None, reraise: bool = False, retry_error_cls: t.Type["RetryError"] = RetryError, - retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]] = None, + retry_error_callback: t.Optional[ + t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]] + ] = None, ) -> t.Callable[[WrappedFn], WrappedFn]: ... diff --git a/tenacity/asyncio/__init__.py b/tenacity/asyncio/__init__.py index 019bbb6..3ec0088 100644 --- a/tenacity/asyncio/__init__.py +++ b/tenacity/asyncio/__init__.py @@ -26,7 +26,6 @@ from tenacity import DoSleep from tenacity import RetryCallState from tenacity import RetryError -from tenacity import _utils from tenacity import after_nothing from tenacity import before_nothing from tenacity import _utils @@ -57,16 +56,26 @@ def asyncio_sleep(duration: float) -> t.Awaitable[None]: class AsyncRetrying(BaseRetrying): def __init__( self, - sleep: t.Callable[[t.Union[int, float]], t.Union[None, t.Awaitable[None]]] = asyncio_sleep, + sleep: t.Callable[ + [t.Union[int, float]], t.Union[None, t.Awaitable[None]] + ] = asyncio_sleep, stop: "StopBaseT" = tenacity.stop.stop_never, wait: "WaitBaseT" = tenacity.wait.wait_none(), retry: "t.Union[SyncRetryBaseT, RetryBaseT]" = tenacity.retry_if_exception_type(), - before: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = before_nothing, - after: t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] = after_nothing, - before_sleep: t.Optional[t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]]] = None, + before: t.Callable[ + ["RetryCallState"], t.Union[None, t.Awaitable[None]] + ] = before_nothing, + after: t.Callable[ + ["RetryCallState"], t.Union[None, t.Awaitable[None]] + ] = after_nothing, + before_sleep: t.Optional[ + t.Callable[["RetryCallState"], t.Union[None, t.Awaitable[None]]] + ] = None, reraise: bool = False, retry_error_cls: t.Type["RetryError"] = RetryError, - retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]]] = None, + retry_error_callback: t.Optional[ + t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]] + ] = None, ) -> None: super().__init__( sleep=sleep, # type: ignore[arg-type] From 927965d09d4091c7a72fee896a7644fe4f5d2d78 Mon Sep 17 00:00:00 2001 From: Hasier Date: Mon, 3 Jun 2024 09:58:23 +0100 Subject: [PATCH 7/8] Copy over strategies as async --- tenacity/asyncio/retry.py | 40 ++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/tenacity/asyncio/retry.py b/tenacity/asyncio/retry.py index 138cebe..94b8b15 100644 --- a/tenacity/asyncio/retry.py +++ b/tenacity/asyncio/retry.py @@ -14,13 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -import inspect import typing from tenacity import _utils from tenacity import retry_base -from tenacity import retry_if_exception as _retry_if_exception -from tenacity import retry_if_result as _retry_if_result if typing.TYPE_CHECKING: from tenacity import RetryCallState @@ -54,35 +51,48 @@ def __ror__( # type: ignore[misc,override] return retry_any(other, self) -class async_predicate_mixin: - async def __call__(self, retry_state: "RetryCallState") -> bool: - result = super().__call__(retry_state) # type: ignore[misc] - if inspect.isawaitable(result): - result = await result - return typing.cast(bool, result) - - RetryBaseT = typing.Union[ async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]] ] -class retry_if_exception(async_predicate_mixin, _retry_if_exception, async_retry_base): # type: ignore[misc] +class retry_if_exception(async_retry_base): """Retry strategy that retries if an exception verifies a predicate.""" def __init__( self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]] ) -> None: - super().__init__(predicate) # type: ignore[arg-type] + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + if retry_state.outcome.failed: + exception = retry_state.outcome.exception() + if exception is None: + raise RuntimeError("outcome failed but the exception is None") + return await self.predicate(exception) + else: + return False -class retry_if_result(async_predicate_mixin, _retry_if_result, async_retry_base): # type: ignore[misc] + +class retry_if_result(async_retry_base): """Retries if the result verifies a predicate.""" def __init__( self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]] ) -> None: - super().__init__(predicate) # type: ignore[arg-type] + self.predicate = predicate + + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] + if retry_state.outcome is None: + raise RuntimeError("__call__() called before outcome was set") + + if not retry_state.outcome.failed: + return await self.predicate(retry_state.outcome.result()) + else: + return False class retry_any(async_retry_base): From 4aa9ccf271cb37b0ac37437075bb25b033a2da0a Mon Sep 17 00:00:00 2001 From: Hasier Date: Tue, 11 Jun 2024 10:58:15 +0100 Subject: [PATCH 8/8] Add release note --- releasenotes/notes/add-async-actions-b249c527d99723bb.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 releasenotes/notes/add-async-actions-b249c527d99723bb.yaml diff --git a/releasenotes/notes/add-async-actions-b249c527d99723bb.yaml b/releasenotes/notes/add-async-actions-b249c527d99723bb.yaml new file mode 100644 index 0000000..096a24f --- /dev/null +++ b/releasenotes/notes/add-async-actions-b249c527d99723bb.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Added the ability to use async functions for retries. This way, you can now use + asyncio coroutines for retry strategy predicates.