Skip to content

Commit

Permalink
On predicate (#70)
Browse files Browse the repository at this point in the history
* Create type alias for 'on' parameter

* Parametrize tests on the 'on' parameter

* Allow 'on' parameter to be a predicate

* Backwards compatible annotation with Python<3.10

* Rename On to ExcOrPredicate

* Use fixture to paramatrize `on`

* Add version debug output

* Add docs

* Wordsmith

---------

Co-authored-by: Hynek Schlawack <[email protected]>
  • Loading branch information
gsakkis and hynek authored Aug 5, 2024
1 parent b4048f1 commit a5c37a8
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 41 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/

## [Unreleased](https://github.com/hynek/stamina/compare/24.2.0...HEAD)

### Added

- The *on* argument in all retry functions now can be a callable that takes an exception and returns a bool which decides whether or not a retry should be scheduled.
[#70](https://github.com/hynek/stamina/pull/70)



## [24.2.0](https://github.com/hynek/stamina/compare/24.1.0...24.2.0) - 2024-01-31

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ But bad retries can make things *much worse*.
Our goal is to be as **ergonomic** as possible, while doing the **right thing by default**, and minimizing the potential for **misuse**.
It is the result of years of copy-pasting the same configuration over and over again:

- Retry only on certain exceptions.
- Retry only on certain exceptions – or even a subset of them by introspecting them first using a predicate.
- Exponential **backoff** with **jitter** between retries.
- Limit the number of retries **and** total time.
- Automatic **async** support – including [Trio](https://trio.readthedocs.io/).
Expand Down
30 changes: 30 additions & 0 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,36 @@ def do_it(code: int) -> httpx.Response:
This will retry the function up to 3 times if it raises an {class}`httpx.HTTPError` (or any subclass thereof).
Since retrying on {class}`Exception` is an [attractive nuisance](https://blog.ganssle.io/articles/2023/01/attractive-nuisances.html), *stamina* doesn't do it by default and forces you to be explicit.

---

Sometimes, an exception is too broad, though.
For example, *httpx* raises [`httpx.HTTPStatusError`](https://www.python-httpx.org/exceptions/) on all HTTP errors.
But some errors, like 404 (Not Found) or 403 (Forbidden), usually shouldn't be retried!

To solve problems like this, you can pass a *predicate* to `on`.
A predicate is a callable that's called with the exception that was raised and whose return value will be used to decide whether to retry or not.

So, calling the following `do_it` function will only retry if <https://httpbin.org> returns a 5xx status code:

```python
def retry_only_on_real_errors(exc: Exception) -> bool:
# If the error is an HTTP status error, only retry on 5xx errors.
if isinstance(exc, httpx.HTTPStatusError):
return exc.response.status_code >= 500

# Otherwise retry on all httpx errors.
return isinstance(exc, httpx.HTTPError)

@stamina.retry(on=retry_only_on_real_errors, attempts=3)
def do_it(code: int) -> httpx.Response:
resp = httpx.get(f"https://httpbin.org/status/{code}")
resp.raise_for_status()

return resp
```

---

To give you observability of your application's retrying, *stamina* will count the retries using [*prometheus-client*](https://github.com/prometheus/client_python) in the `stamina_retries_total` counter (if installed) and log them out using [*structlog*](https://www.structlog.org/) with a fallback to {mod}`logging`.


Expand Down
69 changes: 49 additions & 20 deletions src/stamina/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@
import datetime as dt
import sys

from collections.abc import Callable
from dataclasses import dataclass, replace
from functools import wraps
from inspect import iscoroutinefunction
from types import TracebackType
from typing import AsyncIterator, Awaitable, Iterator, TypedDict, TypeVar
from typing import (
AsyncIterator,
Awaitable,
Callable,
Iterator,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
)

import tenacity as _t

Expand Down Expand Up @@ -51,10 +60,14 @@ async def _smart_sleep(delay: float) -> None:

T = TypeVar("T")
P = ParamSpec("P")
# for backwards compatibility with Python<3.10
ExcOrPredicate = Union[
Type[Exception], Tuple[Type[Exception], ...], Callable[[Exception], bool]
]


def retry_context(
on: type[Exception] | tuple[type[Exception], ...],
on: ExcOrPredicate,
attempts: int | None = 10,
timeout: float | dt.timedelta | None = 45.0,
wait_initial: float | dt.timedelta = 0.1,
Expand Down Expand Up @@ -187,7 +200,7 @@ class RetryingCaller(BaseRetryingCaller):

def __call__(
self,
on: type[Exception] | tuple[type[Exception], ...],
on: ExcOrPredicate,
callable_: Callable[P, T],
/,
*args: P.args,
Expand All @@ -211,9 +224,7 @@ def __call__(

raise SystemError("unreachable") # noqa: EM101

def on(
self, on: type[Exception] | tuple[type[Exception], ...], /
) -> BoundRetryingCaller:
def on(self, on: ExcOrPredicate, /) -> BoundRetryingCaller:
"""
Create a new instance of :class:`BoundRetryingCaller` with the same
parameters, but bound to a specific exception type.
Expand All @@ -240,12 +251,12 @@ class BoundRetryingCaller:
__slots__ = ("_caller", "_on")

_caller: RetryingCaller
_on: type[Exception] | tuple[type[Exception], ...]
_on: ExcOrPredicate

def __init__(
self,
caller: RetryingCaller,
on: type[Exception] | tuple[type[Exception], ...],
on: ExcOrPredicate,
):
self._caller = caller
self._on = on
Expand Down Expand Up @@ -274,7 +285,7 @@ class AsyncRetryingCaller(BaseRetryingCaller):

async def __call__(
self,
on: type[Exception] | tuple[type[Exception], ...],
on: ExcOrPredicate,
callable_: Callable[P, Awaitable[T]],
/,
*args: P.args,
Expand All @@ -289,9 +300,7 @@ async def __call__(

raise SystemError("unreachable") # noqa: EM101

def on(
self, on: type[Exception] | tuple[type[Exception], ...], /
) -> BoundAsyncRetryingCaller:
def on(self, on: ExcOrPredicate, /) -> BoundAsyncRetryingCaller:
"""
Create a new instance of :class:`BoundAsyncRetryingCaller` with the
same parameters, but bound to a specific exception type.
Expand All @@ -315,12 +324,12 @@ class BoundAsyncRetryingCaller:
__slots__ = ("_caller", "_on")

_caller: AsyncRetryingCaller
_on: type[Exception] | tuple[type[Exception], ...]
_on: ExcOrPredicate

def __init__(
self,
caller: AsyncRetryingCaller,
on: type[Exception] | tuple[type[Exception], ...],
on: ExcOrPredicate,
):
self._caller = caller
self._on = on
Expand Down Expand Up @@ -373,7 +382,7 @@ class _RetryContextIterator:
@classmethod
def from_params(
cls,
on: type[Exception] | tuple[type[Exception], ...],
on: ExcOrPredicate,
attempts: int | None,
timeout: float | dt.timedelta | None,
wait_initial: float | dt.timedelta,
Expand All @@ -384,12 +393,20 @@ def from_params(
args: tuple[object, ...],
kw: dict[str, object],
) -> _RetryContextIterator:
if (
isinstance(on, type)
and issubclass(on, BaseException)
or isinstance(on, tuple)
):
_retry = _t.retry_if_exception_type(on)
else:
_retry = _t.retry_if_exception(on)
return cls(
_name=name,
_args=args,
_kw=kw,
_t_kw={
"retry": _t.retry_if_exception_type(on),
"retry": _retry,
"wait": _t.wait_exponential_jitter(
initial=(
wait_initial.total_seconds()
Expand Down Expand Up @@ -521,7 +538,7 @@ def _make_stop(*, attempts: int | None, timeout: float | None) -> _t.stop_base:

def retry(
*,
on: type[Exception] | tuple[type[Exception], ...],
on: ExcOrPredicate,
attempts: int | None = 10,
timeout: float | dt.timedelta | None = 45.0,
wait_initial: float | dt.timedelta = 0.1,
Expand Down Expand Up @@ -552,8 +569,18 @@ def retry(
Args:
on:
An Exception or a tuple of Exceptions on which the decorated
callable will be retried. There is no default -- you *must* pass
this explicitly.
callable will be retried.
You can also pass a *predicate* in the form of a callable that
takes an exception and returns a bool which decides whether the
exception should be retried -- True meaning yes.
This allows more fine-grained control over when to retry. For
example, to only retry on HTTP errors in the 500s range that
indicate server errors, but not those in the 400s which indicate a
client error.
There is no default -- you *must* pass this explicitly.
attempts:
Maximum total number of attempts. Can be combined with *timeout*.
Expand All @@ -577,6 +604,8 @@ def retry(
:class:`datetime.timedelta`.
.. versionadded:: 23.3.0 `Trio <https://trio.readthedocs.io/>`_ support.
.. versionadded:: 24.3.0 *on* can be a callable now.
"""
retry_ctx = _RetryContextIterator.from_params(
on=on,
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,18 @@ def _reset_config():
@pytest.fixture(params=BACKENDS)
def anyio_backend(request):
return request.param


@pytest.fixture(
name="on",
params=[
ValueError,
(ValueError,),
lambda exc: isinstance(exc, ValueError),
],
)
def _on(request):
"""
Parametrize over different ways to specify the exception to retry on.
"""
return request.param
20 changes: 10 additions & 10 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ async def f():

@pytest.mark.parametrize("timeout", [None, 1, dt.timedelta(days=1)])
@pytest.mark.parametrize("duration", [0, dt.timedelta(days=0)])
async def test_retries(duration, timeout):
async def test_retries(duration, timeout, on):
"""
Retries if the specific error is raised.
"""
i = 0

@stamina.retry(
on=ValueError,
on=on,
timeout=timeout,
wait_max=duration,
wait_initial=duration,
Expand All @@ -74,14 +74,14 @@ async def f():
assert 1 == i


async def test_retries_method():
async def test_retries_method(on):
"""
Retries if the specific error is raised.
"""
i = 0

class C:
@stamina.retry(on=ValueError, wait_max=0)
@stamina.retry(on=on, wait_max=0)
async def f(self):
nonlocal i
if i == 0:
Expand All @@ -94,12 +94,12 @@ async def f(self):
assert 1 == i


async def test_wrong_exception():
async def test_wrong_exception(on):
"""
Exceptions that are not passed as `on` are left through.
"""

@stamina.retry(on=ValueError)
@stamina.retry(on=on)
async def f():
raise TypeError("passed")

Expand Down Expand Up @@ -177,13 +177,13 @@ async def test_retry_inactive_block_ok():
assert 1 == num_called


async def test_retry_block():
async def test_retry_block(on):
"""
Async retry_context blocks are retried.
"""
num_called = 0

async for attempt in stamina.retry_context(on=ValueError, wait_max=0):
async for attempt in stamina.retry_context(on=on, wait_max=0):
with attempt:
num_called += 1

Expand Down Expand Up @@ -223,7 +223,7 @@ async def f():

assert 42 == await arc(f)

async def test_retries(self):
async def test_retries(self, on):
"""
Retries if the specific error is raised. Arguments are passed through.
"""
Expand All @@ -237,7 +237,7 @@ async def f(*args, **kw):

return args, kw

arc = stamina.AsyncRetryingCaller().on(ValueError)
arc = stamina.AsyncRetryingCaller().on(on)

args, kw = await arc(f, 42, foo="bar")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_structlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_context_sync(log_output):
"""
from tests.test_sync import test_retry_block

test_retry_block()
test_retry_block(ValueError)

assert [
{
Expand All @@ -105,7 +105,7 @@ async def test_context_async(log_output):
"""
from tests.test_async import test_retry_block

await test_retry_block()
await test_retry_block(ValueError)

assert [
{
Expand Down
Loading

0 comments on commit a5c37a8

Please sign in to comment.