Skip to content

Commit

Permalink
Add ability to use async function with async_paginator (#1204)
Browse files Browse the repository at this point in the history
  • Loading branch information
uriyyo authored Jun 24, 2024
1 parent 894b12b commit 0789958
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
14 changes: 10 additions & 4 deletions fastapi_pagination/async_paginator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Callable, Optional, Sequence, TypeVar
from typing import Any, Awaitable, Callable, Optional, Sequence, TypeVar, Union

__all__ = ["paginate"]

from .api import apply_items_transformer, create_page
from .bases import AbstractParams
from .types import AdditionalData, SyncItemsTransformer
from .utils import check_installed_extensions, verify_params
from .utils import await_if_async, check_installed_extensions, verify_params

T = TypeVar("T")

Expand All @@ -14,7 +14,7 @@
async def paginate(
sequence: Sequence[T],
params: Optional[AbstractParams] = None,
length_function: Callable[[Sequence[T]], int] = len,
length_function: Optional[Callable[[Sequence[T]], Union[int, Awaitable[int]]]] = None,
*,
safe: bool = False,
transformer: Optional[SyncItemsTransformer] = None,
Expand All @@ -28,9 +28,15 @@ async def paginate(
items = sequence[raw_params.as_slice()]
t_items = await apply_items_transformer(items, transformer, async_=True)

length_function = length_function or len

total = None
if raw_params.include_total:
total = await await_if_async(length_function, sequence)

return create_page(
t_items,
total=length_function(sequence) if raw_params.include_total else None,
total=total,
params=params,
**(additional_data or {}),
)
26 changes: 24 additions & 2 deletions fastapi_pagination/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"create_pydantic_model",
"verify_params",
"is_async_callable",
"await_if_async",
"check_installed_extensions",
"disable_installed_extensions_check",
"FastAPIPaginationWarning",
Expand All @@ -16,10 +17,10 @@
import functools
import inspect
import warnings
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, TypeVar, cast, overload

from pydantic import VERSION, BaseModel
from typing_extensions import Annotated, Literal, get_origin
from typing_extensions import Annotated, Literal, ParamSpec, get_origin

if TYPE_CHECKING:
from .bases import AbstractParams, BaseRawParams, CursorRawParams, RawParams
Expand Down Expand Up @@ -66,6 +67,27 @@ def is_async_callable(obj: Any) -> bool: # pragma: no cover
return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))


P = ParamSpec("P")
R = TypeVar("R")


@overload
async def await_if_async(func: Callable[P, Awaitable[R]], /, *args: P.args, **kwargs: P.kwargs) -> R:
pass


@overload
async def await_if_async(func: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R:
pass


async def await_if_async(func: Callable[P, Any], /, *args: P.args, **kwargs: P.kwargs) -> Any:
if is_async_callable(func):
return await func(*args, **kwargs)

return func(*args, **kwargs)


_EXTENSIONS = [
"databases",
"django",
Expand Down
12 changes: 10 additions & 2 deletions tests/test_async_paginator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,24 @@
from .utils import OptionalLimitOffsetPage, OptionalPage


async def _len_func(seq):
return len(seq)


class TestAsyncPaginationParams(BasePaginationTestCase):
@fixture(scope="session", params=[len, _len_func], ids=["sync", "async"])
def len_function(self, request):
return request.param

@fixture(scope="session")
def app(self, model_cls, entities):
def app(self, model_cls, entities, len_function):
app = FastAPI()

@app.get("/default", response_model=Page[model_cls])
@app.get("/limit-offset", response_model=LimitOffsetPage[model_cls])
@app.get("/optional/default", response_model=OptionalPage[model_cls])
@app.get("/optional/limit-offset", response_model=OptionalLimitOffsetPage[model_cls])
async def route():
return await paginate(entities)
return await paginate(entities, length_function=len_function)

return add_pagination(app)

0 comments on commit 0789958

Please sign in to comment.