From 078995815973103489f136f6020327bea2019481 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Mon, 24 Jun 2024 12:12:02 +0300 Subject: [PATCH] Add ability to use async function with async_paginator (#1204) --- fastapi_pagination/async_paginator.py | 14 ++++++++++---- fastapi_pagination/utils.py | 26 ++++++++++++++++++++++++-- tests/test_async_paginator.py | 12 ++++++++++-- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/fastapi_pagination/async_paginator.py b/fastapi_pagination/async_paginator.py index 22bf3a1b..06e3b3f6 100644 --- a/fastapi_pagination/async_paginator.py +++ b/fastapi_pagination/async_paginator.py @@ -1,11 +1,11 @@ -from typing import Any, Callable, Optional, Sequence, TypeVar +from typing import Any, Awaitable, Callable, Optional, Sequence, TypeVar, Union __all__ = ["paginate"] from .api import apply_items_transformer, create_page from .bases import AbstractParams from .types import AdditionalData, SyncItemsTransformer -from .utils import check_installed_extensions, verify_params +from .utils import await_if_async, check_installed_extensions, verify_params T = TypeVar("T") @@ -14,7 +14,7 @@ async def paginate( sequence: Sequence[T], params: Optional[AbstractParams] = None, - length_function: Callable[[Sequence[T]], int] = len, + length_function: Optional[Callable[[Sequence[T]], Union[int, Awaitable[int]]]] = None, *, safe: bool = False, transformer: Optional[SyncItemsTransformer] = None, @@ -28,9 +28,15 @@ async def paginate( items = sequence[raw_params.as_slice()] t_items = await apply_items_transformer(items, transformer, async_=True) + length_function = length_function or len + + total = None + if raw_params.include_total: + total = await await_if_async(length_function, sequence) + return create_page( t_items, - total=length_function(sequence) if raw_params.include_total else None, + total=total, params=params, **(additional_data or {}), ) diff --git a/fastapi_pagination/utils.py b/fastapi_pagination/utils.py index 9427594e..97c701e8 100644 --- a/fastapi_pagination/utils.py +++ b/fastapi_pagination/utils.py @@ -6,6 +6,7 @@ "create_pydantic_model", "verify_params", "is_async_callable", + "await_if_async", "check_installed_extensions", "disable_installed_extensions_check", "FastAPIPaginationWarning", @@ -16,10 +17,10 @@ import functools import inspect import warnings -from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, TypeVar, cast, overload from pydantic import VERSION, BaseModel -from typing_extensions import Annotated, Literal, get_origin +from typing_extensions import Annotated, Literal, ParamSpec, get_origin if TYPE_CHECKING: from .bases import AbstractParams, BaseRawParams, CursorRawParams, RawParams @@ -66,6 +67,27 @@ def is_async_callable(obj: Any) -> bool: # pragma: no cover return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__)) +P = ParamSpec("P") +R = TypeVar("R") + + +@overload +async def await_if_async(func: Callable[P, Awaitable[R]], /, *args: P.args, **kwargs: P.kwargs) -> R: + pass + + +@overload +async def await_if_async(func: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R: + pass + + +async def await_if_async(func: Callable[P, Any], /, *args: P.args, **kwargs: P.kwargs) -> Any: + if is_async_callable(func): + return await func(*args, **kwargs) + + return func(*args, **kwargs) + + _EXTENSIONS = [ "databases", "django", diff --git a/tests/test_async_paginator.py b/tests/test_async_paginator.py index 2b0209ae..a2ff5e84 100644 --- a/tests/test_async_paginator.py +++ b/tests/test_async_paginator.py @@ -12,9 +12,17 @@ from .utils import OptionalLimitOffsetPage, OptionalPage +async def _len_func(seq): + return len(seq) + + class TestAsyncPaginationParams(BasePaginationTestCase): + @fixture(scope="session", params=[len, _len_func], ids=["sync", "async"]) + def len_function(self, request): + return request.param + @fixture(scope="session") - def app(self, model_cls, entities): + def app(self, model_cls, entities, len_function): app = FastAPI() @app.get("/default", response_model=Page[model_cls]) @@ -22,6 +30,6 @@ def app(self, model_cls, entities): @app.get("/optional/default", response_model=OptionalPage[model_cls]) @app.get("/optional/limit-offset", response_model=OptionalLimitOffsetPage[model_cls]) async def route(): - return await paginate(entities) + return await paginate(entities, length_function=len_function) return add_pagination(app)