Skip to content

Commit

Permalink
feat: finished typing
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed Sep 9, 2024
1 parent c43cb59 commit c363a14
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 40 deletions.
66 changes: 33 additions & 33 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from .result_wrappers import fetch_models
from .utils import CursorProtocol
from typing_extensions import Self
from typing import Tuple, List, Any, cast, Optional
from typing import Tuple, List, Any, cast, Optional, Dict, Union


async def aio_prefetch(sq, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE):
async def aio_prefetch(sq: Any, *subqueries: Any, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> Any:
"""Asynchronous version of `prefetch()`.
See also:
Expand All @@ -18,8 +18,8 @@ async def aio_prefetch(sq, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_
return sq

fixed_queries = peewee.prefetch_add_subquery(sq, subqueries, prefetch_type)
deps = {}
rel_map = {}
deps: Dict[Any, Any] = {}
rel_map: Dict[Any, Any] = {}

for pq in reversed(fixed_queries):
query_model = pq.model
Expand Down Expand Up @@ -49,27 +49,27 @@ class AioQueryMixin:
async def aio_execute(self, database: AioDatabase) -> Any:
return await database.aio_execute(self)

async def fetch_results(self, cursor: CursorProtocol) -> List[Any]:
async def fetch_results(self, cursor: CursorProtocol) -> Any:
return await fetch_models(cursor, self)


class AioModelDelete(peewee.ModelDelete, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol):
async def fetch_results(self, cursor: CursorProtocol) -> Union[List[Any], int]:
if self._returning:
return await fetch_models(cursor, self)
return cursor.rowcount


class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin):

async def fetch_results(self, cursor: CursorProtocol):
async def fetch_results(self, cursor: CursorProtocol) -> Union[List[Any], int]:
if self._returning:
return await fetch_models(cursor, self)
return cursor.rowcount


class AioModelInsert(peewee.ModelInsert, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol):
async def fetch_results(self, cursor: CursorProtocol) -> Union[List[Any], Any, int]:
if self._returning is not None and len(self._returning) > 1:
return await fetch_models(cursor, self)

Expand All @@ -96,26 +96,26 @@ async def aio_scalar(self, database: AioDatabase, as_tuple: bool = False) -> Any
See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.scalar
"""
async def fetch_results(cursor):
async def fetch_results(cursor: CursorProtocol) -> Any:
return await cursor.fetchone()

rows = await database.aio_execute(self, fetch_results=fetch_results)

return rows[0] if rows and not as_tuple else rows

async def aio_get(self, database: Optional[AioDatabase] = None):
async def aio_get(self, database: Optional[AioDatabase] = None) -> Any:
"""
Async version of **peewee.SelectBase.get**
See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.get
"""
clone = self.paginate(1, 1)
clone = self.paginate(1, 1) # type: ignore
try:
return (await clone.aio_execute(database))[0]
except IndexError:
sql, params = clone.sql()
raise self.model.DoesNotExist('%s instance matching query does '
raise self.model.DoesNotExist('%s instance matching query does ' # type: ignore
'not exist:\nSQL: %s\nParams: %s' %
(clone.model, sql, params))

Expand All @@ -127,7 +127,7 @@ async def aio_count(self, database: AioDatabase, clear_limit: bool = False) -> i
See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.count
"""
clone = self.order_by().alias('_wrapped')
clone = self.order_by().alias('_wrapped') # type: ignore
if clear_limit:
clone._limit = clone._offset = None
try:
Expand All @@ -150,28 +150,28 @@ async def aio_exists(self, database: AioDatabase) -> bool:
See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.exists
"""
clone = self.columns(peewee.SQL('1'))
clone = self.columns(peewee.SQL('1')) # type: ignore
clone._limit = 1
clone._offset = None
return bool(await clone.aio_scalar())

def union_all(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs)
def union_all(self, rhs: Any) -> "AioModelCompoundSelectQuery":
return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) # type: ignore
__add__ = union_all

def union(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs)
def union(self, rhs: Any) -> "AioModelCompoundSelectQuery":
return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs) # type: ignore
__or__ = union

def intersect(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs)
def intersect(self, rhs: Any) -> "AioModelCompoundSelectQuery":
return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) # type: ignore
__and__ = intersect

def except_(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs)
def except_(self, rhs: Any) -> "AioModelCompoundSelectQuery":
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) # type: ignore
__sub__ = except_

def aio_prefetch(self, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE):
def aio_prefetch(self, *subqueries: Any, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> Any:
"""
Async version of **peewee.ModelSelect.prefetch**
Expand Down Expand Up @@ -214,32 +214,32 @@ class User(peewee_async.AioModel):
"""

@classmethod
def select(cls, *fields) -> AioModelSelect:
def select(cls, *fields: Any) -> AioModelSelect:
is_default = not fields
if not fields:
fields = cls._meta.sorted_fields
return AioModelSelect(cls, fields, is_default=is_default)

@classmethod
def update(cls, __data=None, **update) -> AioModelUpdate:
def update(cls, __data: Any = None, **update: Any) -> AioModelUpdate:
return AioModelUpdate(cls, cls._normalize_data(__data, update))

@classmethod
def insert(cls, __data=None, **insert) -> AioModelInsert:
def insert(cls, __data: Any = None, **insert: Any) -> AioModelInsert:
return AioModelInsert(cls, cls._normalize_data(__data, insert))

@classmethod
def insert_many(cls, rows, fields=None) -> AioModelInsert:
def insert_many(cls, rows: Any, fields: Any = None) -> AioModelInsert:
return AioModelInsert(cls, insert=rows, columns=fields)

@classmethod
def insert_from(cls, query, fields) -> AioModelInsert:
def insert_from(cls, query: Any, fields: Any) -> AioModelInsert:
columns = [getattr(cls, field) if isinstance(field, str)
else field for field in fields]
return AioModelInsert(cls, insert=query, columns=columns)

@classmethod
def raw(cls, sql, *params) -> AioModelRaw:
def raw(cls, sql: Optional[str], *params: Optional[List[Any]]) -> AioModelRaw:
return AioModelRaw(cls, sql, params)

@classmethod
Expand All @@ -263,7 +263,7 @@ async def aio_delete_instance(self, recursive: bool = False, delete_nullable: bo
await model.delete().where(query).aio_execute()
return cast(int, await type(self).delete().where(self._pk_expr()).aio_execute())

async def aio_save(self, force_insert: bool = False, only=None) -> int:
async def aio_save(self, force_insert: bool = False, only: Any =None) -> int:
"""
Async version of **peewee.Model.save**
Expand All @@ -273,7 +273,7 @@ async def aio_save(self, force_insert: bool = False, only=None) -> int:
field_dict = self.__data__.copy()
if self._meta.primary_key is not False:
pk_field = self._meta.primary_key
pk_value = self._pk
pk_value = self._pk # type: ignore
else:
pk_field = pk_value = None
if only is not None:
Expand Down Expand Up @@ -313,7 +313,7 @@ async def aio_save(self, force_insert: bool = False, only=None) -> int:
return rows

@classmethod
async def aio_get(cls, *query, **filters) -> Self:
async def aio_get(cls, *query: Any, **filters: Any) -> Self:
"""Async version of **peewee.Model.get**
See also:
Expand All @@ -327,7 +327,7 @@ async def aio_get(cls, *query, **filters) -> Self:
sq = sq.where(*query)
if filters:
sq = sq.filter(**filters)
return await sq.aio_get()
return cast(Self, await sq.aio_get())

@classmethod
async def aio_get_or_none(cls, *query: Any, **filters: Any) -> Optional[Self]:
Expand Down
13 changes: 9 additions & 4 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .connection import connection_context, ConnectionContextManager
from .pool import PoolBackend, PostgresqlPoolBackend, MysqlPoolBackend
from .transactions import Transaction
from .utils import aiopg, aiomysql, __log__
from .utils import aiopg, aiomysql, __log__, FetchResults


class AioDatabase(peewee.Database):
Expand Down Expand Up @@ -109,7 +109,7 @@ def allow_sync(self) -> Iterator[None]:
self._allow_sync = old_allow_sync
self.close()

def execute_sql(self, *args: Any, **kwargs: Any):
def execute_sql(self, *args: Any, **kwargs: Any) -> Any:
"""Sync execute SQL query, `allow_sync` must be set to True.
"""
assert self._allow_sync, (
Expand All @@ -129,7 +129,12 @@ def aio_connection(self) -> ConnectionContextManager:

return ConnectionContextManager(self.pool_backend)

async def aio_execute_sql(self, sql: str, params: Optional[List[Any]] = None, fetch_results=None):
async def aio_execute_sql(
self,
sql: str,
params: Optional[List[Any]] = None,
fetch_results: Optional[FetchResults] = None
) -> Any:
__log__.debug(sql, params)
with peewee.__exception_wrapper__:
async with self.aio_connection() as connection:
Expand All @@ -138,7 +143,7 @@ async def aio_execute_sql(self, sql: str, params: Optional[List[Any]] = None, fe
if fetch_results is not None:
return await fetch_results(cursor)

async def aio_execute(self, query, fetch_results=None) -> Any:
async def aio_execute(self, query: Any, fetch_results: Optional[FetchResults] = None) -> Any:
"""Execute *SELECT*, *INSERT*, *UPDATE* or *DELETE* query asyncronously.
:param query: peewee query instance created with ``Model.select()``,
Expand Down
8 changes: 5 additions & 3 deletions peewee_async/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import logging
from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager, List
from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager, List, Callable, Awaitable

try:
import aiopg
import psycopg2
except ImportError:
aiopg = None
aiopg = None # type: ignore
psycopg2 = None

try:
import aiomysql
import pymysql
except ImportError:
aiomysql = None
pymysql = None
pymysql = None # type: ignore

__log__ = logging.getLogger('peewee.async')
__log__.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -70,3 +70,5 @@ def terminate(self) -> None:
async def wait_closed(self) -> None:
...


FetchResults = Callable[[CursorProtocol], Awaitable[Any]]

0 comments on commit c363a14

Please sign in to comment.