Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: tweak type annotations to make Pyright pass #2375

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
10 changes: 8 additions & 2 deletions falcon/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Optional,
Pattern,
Protocol,
Sequence,
Tuple,
TYPE_CHECKING,
TypeVar,
Expand Down Expand Up @@ -104,6 +105,9 @@ async def __call__(
HeaderMapping = Mapping[str, str]
HeaderIter = Iterable[Tuple[str, str]]
HeaderArg = Union[HeaderMapping, HeaderIter]

NarrowHeaderArg = Union[Mapping[str, str], Sequence[Tuple[str, str]]]

ResponseStatus = Union[http.HTTPStatus, str, int]
StoreArg = Optional[Dict[str, Any]]
Resource = object
Expand Down Expand Up @@ -164,6 +168,9 @@ async def __call__(
AsgiProcessResponseMethod = Callable[
['AsgiRequest', 'AsgiResponse', Resource, bool], Awaitable[None]
]
AsgiProcessStartupMethod = Callable[[Dict[str, Any], 'AsgiEvent'], Awaitable[None]]
AsgiProcessShutdownMethod = Callable[[Dict[str, Any], 'AsgiEvent'], Awaitable[None]]

AsgiProcessRequestWsMethod = Callable[['AsgiRequest', 'WebSocket'], Awaitable[None]]
AsgiProcessResourceWsMethod = Callable[
['AsgiRequest', 'WebSocket', Resource, Dict[str, Any]], Awaitable[None]
Expand All @@ -173,7 +180,6 @@ async def __call__(
Tuple[Callable[[], Awaitable[None]], Literal[True]],
]


# Routing

MethodDict = Union[
Expand All @@ -190,7 +196,7 @@ def __call__(

# Media
class SerializeSync(Protocol):
def __call__(self, media: Any, content_type: Optional[str] = ...) -> bytes: ...
def __call__(self, media: object, content_type: Optional[str] = ...) -> bytes: ...


DeserializeSync = Callable[[bytes], Any]
Expand Down
13 changes: 10 additions & 3 deletions falcon/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,13 @@ def _get_body(
# NOTE(kgriffs): Heuristic to quickly check if stream is
# file-like. Not perfect, but should be good enough until
# proven otherwise.
# TODO(jkmnt): The checks like these are a perfect candidates for the
# Python 3.13 TypeIs guard. The TypeGuard of Python 3.10+ seems to fit too,
# though it narrows type only for the 'if' branch.
# Something like:
# def is_readable_io(stream) -> TypeIs[ReadableStream]:
# return hasattr(stream, 'read')
#
if hasattr(stream, 'read'):
if wsgi_file_wrapper is not None:
# TODO(kgriffs): Make block size configurable at the
Expand All @@ -1251,17 +1258,17 @@ def _get_body(
self._STREAM_BLOCK_SIZE,
)
else:
iterable = stream
iterable = cast(Iterable[bytes], stream)

return iterable, None

return [], 0

def _update_sink_and_static_routes(self) -> None:
if self._sink_before_static_route:
self._sink_and_static_routes = tuple(self._sinks + self._static_routes) # type: ignore[operator]
self._sink_and_static_routes = (*self._sinks, *self._static_routes)
else:
self._sink_and_static_routes = tuple(self._static_routes + self._sinks) # type: ignore[operator]
self._sink_and_static_routes = (*self._static_routes, *self._sinks)


# TODO(myusko): This class is a compatibility alias, and should be removed
Expand Down
21 changes: 15 additions & 6 deletions falcon/app_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
from __future__ import annotations

from inspect import iscoroutinefunction
from typing import IO, Iterable, List, Literal, Optional, overload, Tuple, Union
from typing import (
Callable,
Iterable,
List,
Literal,
Optional,
overload,
Tuple,
Union,
)

from falcon import util
from falcon._typing import AsgiProcessRequestMethod as APRequest
Expand All @@ -34,6 +43,7 @@
from falcon.errors import HTTPError
from falcon.request import Request
from falcon.response import Response
from falcon.typing import ReadableIO
from falcon.util.sync import _wrap_non_coroutine_unsafe

__all__ = (
Expand Down Expand Up @@ -367,7 +377,7 @@ class CloseableStreamIterator:
block_size (int): Number of bytes to read per iteration.
"""

def __init__(self, stream: IO[bytes], block_size: int) -> None:
def __init__(self, stream: ReadableIO, block_size: int) -> None:
self._stream = stream
self._block_size = block_size

Expand All @@ -383,7 +393,6 @@ def __next__(self) -> bytes:
return data

def close(self) -> None:
try:
self._stream.close()
except (AttributeError, TypeError):
pass
close: Optional[Callable[[], None]] = getattr(self._stream, 'close', None)
if close:
close()
54 changes: 33 additions & 21 deletions falcon/asgi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@
Union,
)

from falcon import _logger
from falcon import constants
from falcon import responders
from falcon import routing
from falcon._typing import _UNSET
from falcon._typing import AsgiErrorHandler
from falcon._typing import AsgiProcessShutdownMethod
from falcon._typing import AsgiProcessStartupMethod
from falcon._typing import AsgiReceive
from falcon._typing import AsgiResponderCallable
from falcon._typing import AsgiResponderWsCallable
Expand All @@ -59,6 +62,7 @@
from falcon.constants import MEDIA_JSON
from falcon.errors import CompatibilityError
from falcon.errors import HTTPBadRequest
from falcon.errors import HTTPInternalServerError
from falcon.errors import WebSocketDisconnected
from falcon.http_error import HTTPError
from falcon.http_status import HTTPStatus
Expand Down Expand Up @@ -768,10 +772,16 @@ async def watch_disconnect() -> None:
# (c) async iterator
#

if hasattr(stream, 'read'):
read: Optional[Callable[[int], Awaitable[bytes]]] = getattr(
stream, 'read', None
)
close: Optional[Callable[[], Awaitable[None]]] = getattr(
stream, 'close', None
)
if read:
try:
while True:
data = await stream.read(self._STREAM_BLOCK_SIZE)
data = await read(self._STREAM_BLOCK_SIZE)
if data == b'':
break
else:
Expand All @@ -785,8 +795,8 @@ async def watch_disconnect() -> None:
}
)
finally:
if hasattr(stream, 'close'):
await stream.close()
if close:
await close()
else:
# NOTE(kgriffs): Works for both async generators and iterators
try:
Expand Down Expand Up @@ -822,11 +832,8 @@ async def watch_disconnect() -> None:
'Response.stream: ' + str(ex)
)
finally:
# NOTE(vytas): This could be DRYed with the above identical
# twoliner in a one large block, but OTOH we would be
# unable to reuse the current try.. except.
if hasattr(stream, 'close'):
await stream.close()
if close:
await close()

await send(_EVT_RESP_EOF)

Expand Down Expand Up @@ -1003,6 +1010,7 @@ async def handle(req, resp, ex, params):
'The handler must be an awaitable coroutine function in order '
'to be used safely with an ASGI app.'
)
assert handler
handler_callable: AsgiErrorHandler = handler

exception_tuple: Tuple[type[BaseException], ...]
Expand Down Expand Up @@ -1075,9 +1083,12 @@ async def _call_lifespan_handlers(
return

for handler in self._unprepared_middleware:
if hasattr(handler, 'process_startup'):
process_startup: Optional[AsgiProcessStartupMethod] = getattr(
handler, 'process_startup', None
)
if process_startup:
try:
await handler.process_startup(scope, event)
await process_startup(scope, event)
except Exception:
await send(
{
Expand All @@ -1091,9 +1102,12 @@ async def _call_lifespan_handlers(

elif event['type'] == 'lifespan.shutdown':
for handler in reversed(self._unprepared_middleware):
if hasattr(handler, 'process_shutdown'):
process_shutdown: Optional[AsgiProcessShutdownMethod] = getattr(
handler, 'process_shutdown', None
)
if process_shutdown:
try:
await handler.process_shutdown(scope, event)
await process_shutdown(scope, event)
except Exception:
await send(
{
Expand Down Expand Up @@ -1185,7 +1199,7 @@ async def _http_status_handler( # type: ignore[override]
self._compose_status_response(req, resp, status)
elif ws:
code = http_status_to_ws_code(status.status_code)
falcon._logger.error(
_logger.error(
'[FALCON] HTTPStatus %s raised while handling WebSocket. '
'Closing with code %s',
status,
Expand All @@ -1207,7 +1221,7 @@ async def _http_error_handler( # type: ignore[override]
self._compose_error_response(req, resp, error)
elif ws:
code = http_status_to_ws_code(error.status_code)
falcon._logger.error(
_logger.error(
'[FALCON] HTTPError %s raised while handling WebSocket. '
'Closing with code %s',
error,
Expand All @@ -1225,10 +1239,10 @@ async def _python_error_handler( # type: ignore[override]
params: Dict[str, Any],
ws: Optional[WebSocket] = None,
) -> None:
falcon._logger.error('[FALCON] Unhandled exception in ASGI app', exc_info=error)
_logger.error('[FALCON] Unhandled exception in ASGI app', exc_info=error)

if resp:
self._compose_error_response(req, resp, falcon.HTTPInternalServerError())
self._compose_error_response(req, resp, HTTPInternalServerError())
elif ws:
await self._ws_cleanup_on_error(ws)
else:
Expand All @@ -1244,9 +1258,7 @@ async def _ws_disconnected_error_handler(
) -> None:
assert resp is None
assert ws is not None
falcon._logger.debug(
'[FALCON] WebSocket client disconnected with code %i', error.code
)
_logger.debug('[FALCON] WebSocket client disconnected with code %i', error.code)
await self._ws_cleanup_on_error(ws)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1323,7 +1335,7 @@ async def _ws_cleanup_on_error(self, ws: WebSocket) -> None:
if 'invalid close code' in str(ex).lower():
await ws.close(_FALLBACK_WS_ERROR_CODE)
else:
falcon._logger.warning(
_logger.warning(
(
'[FALCON] Attempt to close web connection cleanly '
'failed due to raised error.'
Expand Down
13 changes: 4 additions & 9 deletions falcon/asgi/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import auto
from enum import Enum
import re
from typing import Any, Deque, Dict, Iterable, Mapping, Optional, Tuple, Union
from typing import Any, Deque, Dict, Mapping, Optional, Tuple, Union

from falcon import errors
from falcon import media
Expand All @@ -18,6 +18,7 @@
from falcon.asgi_spec import EventType
from falcon.asgi_spec import WSCloseCode
from falcon.constants import WebSocketPayloadType
from falcon.response_helpers import _headers_to_items
from falcon.util import misc

__all__ = ('WebSocket',)
Expand Down Expand Up @@ -210,15 +211,9 @@ async def accept(
'does not support accept headers.'
)

header_items = getattr(headers, 'items', None)
if callable(header_items):
headers_iterable: Iterable[tuple[str, str]] = header_items()
else:
headers_iterable = headers # type: ignore[assignment]

event['headers'] = parsed_headers = [
(name.lower().encode('ascii'), value.encode('ascii'))
for name, value in headers_iterable
for name, value in _headers_to_items(headers)
]

for name, __ in parsed_headers:
Expand Down Expand Up @@ -628,7 +623,7 @@ class WebSocketOptions:

@classmethod
def _init_default_close_reasons(cls) -> Dict[int, str]:
reasons = dict(cls._STANDARD_CLOSE_REASONS)
reasons: dict[int, str] = dict(cls._STANDARD_CLOSE_REASONS)
for status_constant in dir(status_codes):
if 'HTTP_100' <= status_constant < 'HTTP_599':
status_line = getattr(status_codes, status_constant)
Expand Down
Loading