From 41f44b4dfce93d1dc3e74424a873e71cc2f18808 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 22 Sep 2024 16:41:54 +0100 Subject: [PATCH 1/9] Test typing round 6 --- aiohttp/web_ws.py | 6 +- tests/test_web_runner.py | 90 +++---- tests/test_web_sendfile_functional.py | 196 ++++++++------- tests/test_web_server.py | 53 ++-- tests/test_web_websocket.py | 324 +++++++++++++------------ tests/test_web_websocket_functional.py | 198 +++++++-------- 6 files changed, 449 insertions(+), 418 deletions(-) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index d5a48e478db..14d47b66e40 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -404,14 +404,14 @@ async def pong(self, message: bytes = b"") -> None: raise RuntimeError("Call .prepare() first") await self._writer.pong(message) - async def send_str(self, data: str, compress: Optional[bool] = None) -> None: + async def send_str(self, data: str, compress: Optional[int] = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, str): raise TypeError("data argument must be str (%r)" % type(data)) await self._writer.send(data, binary=False, compress=compress) - async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None: + async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, (bytes, bytearray, memoryview)): @@ -421,7 +421,7 @@ async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None async def send_json( self, data: Any, - compress: Optional[bool] = None, + compress: Optional[int] = None, *, dumps: JSONEncoder = json.dumps, ) -> None: diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index 302afbb95d6..4bf0a9fc796 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -1,29 +1,34 @@ -# type: ignore import asyncio import platform import signal -from typing import Any -from unittest.mock import patch +from typing import Any, Iterator, NoReturn, Protocol, Union +from unittest import mock import pytest from aiohttp import web from aiohttp.abc import AbstractAccessLogger from aiohttp.test_utils import get_unused_port_socket +from aiohttp.web_log import AccessLogger + + +class _RunnerMaker(Protocol): + def __call__(self, handle_signals: bool = ..., **kwargs: Any) -> web.AppRunner: + ... @pytest.fixture -def app(): +def app() -> web.Application: return web.Application() @pytest.fixture -def make_runner(loop: Any, app: Any): +def make_runner(loop: asyncio.AbstractEventLoop, app: web.Application) -> Iterator[_RunnerMaker]: asyncio.set_event_loop(loop) runners = [] - def go(**kwargs): - runner = web.AppRunner(app, **kwargs) + def go(handle_signals: bool = False, **kwargs: Any) -> web.AppRunner: + runner = web.AppRunner(app, handle_signals=handle_signals, **kwargs) runners.append(runner) return runner @@ -32,7 +37,7 @@ def go(**kwargs): loop.run_until_complete(runner.cleanup()) -async def test_site_for_nonfrozen_app(make_runner: Any) -> None: +async def test_site_for_nonfrozen_app(make_runner: _RunnerMaker) -> None: runner = make_runner() with pytest.raises(RuntimeError): web.TCPSite(runner) @@ -42,7 +47,7 @@ async def test_site_for_nonfrozen_app(make_runner: Any) -> None: @pytest.mark.skipif( platform.system() == "Windows", reason="the test is not valid for Windows" ) -async def test_runner_setup_handle_signals(make_runner: Any) -> None: +async def test_runner_setup_handle_signals(make_runner: _RunnerMaker) -> None: runner = make_runner(handle_signals=True) await runner.setup() assert signal.getsignal(signal.SIGTERM) is not signal.SIG_DFL @@ -53,7 +58,7 @@ async def test_runner_setup_handle_signals(make_runner: Any) -> None: @pytest.mark.skipif( platform.system() == "Windows", reason="the test is not valid for Windows" ) -async def test_runner_setup_without_signal_handling(make_runner: Any) -> None: +async def test_runner_setup_without_signal_handling(make_runner: _RunnerMaker) -> None: runner = make_runner(handle_signals=False) await runner.setup() assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL @@ -61,7 +66,7 @@ async def test_runner_setup_without_signal_handling(make_runner: Any) -> None: assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL -async def test_site_double_added(make_runner: Any) -> None: +async def test_site_double_added(make_runner: _RunnerMaker) -> None: _sock = get_unused_port_socket("127.0.0.1") runner = make_runner() await runner.setup() @@ -73,7 +78,7 @@ async def test_site_double_added(make_runner: Any) -> None: assert len(runner.sites) == 1 -async def test_site_stop_not_started(make_runner: Any) -> None: +async def test_site_stop_not_started(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) @@ -83,13 +88,14 @@ async def test_site_stop_not_started(make_runner: Any) -> None: assert len(runner.sites) == 0 -async def test_custom_log_format(make_runner: Any) -> None: +async def test_custom_log_format(make_runner: _RunnerMaker) -> None: runner = make_runner(access_log_format="abc") await runner.setup() + assert runner.server is not None assert runner.server._kwargs["access_log_format"] == "abc" -async def test_unreg_site(make_runner: Any) -> None: +async def test_unreg_site(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) @@ -97,20 +103,20 @@ async def test_unreg_site(make_runner: Any) -> None: runner._unreg_site(site) -async def test_app_property(make_runner: Any, app: Any) -> None: +async def test_app_property(make_runner: _RunnerMaker, app: web.Application) -> None: runner = make_runner() assert runner.app is app def test_non_app() -> None: with pytest.raises(TypeError): - web.AppRunner(object()) + web.AppRunner(object()) # type: ignore[arg-type] def test_app_handler_args() -> None: app = web.Application(handler_args={"test": True}) runner = web.AppRunner(app) - assert runner._kwargs == {"access_log_class": web.AccessLogger, "test": True} + assert runner._kwargs == {"access_log_class": AccessLogger, "test": True} async def test_app_handler_args_failure() -> None: @@ -132,7 +138,7 @@ async def test_app_handler_args_failure() -> None: ("2", 2), ), ) -async def test_app_handler_args_ceil_threshold(value: Any, expected: Any) -> None: +async def test_app_handler_args_ceil_threshold(value: Union[int, str, None], expected: int) -> None: app = web.Application(handler_args={"timeout_ceil_threshold": value}) runner = web.AppRunner(app) await runner.setup() @@ -150,7 +156,7 @@ class Logger: app = web.Application() with pytest.raises(TypeError): - web.AppRunner(app, access_log_class=Logger) + web.AppRunner(app, access_log_class=Logger) # type: ignore[arg-type] async def test_app_make_handler_access_log_class_bad_type2() -> None: @@ -165,7 +171,7 @@ class Logger: async def test_app_make_handler_access_log_class1() -> None: class Logger(AbstractAccessLogger): - def log(self, request, response, time): + def log(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> None: """Pass log method.""" app = web.Application() @@ -175,7 +181,7 @@ def log(self, request, response, time): async def test_app_make_handler_access_log_class2() -> None: class Logger(AbstractAccessLogger): - def log(self, request, response, time): + def log(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> None: """Pass log method.""" app = web.Application(handler_args={"access_log_class": Logger}) @@ -183,7 +189,7 @@ def log(self, request, response, time): assert runner._kwargs["access_log_class"] is Logger -async def test_addresses(make_runner: Any, unix_sockname: Any) -> None: +async def test_addresses(make_runner: _RunnerMaker, unix_sockname: str) -> None: _sock = get_unused_port_socket("127.0.0.1") runner = make_runner() await runner.setup() @@ -200,7 +206,7 @@ async def test_addresses(make_runner: Any, unix_sockname: Any) -> None: platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_runner_wrong_loop( - app: Any, selector_loop: Any, pipe_name: Any + app: web.Application, selector_loop: asyncio.AbstractEventLoop, pipe_name: str ) -> None: runner = web.AppRunner(app) await runner.setup() @@ -212,7 +218,7 @@ async def test_named_pipe_runner_wrong_loop( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_runner_proactor_loop( - proactor_loop: Any, app: Any, pipe_name: Any + proactor_loop: asyncio.AbstractEventLoop, app: web.Application, pipe_name: str ) -> None: runner = web.AppRunner(app) await runner.setup() @@ -221,29 +227,22 @@ async def test_named_pipe_runner_proactor_loop( await runner.cleanup() -async def test_tcpsite_default_host(make_runner: Any) -> None: +async def test_tcpsite_default_host(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) assert site.name == "http://0.0.0.0:8080" - calls = [] - - async def mock_create_server(*args, **kwargs): - calls.append((args, kwargs)) - - with patch("asyncio.get_event_loop") as mock_get_loop: - mock_get_loop.return_value.create_server = mock_create_server + m = mock.create_autospec(asyncio.AbstractEventLoop, spec_set=True, instance=True) + with mock.patch("asyncio.get_event_loop", autospec=True, spec_set=True, return_value=m) as mock_get_loop: await site.start() - assert len(calls) == 1 - server, host, port = calls[0][0] - assert server is runner.server - assert host is None - assert port == 8080 + m.create_server.assert_called_once() + args, kwargs = m.create_server.call_args + assert args == (runner.server, None, 8080) -async def test_tcpsite_empty_str_host(make_runner: Any) -> None: +async def test_tcpsite_empty_str_host(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner, host="") @@ -251,15 +250,16 @@ async def test_tcpsite_empty_str_host(make_runner: Any) -> None: def test_run_after_asyncio_run() -> None: - async def nothing(): - pass + called = False - def spy(): - spy.called = True + async def nothing() -> None: + pass - spy.called = False + def spy() -> None: + nonlocal called + called = True - async def shutdown(): + async def shutdown() -> NoReturn: spy() raise web.GracefulExit() @@ -271,4 +271,4 @@ async def shutdown(): app.on_startup.append(lambda a: asyncio.create_task(shutdown())) web.run_app(app) - assert spy.called, "run_app() should work after asyncio.run()." + assert called, "run_app() should work after asyncio.run()." diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 7f8ae587ba9..f8e2ee5fc52 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -1,17 +1,20 @@ -# type: ignore import asyncio import bz2 import gzip import pathlib import socket +import ssl import zlib -from typing import Any, Iterable, Optional +from typing import Callable, Iterable, Iterator, NoReturn, Optional, Tuple from unittest import mock import pytest +from _pytest.fixtures import SubRequest import aiohttp from aiohttp import web +from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer +from aiohttp.typedefs import PathLike try: import brotlicffi as brotli @@ -21,14 +24,16 @@ try: import ssl except ImportError: - ssl = None + ssl = None # type: ignore[assignment] + +_Sender = Callable[..., web.FileResponse] HELLO_AIOHTTP = b"Hello aiohttp! :-)\n" @pytest.fixture(scope="module") -def hello_txt(request, tmp_path_factory) -> pathlib.Path: +def hello_txt(request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathFactory) -> pathlib.Path: """Create a temp path with hello.txt and compressed versions. The uncompressed text file path is returned by default. Alternatively, an @@ -50,22 +55,22 @@ def hello_txt(request, tmp_path_factory) -> pathlib.Path: @pytest.fixture -def loop_with_mocked_native_sendfile(loop: Any): - def sendfile(transport, fobj, offset, count): +def loop_with_mocked_native_sendfile(loop: asyncio.AbstractEventLoop) -> Iterator[asyncio.AbstractEventLoop]: + def sendfile(transport: object, fobj: object, offset: int, count: int) -> NoReturn: if count == 0: raise ValueError("count must be a positive integer (got 0)") raise NotImplementedError - loop.sendfile = sendfile - return loop + with mock.patch.object(loop, "sendfile", sendfile): + yield loop @pytest.fixture(params=["sendfile", "no_sendfile"], ids=["sendfile", "no_sendfile"]) -def sender(request: Any, loop: Any): +def sender(request: SubRequest, loop: asyncio.AbstractEventLoop) -> Iterator[_Sender]: sendfile_mock = None - def maker(*args, **kwargs): - ret = web.FileResponse(*args, **kwargs) + def maker(path: PathLike, chunk_size: int = 256 * 1024) -> web.FileResponse: + ret = web.FileResponse(path, chunk_size=chunk_size) rloop = asyncio.get_running_loop() is_patched = rloop.sendfile is sendfile_mock assert is_patched if request.param == "no_sendfile" else not is_patched @@ -85,11 +90,11 @@ def maker(*args, **kwargs): @pytest.fixture -def app_with_static_route(sender: Any) -> web.Application: +def app_with_static_route(sender: _Sender) -> web.Application: filename = "data.unknown_mime_type" filepath = pathlib.Path(__file__).parent / filename - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath) app = web.Application() @@ -98,7 +103,7 @@ async def handler(request): async def test_static_file_ok( - aiohttp_client: Any, app_with_static_route: web.Application + aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) @@ -112,10 +117,10 @@ async def test_static_file_ok( await client.close() -async def test_zero_bytes_file_ok(aiohttp_client: Any, sender: Any) -> None: +async def test_zero_bytes_file_ok(aiohttp_client: AiohttpClient, sender: _Sender) -> None: filepath = pathlib.Path(__file__).parent / "data.zero_bytes" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath) app = web.Application() @@ -138,11 +143,11 @@ async def handler(request): async def test_zero_bytes_file_mocked_native_sendfile( - aiohttp_client: Any, loop_with_mocked_native_sendfile: Any + aiohttp_client: AiohttpClient, loop_with_mocked_native_sendfile: asyncio.AbstractEventLoop ) -> None: filepath = pathlib.Path(__file__).parent / "data.zero_bytes" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: asyncio.set_event_loop(loop_with_mocked_native_sendfile) return web.FileResponse(filepath) @@ -167,7 +172,7 @@ async def handler(request): async def test_static_file_ok_string_path( - aiohttp_client: Any, app_with_static_route: web.Application + aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) @@ -181,7 +186,7 @@ async def test_static_file_ok_string_path( await client.close() -async def test_static_file_not_exists(aiohttp_client: Any) -> None: +async def test_static_file_not_exists(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) @@ -191,7 +196,7 @@ async def test_static_file_not_exists(aiohttp_client: Any) -> None: await client.close() -async def test_static_file_name_too_long(aiohttp_client: Any) -> None: +async def test_static_file_name_too_long(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) @@ -201,7 +206,7 @@ async def test_static_file_name_too_long(aiohttp_client: Any) -> None: await client.close() -async def test_static_file_upper_directory(aiohttp_client: Any) -> None: +async def test_static_file_upper_directory(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) @@ -211,10 +216,10 @@ async def test_static_file_upper_directory(aiohttp_client: Any) -> None: await client.close() -async def test_static_file_with_content_type(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_with_content_type(aiohttp_client: AiohttpClient, sender: _Sender) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.jpg" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -236,11 +241,11 @@ async def handler(request): @pytest.mark.parametrize("hello_txt", ["gzip", "br"], indirect=True) async def test_static_file_custom_content_type( - hello_txt: pathlib.Path, aiohttp_client: Any, sender: Any + hello_txt: pathlib.Path, aiohttp_client: AiohttpClient, sender: _Sender ) -> None: """Test that custom type without encoding is returned for encoded request.""" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: resp = sender(hello_txt, chunk_size=16) resp.content_type = "application/pdf" return resp @@ -265,14 +270,14 @@ async def handler(request): ) async def test_static_file_custom_content_type_compress( hello_txt: pathlib.Path, - aiohttp_client: Any, - sender: Any, + aiohttp_client: AiohttpClient, + sender: _Sender, accept_encoding: str, expect_encoding: str, -): +) -> None: """Test that custom type with encoding is returned for unencoded requests.""" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: resp = sender(hello_txt, chunk_size=16) resp.content_type = "application/pdf" return resp @@ -298,15 +303,15 @@ async def handler(request): @pytest.mark.parametrize("forced_compression", [None, web.ContentCoding.gzip]) async def test_static_file_with_encoding_and_enable_compression( hello_txt: pathlib.Path, - aiohttp_client: Any, - sender: Any, + aiohttp_client: AiohttpClient, + sender: _Sender, accept_encoding: str, expect_encoding: str, forced_compression: Optional[web.ContentCoding], -): +) -> None: """Test that enable_compression does not double compress when an encoded file is also present.""" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: resp = sender(hello_txt) resp.enable_compression(forced_compression) return resp @@ -335,11 +340,11 @@ async def handler(request): indirect=["hello_txt"], ) async def test_static_file_with_content_encoding( - hello_txt: pathlib.Path, aiohttp_client: Any, sender: Any, expect_type: str + hello_txt: pathlib.Path, aiohttp_client: AiohttpClient, sender: _Sender, expect_type: str ) -> None: """Test requesting static compressed files returns the correct content type and encoding.""" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(hello_txt) app = web.Application() @@ -358,7 +363,7 @@ async def handler(request): async def test_static_file_if_modified_since( - aiohttp_client: Any, app_with_static_route: web.Application + aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) @@ -381,7 +386,7 @@ async def test_static_file_if_modified_since( async def test_static_file_if_modified_since_past_date( - aiohttp_client: Any, app_with_static_route: web.Application + aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) @@ -396,8 +401,8 @@ async def test_static_file_if_modified_since_past_date( async def test_static_file_if_modified_since_invalid_date( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "not a valid HTTP-date" @@ -411,8 +416,8 @@ async def test_static_file_if_modified_since_invalid_date( async def test_static_file_if_modified_since_future_date( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -431,7 +436,7 @@ async def test_static_file_if_modified_since_future_date( @pytest.mark.parametrize("if_unmodified_since", ("", "Fri, 31 Dec 0000 23:59:59 GMT")) async def test_static_file_if_match( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_unmodified_since: str, ) -> None: @@ -467,11 +472,11 @@ async def test_static_file_if_match( ], ) async def test_static_file_if_match_custom_tags( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_unmodified_since: str, - etags: Iterable[str], - expected_status: Iterable[int], + etags: Tuple[str], + expected_status: int, ) -> None: client = await aiohttp_client(app_with_static_route) @@ -496,7 +501,7 @@ async def test_static_file_if_match_custom_tags( ), ) async def test_static_file_if_none_match( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_modified_since: str, additional_etags: Iterable[str], @@ -529,7 +534,7 @@ async def test_static_file_if_none_match( async def test_static_file_if_none_match_star( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, ) -> None: client = await aiohttp_client(app_with_static_route) @@ -549,10 +554,10 @@ async def test_static_file_if_none_match_star( @pytest.mark.skipif(not ssl, reason="ssl not supported") async def test_static_file_ssl( - aiohttp_server: Any, - ssl_ctx: Any, - aiohttp_client: Any, - client_ssl_ctx: Any, + aiohttp_server: AiohttpServer, + ssl_ctx: ssl.SSLContext, + aiohttp_client: AiohttpClient, + client_ssl_ctx: ssl.SSLContext, ) -> None: dirname = pathlib.Path(__file__).parent filename = "data.unknown_mime_type" @@ -574,7 +579,7 @@ async def test_static_file_ssl( await client.close() -async def test_static_file_directory_traversal_attack(aiohttp_client: Any) -> None: +async def test_static_file_directory_traversal_attack(aiohttp_client: AiohttpClient) -> None: dirname = pathlib.Path(__file__).parent relpath = "../README.rst" full_path = dirname / relpath @@ -601,7 +606,7 @@ async def test_static_file_directory_traversal_attack(aiohttp_client: Any) -> No await client.close() -async def test_static_file_huge(aiohttp_client: Any, tmp_path: Any) -> None: +async def test_static_file_huge(aiohttp_client: AiohttpClient, tmp_path: pathlib.Path) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 20MB file @@ -620,29 +625,29 @@ async def test_static_file_huge(aiohttp_client: Any, tmp_path: Any) -> None: ct = resp.headers["CONTENT-TYPE"] assert "application/octet-stream" == ct assert resp.headers.get("CONTENT-ENCODING") is None - assert int(resp.headers.get("CONTENT-LENGTH")) == file_st.st_size + assert int(resp.headers["CONTENT-LENGTH"]) == file_st.st_size - f = file_path.open("rb") + f2 = file_path.open("rb") off = 0 cnt = 0 while off < file_st.st_size: chunk = await resp.content.readany() - expected = f.read(len(chunk)) + expected = f2.read(len(chunk)) assert chunk == expected off += len(chunk) cnt += 1 - f.close() + f2.close() resp.release() await client.close() -async def test_static_file_range(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_range(aiohttp_client: AiohttpClient, sender: _Sender) -> None: filepath = pathlib.Path(__file__).parent / "sample.txt" filesize = filepath.stat().st_size - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -696,10 +701,10 @@ async def handler(request): await client.close() -async def test_static_file_range_end_bigger_than_size(aiohttp_client: Any, sender: Any): +async def test_static_file_range_end_bigger_than_size(aiohttp_client: AiohttpClient, sender: _Sender) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -730,10 +735,10 @@ async def handler(request): await client.close() -async def test_static_file_range_beyond_eof(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_range_beyond_eof(aiohttp_client: AiohttpClient, sender: _Sender) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -751,10 +756,10 @@ async def handler(request): await client.close() -async def test_static_file_range_tail(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_range_tail(aiohttp_client: AiohttpClient, sender: _Sender) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -786,10 +791,10 @@ async def handler(request): await client.close() -async def test_static_file_invalid_range(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_invalid_range(aiohttp_client: AiohttpClient, sender: _Sender) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -836,8 +841,8 @@ async def handler(request): async def test_static_file_if_unmodified_since_past_with_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" @@ -853,8 +858,8 @@ async def test_static_file_if_unmodified_since_past_with_range( async def test_static_file_if_unmodified_since_future_with_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -872,8 +877,8 @@ async def test_static_file_if_unmodified_since_future_with_range( async def test_static_file_if_range_past_with_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" @@ -887,8 +892,8 @@ async def test_static_file_if_range_past_with_range( async def test_static_file_if_range_future_with_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -904,8 +909,8 @@ async def test_static_file_if_range_future_with_range( async def test_static_file_if_unmodified_since_past_without_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" @@ -919,8 +924,8 @@ async def test_static_file_if_unmodified_since_past_without_range( async def test_static_file_if_unmodified_since_future_without_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -935,8 +940,8 @@ async def test_static_file_if_unmodified_since_future_without_range( async def test_static_file_if_range_past_without_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" @@ -951,8 +956,8 @@ async def test_static_file_if_range_past_without_range( async def test_static_file_if_range_future_without_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -967,8 +972,8 @@ async def test_static_file_if_range_future_without_range( async def test_static_file_if_unmodified_since_invalid_date( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "not a valid HTTP-date" @@ -982,8 +987,8 @@ async def test_static_file_if_unmodified_since_invalid_date( async def test_static_file_if_range_invalid_date( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "not a valid HTTP-date" @@ -996,10 +1001,10 @@ async def test_static_file_if_range_invalid_date( await client.close() -async def test_static_file_compression(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_compression(aiohttp_client: AiohttpClient, sender: _Sender) -> None: filepath = pathlib.Path(__file__).parent / "data.unknown_mime_type" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: ret = sender(filepath) ret.enable_compression() return ret @@ -1020,7 +1025,7 @@ async def handler(request): await client.close() -async def test_static_file_huge_cancel(aiohttp_client: Any, tmp_path: Any) -> None: +async def test_static_file_huge_cancel(aiohttp_client: AiohttpClient, tmp_path: pathlib.Path) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 100MB file @@ -1030,11 +1035,12 @@ async def test_static_file_huge_cancel(aiohttp_client: Any, tmp_path: Any) -> No task = None - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: nonlocal task task = request.task # reduce send buffer size tr = request.transport + assert tr is not None sock = tr.get_extra_info("socket") sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) ret = web.FileResponse(file_path) @@ -1047,6 +1053,7 @@ async def handler(request): resp = await client.get("/") assert resp.status == 200 + assert task is not None task.cancel() await asyncio.sleep(0) data = b"" @@ -1061,7 +1068,7 @@ async def handler(request): await client.close() -async def test_static_file_huge_error(aiohttp_client: Any, tmp_path: Any) -> None: +async def test_static_file_huge_error(aiohttp_client: AiohttpClient, tmp_path: pathlib.Path) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 20MB file @@ -1069,9 +1076,10 @@ async def test_static_file_huge_error(aiohttp_client: Any, tmp_path: Any) -> Non f.seek(20 * 1024 * 1024) f.write(b"1") - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: # reduce send buffer size tr = request.transport + assert tr is not None sock = tr.get_extra_info("socket") sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) ret = web.FileResponse(file_path) diff --git a/tests/test_web_server.py b/tests/test_web_server.py index 619b83e947c..0cd462af3b6 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -1,16 +1,16 @@ -# type: ignore import asyncio from contextlib import suppress -from typing import Any +from typing import Callable, NoReturn from unittest import mock import pytest from aiohttp import client, web +from aiohttp.pytest_plugin import AiohttpClient, AiohttpRawServer -async def test_simple_server(aiohttp_raw_server: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_simple_server(aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.BaseRequest) -> web.Response: return web.Response(text=str(request.rel_url)) server = await aiohttp_raw_server(handler) @@ -21,10 +21,10 @@ async def handler(request): assert txt == "/path/to" -async def test_unsupported_upgrade(aiohttp_raw_server, aiohttp_client) -> None: +async def test_unsupported_upgrade(aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient) -> None: # don't fail if a client probes for an unsupported protocol upgrade # https://github.com/aio-libs/aiohttp/issues/6446#issuecomment-999032039 - async def handler(request: web.Request): + async def handler(request: web.BaseRequest) -> web.Response: return web.Response(body=await request.read()) upgrade_headers = {"Connection": "Upgrade", "Upgrade": "unsupported_proto"} @@ -38,14 +38,14 @@ async def handler(request: web.Request): async def test_raw_server_not_http_exception( - aiohttp_raw_server: Any, aiohttp_client: Any, loop: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop ) -> None: # disable debug mode not to print traceback loop.set_debug(False) exc = RuntimeError("custom runtime error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc logger = mock.Mock() @@ -63,13 +63,13 @@ async def handler(request): async def test_raw_server_handler_timeout( - aiohttp_raw_server: Any, aiohttp_client: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: loop = asyncio.get_event_loop() loop.set_debug(True) exc = asyncio.TimeoutError("error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc logger = mock.Mock() @@ -83,9 +83,9 @@ async def handler(request): async def test_raw_server_do_not_swallow_exceptions( - aiohttp_raw_server: Any, aiohttp_client: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise asyncio.CancelledError() loop = asyncio.get_event_loop() @@ -101,13 +101,13 @@ async def handler(request): async def test_raw_server_cancelled_in_write_eof( - aiohttp_raw_server: Any, aiohttp_client: Any -): + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient +) -> None: class MyResponse(web.Response): - async def write_eof(self, data=b""): + async def write_eof(self, data: bytes = b"") -> NoReturn: raise asyncio.CancelledError("error") - async def handler(request): + async def handler(request: web.BaseRequest) -> MyResponse: resp = MyResponse(text=str(request.rel_url)) return resp @@ -125,11 +125,11 @@ async def handler(request): async def test_raw_server_not_http_exception_debug( - aiohttp_raw_server: Any, aiohttp_client: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: exc = RuntimeError("custom runtime error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc loop = asyncio.get_event_loop() @@ -148,14 +148,14 @@ async def handler(request): async def test_raw_server_html_exception( - aiohttp_raw_server: Any, aiohttp_client: Any, loop: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop ) -> None: # disable debug mode not to print traceback loop.set_debug(False) exc = RuntimeError("custom runtime error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc logger = mock.Mock() @@ -177,11 +177,11 @@ async def handler(request): async def test_raw_server_html_exception_debug( - aiohttp_raw_server: Any, aiohttp_client: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: exc = RuntimeError("custom runtime error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc loop = asyncio.get_event_loop() @@ -204,11 +204,11 @@ async def handler(request): logger.exception.assert_called_with("Error handling request", exc_info=exc) -async def test_handler_cancellation(aiohttp_unused_port) -> None: +async def test_handler_cancellation(aiohttp_unused_port: Callable[[], int]) -> None: event = asyncio.Event() port = aiohttp_unused_port() - async def on_request(_: web.Request) -> web.Response: + async def on_request(request: web.Request) -> web.Response: nonlocal event try: await asyncio.sleep(10) @@ -228,6 +228,7 @@ async def on_request(_: web.Request) -> web.Response: await site.start() + assert runner.server is not None try: assert runner.server.handler_cancellation, "Flag was not propagated" @@ -244,13 +245,13 @@ async def on_request(_: web.Request) -> web.Response: await asyncio.gather(runner.shutdown(), site.stop()) -async def test_no_handler_cancellation(aiohttp_unused_port) -> None: +async def test_no_handler_cancellation(aiohttp_unused_port: Callable[[], int]) -> None: timeout_event = asyncio.Event() done_event = asyncio.Event() port = aiohttp_unused_port() started = False - async def on_request(_: web.Request) -> web.Response: + async def on_request(request: web.Request) -> web.Response: nonlocal done_event, started, timeout_event started = True await asyncio.wait_for(timeout_event.wait(), timeout=5) diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 067634f0a4d..ff6ad25ae55 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -1,40 +1,42 @@ -# type: ignore import asyncio import time -from typing import Any +from typing import Optional, Protocol from unittest import mock import aiosignal import pytest from multidict import CIMultiDict +from pytest_mock import MockerFixture -from aiohttp import WSMsgType +from aiohttp import WSMsgType, web +from aiohttp.http import WS_CLOSED_MESSAGE, WSMessage from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro, make_mocked_request -from aiohttp.web import HTTPBadRequest, WebSocketResponse -from aiohttp.web_ws import WS_CLOSED_MESSAGE, WebSocketReady, WSMessage +from aiohttp.typedefs import LooseHeaders +from aiohttp.web_ws import WebSocketReady + +class _RequestMaker(Protocol): + def __call__(self, method: str, path: str, headers: Optional[CIMultiDict[str]] = None, protocols: bool = False) -> web.Request: + ... @pytest.fixture -def app(loop: Any): - ret = mock.Mock() - ret.loop = loop - ret._debug = False - ret.on_response_prepare = aiosignal.Signal(ret) +def app(loop: asyncio.AbstractEventLoop) -> web.Application: + ret: web.Application = mock.create_autospec(web.Application, spec_set=True, on_response_prepare=aiosignal.Signal(ret)) ret.on_response_prepare.freeze() return ret @pytest.fixture -def protocol(): +def protocol() -> web.RequestHandler[web.Request]: ret = mock.Mock() ret.set_parser.return_value = ret return ret @pytest.fixture -def make_request(app: Any, protocol: Any): - def maker(method, path, headers=None, protocols=False): +def make_request(app: web.Application, protocol: web.RequestHandler[web.Request]) -> _RequestMaker: + def maker(method: str, path: str, headers: Optional[CIMultiDict[str]] = None, protocols: bool = False) -> web.Request: if headers is None: headers = CIMultiDict( { @@ -49,107 +51,106 @@ def maker(method, path, headers=None, protocols=False): if protocols: headers["SEC-WEBSOCKET-PROTOCOL"] = "chat, superchat" - return make_mocked_request( - method, path, headers, app=app, protocol=protocol, loop=app.loop - ) + return make_mocked_request(method, path, headers, app=app, protocol=protocol) return maker async def test_nonstarted_ping() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.ping() async def test_nonstarted_pong() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.pong() async def test_nonstarted_send_str() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_str("string") async def test_nonstarted_send_bytes() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_bytes(b"bytes") async def test_nonstarted_send_json() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_json({"type": "json"}) async def test_nonstarted_close() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.close() async def test_nonstarted_receive_str() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.receive_str() async def test_nonstarted_receive_bytes() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.receive_bytes() async def test_nonstarted_receive_json() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.receive_json() -async def test_send_str_nonstring(make_request: Any) -> None: +async def test_send_str_nonstring(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): - await ws.send_str(b"bytes") + await ws.send_str(b"bytes") # type: ignore[arg-type] -async def test_send_bytes_nonbytes(make_request: Any) -> None: +async def test_send_bytes_nonbytes(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): - await ws.send_bytes("string") + await ws.send_bytes("string") # type: ignore[arg-type] -async def test_send_json_nonjson(make_request: Any) -> None: +async def test_send_json_nonjson(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): await ws.send_json(set()) async def test_write_non_prepared() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.write(b"data") -async def test_heartbeat_timeout(make_request: Any) -> None: +async def test_heartbeat_timeout(make_request: _RequestMaker) -> None: """Verify the transport is closed when the heartbeat timeout is reached.""" loop = asyncio.get_running_loop() future = loop.create_future() req = make_request("GET", "/") + assert req.transport is not None + req.transport.close.side_effect = lambda: future.set_result(None) # type: ignore[attr-defined] lowest_time = time.get_clock_info("monotonic").resolution req._protocol._timeout_ceil_threshold = lowest_time - ws = WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time) + ws = web.WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time) await ws.prepare(req) - ws._req.transport.close.side_effect = lambda: future.set_result(None) await future assert ws.closed @@ -182,27 +183,27 @@ def test_bool_websocket_not_ready() -> None: assert bool(websocket_ready) is False -def test_can_prepare_ok(make_request: Any) -> None: +def test_can_prepare_ok(make_request: _RequestMaker) -> None: req = make_request("GET", "/", protocols=True) - ws = WebSocketResponse(protocols=("chat",)) + ws = web.WebSocketResponse(protocols=("chat",)) assert WebSocketReady(True, "chat") == ws.can_prepare(req) -def test_can_prepare_unknown_protocol(make_request: Any) -> None: +def test_can_prepare_unknown_protocol(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() assert WebSocketReady(True, None) == ws.can_prepare(req) -def test_can_prepare_without_upgrade(make_request: Any) -> None: +def test_can_prepare_without_upgrade(make_request: _RequestMaker) -> None: req = make_request("GET", "/", headers=CIMultiDict({})) - ws = WebSocketResponse() + ws = web.WebSocketResponse() assert WebSocketReady(False, None) == ws.can_prepare(req) -async def test_can_prepare_started(make_request: Any) -> None: +async def test_can_prepare_started(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(RuntimeError) as ctx: ws.can_prepare(req) @@ -211,27 +212,30 @@ async def test_can_prepare_started(make_request: Any) -> None: def test_closed_after_ctor() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() assert not ws.closed assert ws.close_code is None -async def test_send_str_closed(make_request: Any) -> None: +async def test_send_str_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() - assert len(ws._req.transport.close.mock_calls) == 1 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] with pytest.raises(ConnectionError): await ws.send_str("string") -async def test_send_bytes_closed(make_request: Any) -> None: +async def test_send_bytes_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -239,10 +243,11 @@ async def test_send_bytes_closed(make_request: Any) -> None: await ws.send_bytes(b"bytes") -async def test_send_json_closed(make_request: Any) -> None: +async def test_send_json_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -250,10 +255,11 @@ async def test_send_json_closed(make_request: Any) -> None: await ws.send_json({"type": "json"}) -async def test_ping_closed(make_request: Any) -> None: +async def test_ping_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -261,10 +267,11 @@ async def test_ping_closed(make_request: Any) -> None: await ws.ping() -async def test_pong_closed(make_request: Any, mocker: Any) -> None: +async def test_pong_closed(make_request: _RequestMaker, mocker: MockerFixture) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -272,62 +279,66 @@ async def test_pong_closed(make_request: Any, mocker: Any) -> None: await ws.pong() -async def test_close_idempotent(make_request: Any) -> None: +async def test_close_idempotent(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) - assert await ws.close(code=1, message="message1") + assert await ws.close(code=1, message=b"message1") assert ws.closed - assert len(ws._req.transport.close.mock_calls) == 1 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] - assert not (await ws.close(code=2, message="message2")) + assert not (await ws.close(code=2, message=b"message2")) -async def test_prepare_post_method_ok(make_request: Any) -> None: +async def test_prepare_post_method_ok(make_request: _RequestMaker) -> None: req = make_request("POST", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) assert ws.prepared -async def test_prepare_without_upgrade(make_request: Any) -> None: +async def test_prepare_without_upgrade(make_request: _RequestMaker) -> None: req = make_request("GET", "/", headers=CIMultiDict({})) - ws = WebSocketResponse() - with pytest.raises(HTTPBadRequest): + ws = web.WebSocketResponse() + with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_wait_closed_before_start() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.close() async def test_write_eof_not_started() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.write_eof() -async def test_write_eof_idempotent(make_request: Any) -> None: +async def test_write_eof_idempotent(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) - assert len(ws._req.transport.close.mock_calls) == 0 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() await ws.write_eof() await ws.write_eof() await ws.write_eof() - assert len(ws._req.transport.close.mock_calls) == 1 + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] -async def test_receive_eofstream_in_reader(make_request: Any, loop: Any) -> None: +async def test_receive_eofstream_in_reader(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) ws._reader = mock.Mock() @@ -335,18 +346,18 @@ async def test_receive_eofstream_in_reader(make_request: Any, loop: Any) -> None res = loop.create_future() res.set_exception(exc) ws._reader.read = make_mocked_coro(res) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) - - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSED - assert ws.closed + assert ws._payload_writer is not None + f = loop.create_future() + f.set_result(True) + with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSED + assert ws.closed -async def test_receive_exception_in_reader(make_request: Any, loop: Any) -> None: +async def test_receive_exception_in_reader(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) ws._reader = mock.Mock() @@ -354,88 +365,92 @@ async def test_receive_exception_in_reader(make_request: Any, loop: Any) -> None res = loop.create_future() res.set_exception(exc) ws._reader.read = make_mocked_coro(res) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) - msg = await ws.receive() - assert msg.type == WSMsgType.ERROR - assert ws.closed - assert len(ws._req.transport.close.mock_calls) == 1 + f = loop.create_future() + f.set_result(True) + with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + msg = await ws.receive() + assert msg.type == WSMsgType.ERROR + assert ws.closed + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] -async def test_receive_close_but_left_open(make_request: Any, loop: Any) -> None: +async def test_receive_close_but_left_open(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) close_message = WSMessage(WSMsgType.CLOSE, 1000, "close") ws._reader = mock.Mock() ws._reader.read = mock.AsyncMock(return_value=close_message) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSE - assert ws.closed - assert len(ws._req.transport.close.mock_calls) == 1 + f = loop.create_future() + f.set_result(True) + with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSE + assert ws.closed + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] -async def test_receive_closing(make_request: Any, loop: Any) -> None: +async def test_receive_closing(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) closing_message = WSMessage(WSMsgType.CLOSING, 1000, "closing") ws._reader = mock.Mock() read_mock = mock.AsyncMock(return_value=closing_message) ws._reader.read = read_mock - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING - assert not ws.closed + f = loop.create_future() + f.set_result(True) + with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + assert not ws.closed - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING - assert not ws.closed + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + assert not ws.closed - ws._cancel(ConnectionResetError("Connection lost")) + ws._cancel(ConnectionResetError("Connection lost")) - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING -async def test_close_after_closing(make_request: Any, loop: Any) -> None: +async def test_close_after_closing(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) closing_message = WSMessage(WSMsgType.CLOSING, 1000, "closing") ws._reader = mock.Mock() ws._reader.read = mock.AsyncMock(return_value=closing_message) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING - assert not ws.closed - assert len(ws._req.transport.close.mock_calls) == 0 + f = loop.create_future() + f.set_result(True) + with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + assert not ws.closed + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] - await ws.close() - assert ws.closed - assert len(ws._req.transport.close.mock_calls) == 1 + await ws.close() + assert ws.closed + assert len(req.transport.close.mock_calls) == 1 -async def test_receive_timeouterror(make_request: Any, loop: Any) -> None: +async def test_receive_timeouterror(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) - assert len(ws._req.transport.close.mock_calls) == 0 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] ws._reader = mock.Mock() res = loop.create_future() @@ -446,13 +461,14 @@ async def test_receive_timeouterror(make_request: Any, loop: Any) -> None: await ws.receive() # Should not close the connection on timeout - assert len(ws._req.transport.close.mock_calls) == 0 + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] -async def test_multiple_receive_on_close_connection(make_request: Any) -> None: +async def test_multiple_receive_on_close_connection(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -465,9 +481,9 @@ async def test_multiple_receive_on_close_connection(make_request: Any) -> None: await ws.receive() -async def test_concurrent_receive(make_request: Any) -> None: +async def test_concurrent_receive(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) ws._waiting = True @@ -475,11 +491,12 @@ async def test_concurrent_receive(make_request: Any) -> None: await ws.receive() -async def test_close_exc(make_request: Any) -> None: +async def test_close_exc(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) - assert len(ws._req.transport.close.mock_calls) == 0 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] exc = ValueError() ws._writer = mock.Mock() @@ -487,7 +504,7 @@ async def test_close_exc(make_request: Any) -> None: await ws.close() assert ws.closed assert ws.exception() is exc - assert len(ws._req.transport.close.mock_calls) == 1 + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] ws._closed = False ws._writer.close.side_effect = asyncio.CancelledError() @@ -495,34 +512,33 @@ async def test_close_exc(make_request: Any) -> None: await ws.close() -async def test_prepare_twice_idempotent(make_request: Any) -> None: +async def test_prepare_twice_idempotent(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() impl1 = await ws.prepare(req) impl2 = await ws.prepare(req) assert impl1 is impl2 -async def test_send_with_per_message_deflate(make_request: Any, mocker: Any) -> None: +async def test_send_with_per_message_deflate(make_request: _RequestMaker, mocker: MockerFixture) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) - writer_send = ws._writer.send = make_mocked_coro() - - await ws.send_str("string", compress=15) - writer_send.assert_called_with("string", binary=False, compress=15) + with mock.patch.object(ws._writer, "send", autospec=True, spec_set=True) as m: + await ws.send_str("string", compress=15) + m.assert_called_with("string", binary=False, compress=15) - await ws.send_bytes(b"bytes", compress=0) - writer_send.assert_called_with(b"bytes", binary=True, compress=0) + await ws.send_bytes(b"bytes", compress=0) + m.assert_called_with(b"bytes", binary=True, compress=0) - await ws.send_json("[{}]", compress=9) - writer_send.assert_called_with('"[{}]"', binary=False, compress=9) + await ws.send_json("[{}]", compress=9) + m.assert_called_with('"[{}]"', binary=False, compress=9) -async def test_no_transfer_encoding_header(make_request: Any, mocker: Any) -> None: +async def test_no_transfer_encoding_header(make_request: _RequestMaker, mocker: MockerFixture) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws._start(req) assert "Transfer-Encoding" not in ws.headers @@ -546,13 +562,13 @@ async def test_no_transfer_encoding_header(make_request: Any, mocker: Any) -> No ], ) async def test_get_extra_info( - make_request: Any, mocker: Any, ws_transport: Any, expected_result: Any + make_request: _RequestMaker, mocker: MockerFixture, ws_transport: Optional[mock.MagicMock], expected_result: str ) -> None: valid_key = "test" default_value = "default" req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) ws._writer = ws_transport diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index ff0a4cc6b47..f28ec42fa8b 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1,11 +1,10 @@ -# type: ignore # HTTP websocket server functional tests import asyncio import contextlib import sys import weakref -from typing import Any, NoReturn, Optional +from typing import NoReturn, Optional from unittest import mock import pytest @@ -13,11 +12,11 @@ import aiohttp from aiohttp import WSServerHandshakeError, web from aiohttp.http import WSCloseCode, WSMsgType -from aiohttp.pytest_plugin import AiohttpClient +from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer -async def test_websocket_can_prepare(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_can_prepare(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.Response: ws = web.WebSocketResponse() if not ws.can_prepare(request): raise web.HTTPUpgradeRequired() @@ -32,11 +31,11 @@ async def handler(request): assert resp.status == 426 -async def test_websocket_json(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_json(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): - return web.HTTPUpgradeRequired() + raise web.HTTPUpgradeRequired() await ws.prepare(request) msg = await ws.receive() @@ -61,8 +60,8 @@ async def handler(request): assert resp.data == expected_value -async def test_websocket_json_invalid_message(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_json_invalid_message(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) try: @@ -87,8 +86,8 @@ async def handler(request): assert "ValueError was raised" in data -async def test_websocket_send_json(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_send_json(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -110,8 +109,8 @@ async def handler(request): assert data["test"] == expected_value -async def test_websocket_receive_json(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_receive_json(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -135,10 +134,10 @@ async def handler(request): assert resp.data == expected_value -async def test_send_recv_text(loop: Any, aiohttp_client: Any) -> None: +async def test_send_recv_text(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() @@ -168,10 +167,10 @@ async def handler(request): await closed -async def test_send_recv_bytes(loop: Any, aiohttp_client: Any) -> None: +async def test_send_recv_bytes(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -202,10 +201,10 @@ async def handler(request): await closed -async def test_send_recv_json(loop: Any, aiohttp_client: Any) -> None: +async def test_send_recv_json(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) data = await ws.receive_json() @@ -236,16 +235,17 @@ async def handler(request): await closed -async def test_close_timeout(loop: Any, aiohttp_client: Any) -> None: +async def test_close_timeout(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: aborted = loop.create_future() elapsed = 1e10 # something big - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal elapsed ws = web.WebSocketResponse(timeout=0.1) await ws.prepare(request) assert "request" == (await ws.receive_str()) await ws.send_str("reply") + assert ws._loop is not None begin = ws._loop.time() assert await ws.close() elapsed = ws._loop.time() - begin @@ -277,10 +277,10 @@ async def handler(request): await ws.close() -async def test_concurrent_close(loop: Any, aiohttp_client: Any) -> None: +async def test_concurrent_close(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: srv_ws = None - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) await ws.prepare(request) @@ -304,6 +304,7 @@ async def handler(request): ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) + assert srv_ws is not None await srv_ws.close(code=WSCloseCode.INVALID_TEXT) msg = await ws.receive() @@ -314,10 +315,10 @@ async def handler(request): assert msg.type == WSMsgType.CLOSED -async def test_concurrent_close_multiple_tasks(loop: Any, aiohttp_client: Any) -> None: +async def test_concurrent_close_multiple_tasks(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: srv_ws = None - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) await ws.prepare(request) @@ -341,6 +342,7 @@ async def handler(request): ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) + assert srv_ws is not None task1 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) task2 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) @@ -355,10 +357,12 @@ async def handler(request): assert msg.type == WSMsgType.CLOSED -async def test_close_op_code_from_client(loop: Any, aiohttp_client: Any) -> None: +async def test_close_op_code_from_client( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: srv_ws: Optional[web.WebSocketResponse] = None - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws ws = srv_ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) @@ -372,7 +376,7 @@ async def handler(request): app.router.add_get("/", handler) client = await aiohttp_client(app) - ws: web.WebSocketResponse = await client.ws_connect("/", protocols=("eggs", "bar")) + ws = await client.ws_connect("/", protocols=("eggs", "bar")) await ws._writer._send_frame(b"", WSMsgType.CLOSE) @@ -384,10 +388,10 @@ async def handler(request): assert msg.type == WSMsgType.CLOSED -async def test_auto_pong_with_closing_by_peer(loop: Any, aiohttp_client: Any) -> None: +async def test_auto_pong_with_closing_by_peer(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive() @@ -409,18 +413,18 @@ async def handler(request): msg = await ws.receive() assert msg.type == WSMsgType.PONG - await ws.close(code=WSCloseCode.OK, message="exit message") + await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed -async def test_ping(loop: Any, aiohttp_client: Any) -> None: +async def test_ping(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) - await ws.ping("data") + await ws.ping(b"data") await ws.receive() closed.set_result(None) return ws @@ -439,10 +443,10 @@ async def handler(request): await closed -async def aiohttp_client_ping(loop: Any, aiohttp_client: Any): +async def test_client_ping(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -456,7 +460,7 @@ async def handler(request): ws = await client.ws_connect("/", autoping=False) - await ws.ping("data") + await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" @@ -464,16 +468,16 @@ async def handler(request): await ws.close() -async def test_pong(loop: Any, aiohttp_client: Any) -> None: +async def test_pong(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) msg = await ws.receive() assert msg.type == WSMsgType.PING - await ws.pong("data") + await ws.pong(b"data") msg = await ws.receive() assert msg.type == WSMsgType.CLOSE @@ -488,20 +492,20 @@ async def handler(request): ws = await client.ws_connect("/", autoping=False) - await ws.ping("data") + await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" - await ws.close(code=WSCloseCode.OK, message="exit message") + await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed -async def test_change_status(loop: Any, aiohttp_client: Any) -> None: +async def test_change_status(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() ws.set_status(200) assert 200 == ws.status @@ -522,10 +526,10 @@ async def handler(request): await ws.close() -async def test_handle_protocol(loop: Any, aiohttp_client: Any) -> None: +async def test_handle_protocol(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) await ws.close() @@ -543,10 +547,10 @@ async def handler(request): await closed -async def test_server_close_handshake(loop: Any, aiohttp_client: Any) -> None: +async def test_server_close_handshake(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) await ws.close() @@ -565,10 +569,10 @@ async def handler(request): await closed -async def aiohttp_client_close_handshake(loop: Any, aiohttp_client: Any): +async def test_client_close_handshake(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) await ws.prepare(request) @@ -598,11 +602,11 @@ async def handler(request): async def test_server_close_handshake_server_eats_client_messages( - loop: Any, aiohttp_client: Any -): + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) await ws.close() @@ -628,10 +632,10 @@ async def handler(request): await closed -async def test_receive_timeout(loop: Any, aiohttp_client: Any) -> None: +async def test_receive_timeout(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: raised = False - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(receive_timeout=0.1) await ws.prepare(request) @@ -654,10 +658,10 @@ async def handler(request): assert raised -async def test_custom_receive_timeout(loop: Any, aiohttp_client: Any) -> None: +async def test_custom_receive_timeout(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: raised = False - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(receive_timeout=None) await ws.prepare(request) @@ -680,8 +684,8 @@ async def handler(request): assert raised -async def test_heartbeat(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_heartbeat(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) await ws.receive() @@ -700,8 +704,8 @@ async def handler(request): await ws.close() -async def test_heartbeat_no_pong(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_heartbeat_no_pong(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) @@ -731,6 +735,8 @@ async def handler(request: web.Request) -> NoReturn: # We patch write here to simulate a connection reset error # since if we closed the connection normally, the server would # would cancel the heartbeat task and we wouldn't get a ping + assert ws_server._req is not None + assert ws_server._writer is not None with mock.patch.object( ws_server._req.transport, "write", side_effect=ConnectionResetError ), mock.patch.object( @@ -789,11 +795,11 @@ async def handler(request: web.Request) -> NoReturn: async def test_heartbeat_no_pong_send_many_messages( - loop: Any, aiohttp_client: Any + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test no pong after sending many messages.""" - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) for _ in range(10): @@ -818,11 +824,11 @@ async def handler(request): async def test_heartbeat_no_pong_receive_many_messages( - loop: Any, aiohttp_client: Any + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test no pong after receiving many messages.""" - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) for _ in range(10): @@ -845,10 +851,10 @@ async def handler(request): await ws.close() -async def test_server_ws_async_for(loop: Any, aiohttp_server: Any) -> None: +async def test_server_ws_async_for(loop: asyncio.AbstractEventLoop, aiohttp_server: AiohttpServer) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: @@ -876,10 +882,10 @@ async def handler(request): await closed -async def test_closed_async_for(loop: Any, aiohttp_client: Any) -> None: +async def test_closed_async_for(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -911,8 +917,8 @@ async def handler(request): await closed -async def test_websocket_disable_keepalive(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_disable_keepalive(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.StreamResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): return web.Response(text="OK") @@ -938,11 +944,11 @@ async def handler(request): assert data == "OK" -async def test_receive_str_nonstring(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_receive_str_nonstring(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): - return web.HTTPUpgradeRequired() + raise web.HTTPUpgradeRequired() await ws.prepare(request) await ws.send_bytes(b"answer") @@ -958,14 +964,14 @@ async def handler(request): await ws.receive_str() -async def test_receive_bytes_nonbytes(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_receive_bytes_nonbytes(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): - return web.HTTPUpgradeRequired() + raise web.HTTPUpgradeRequired() await ws.prepare(request) - await ws.send_bytes("answer") + await ws.send_bytes(b"answer") await ws.close() return ws @@ -978,11 +984,11 @@ async def handler(request): await ws.receive_bytes() -async def test_bug3380(loop: Any, aiohttp_client: Any) -> None: - async def handle_null(request): - return aiohttp.web.json_response({"err": None}) +async def test_bug3380(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: + async def handle_null(request: web.Request) -> web.Response: + return web.json_response({"err": None}) - async def ws_handler(request): + async def ws_handler(request: web.Request) -> web.Response: return web.Response(status=401) app = web.Application() @@ -1004,11 +1010,11 @@ async def ws_handler(request): async def test_receive_being_cancelled_keeps_connection_open( - loop: Any, aiohttp_client: Any + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) @@ -1021,7 +1027,7 @@ async def handler(request): msg = await ws.receive() assert msg.type == WSMsgType.PING await asyncio.sleep(0) - await ws.pong("data") + await ws.pong(b"data") msg = await ws.receive() assert msg.type == WSMsgType.CLOSE @@ -1037,24 +1043,24 @@ async def handler(request): ws = await client.ws_connect("/", autoping=False) await asyncio.sleep(0) - await ws.ping("data") + await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" - await ws.close(code=WSCloseCode.OK, message="exit message") + await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed async def test_receive_timeout_keeps_connection_open( - loop: Any, aiohttp_client: Any + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() timed_out = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) @@ -1067,7 +1073,7 @@ async def handler(request): msg = await ws.receive() assert msg.type == WSMsgType.PING await asyncio.sleep(0) - await ws.pong("data") + await ws.pong(b"data") msg = await ws.receive() assert msg.type == WSMsgType.CLOSE @@ -1083,13 +1089,13 @@ async def handler(request): ws = await client.ws_connect("/", autoping=False) await timed_out - await ws.ping("data") + await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" - await ws.close(code=WSCloseCode.OK, message="exit message") + await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed @@ -1098,11 +1104,11 @@ async def test_websocket_shutdown(aiohttp_client: AiohttpClient) -> None: """Test that the client websocket gets the close message when the server is shutting down.""" url = "/ws" app = web.Application() - websockets = web.AppKey("websockets", weakref.WeakSet) + websockets = web.AppKey("websockets", weakref.WeakSet[web.WebSocketResponse]) app[websockets] = weakref.WeakSet() # need for send signal shutdown server - shutdown_websockets = web.AppKey("shutdown_websockets", weakref.WeakSet) + shutdown_websockets = web.AppKey("shutdown_websockets", weakref.WeakSet[web.WebSocketResponse]) app[shutdown_websockets] = weakref.WeakSet() async def websocket_handler(request: web.Request) -> web.WebSocketResponse: @@ -1124,7 +1130,7 @@ async def on_shutdown(app: web.Application) -> None: websocket = app[shutdown_websockets].pop() await websocket.close( code=aiohttp.WSCloseCode.GOING_AWAY, - message="Server shutdown", + message=b"Server shutdown", ) app.router.add_get(url, websocket_handler) From 12e97c1639f692b54357005a8d6a6803c5a4df5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Sep 2024 15:43:13 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_web_runner.py | 23 +++-- tests/test_web_sendfile_functional.py | 64 +++++++++---- tests/test_web_server.py | 16 +++- tests/test_web_websocket.py | 86 +++++++++++++----- tests/test_web_websocket_functional.py | 120 ++++++++++++++++++------- 5 files changed, 232 insertions(+), 77 deletions(-) diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index 4bf0a9fc796..ba0f18a8783 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -13,8 +13,7 @@ class _RunnerMaker(Protocol): - def __call__(self, handle_signals: bool = ..., **kwargs: Any) -> web.AppRunner: - ... + def __call__(self, handle_signals: bool = ..., **kwargs: Any) -> web.AppRunner: ... @pytest.fixture @@ -23,7 +22,9 @@ def app() -> web.Application: @pytest.fixture -def make_runner(loop: asyncio.AbstractEventLoop, app: web.Application) -> Iterator[_RunnerMaker]: +def make_runner( + loop: asyncio.AbstractEventLoop, app: web.Application +) -> Iterator[_RunnerMaker]: asyncio.set_event_loop(loop) runners = [] @@ -138,7 +139,9 @@ async def test_app_handler_args_failure() -> None: ("2", 2), ), ) -async def test_app_handler_args_ceil_threshold(value: Union[int, str, None], expected: int) -> None: +async def test_app_handler_args_ceil_threshold( + value: Union[int, str, None], expected: int +) -> None: app = web.Application(handler_args={"timeout_ceil_threshold": value}) runner = web.AppRunner(app) await runner.setup() @@ -171,7 +174,9 @@ class Logger: async def test_app_make_handler_access_log_class1() -> None: class Logger(AbstractAccessLogger): - def log(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> None: + def log( + self, request: web.BaseRequest, response: web.StreamResponse, time: float + ) -> None: """Pass log method.""" app = web.Application() @@ -181,7 +186,9 @@ def log(self, request: web.BaseRequest, response: web.StreamResponse, time: floa async def test_app_make_handler_access_log_class2() -> None: class Logger(AbstractAccessLogger): - def log(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> None: + def log( + self, request: web.BaseRequest, response: web.StreamResponse, time: float + ) -> None: """Pass log method.""" app = web.Application(handler_args={"access_log_class": Logger}) @@ -234,7 +241,9 @@ async def test_tcpsite_default_host(make_runner: _RunnerMaker) -> None: assert site.name == "http://0.0.0.0:8080" m = mock.create_autospec(asyncio.AbstractEventLoop, spec_set=True, instance=True) - with mock.patch("asyncio.get_event_loop", autospec=True, spec_set=True, return_value=m) as mock_get_loop: + with mock.patch( + "asyncio.get_event_loop", autospec=True, spec_set=True, return_value=m + ) as mock_get_loop: await site.start() m.create_server.assert_called_once() diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index f8e2ee5fc52..dd8ec29f07a 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -33,7 +33,9 @@ @pytest.fixture(scope="module") -def hello_txt(request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathFactory) -> pathlib.Path: +def hello_txt( + request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathFactory +) -> pathlib.Path: """Create a temp path with hello.txt and compressed versions. The uncompressed text file path is returned by default. Alternatively, an @@ -55,7 +57,9 @@ def hello_txt(request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathF @pytest.fixture -def loop_with_mocked_native_sendfile(loop: asyncio.AbstractEventLoop) -> Iterator[asyncio.AbstractEventLoop]: +def loop_with_mocked_native_sendfile( + loop: asyncio.AbstractEventLoop, +) -> Iterator[asyncio.AbstractEventLoop]: def sendfile(transport: object, fobj: object, offset: int, count: int) -> NoReturn: if count == 0: raise ValueError("count must be a positive integer (got 0)") @@ -117,7 +121,9 @@ async def test_static_file_ok( await client.close() -async def test_zero_bytes_file_ok(aiohttp_client: AiohttpClient, sender: _Sender) -> None: +async def test_zero_bytes_file_ok( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "data.zero_bytes" async def handler(request: web.Request) -> web.FileResponse: @@ -143,7 +149,8 @@ async def handler(request: web.Request) -> web.FileResponse: async def test_zero_bytes_file_mocked_native_sendfile( - aiohttp_client: AiohttpClient, loop_with_mocked_native_sendfile: asyncio.AbstractEventLoop + aiohttp_client: AiohttpClient, + loop_with_mocked_native_sendfile: asyncio.AbstractEventLoop, ) -> None: filepath = pathlib.Path(__file__).parent / "data.zero_bytes" @@ -216,7 +223,9 @@ async def test_static_file_upper_directory(aiohttp_client: AiohttpClient) -> Non await client.close() -async def test_static_file_with_content_type(aiohttp_client: AiohttpClient, sender: _Sender) -> None: +async def test_static_file_with_content_type( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.jpg" async def handler(request: web.Request) -> web.FileResponse: @@ -340,7 +349,10 @@ async def handler(request: web.Request) -> web.FileResponse: indirect=["hello_txt"], ) async def test_static_file_with_content_encoding( - hello_txt: pathlib.Path, aiohttp_client: AiohttpClient, sender: _Sender, expect_type: str + hello_txt: pathlib.Path, + aiohttp_client: AiohttpClient, + sender: _Sender, + expect_type: str, ) -> None: """Test requesting static compressed files returns the correct content type and encoding.""" @@ -579,7 +591,9 @@ async def test_static_file_ssl( await client.close() -async def test_static_file_directory_traversal_attack(aiohttp_client: AiohttpClient) -> None: +async def test_static_file_directory_traversal_attack( + aiohttp_client: AiohttpClient, +) -> None: dirname = pathlib.Path(__file__).parent relpath = "../README.rst" full_path = dirname / relpath @@ -606,7 +620,9 @@ async def test_static_file_directory_traversal_attack(aiohttp_client: AiohttpCli await client.close() -async def test_static_file_huge(aiohttp_client: AiohttpClient, tmp_path: pathlib.Path) -> None: +async def test_static_file_huge( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 20MB file @@ -642,7 +658,9 @@ async def test_static_file_huge(aiohttp_client: AiohttpClient, tmp_path: pathlib await client.close() -async def test_static_file_range(aiohttp_client: AiohttpClient, sender: _Sender) -> None: +async def test_static_file_range( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "sample.txt" filesize = filepath.stat().st_size @@ -701,7 +719,9 @@ async def handler(request: web.Request) -> web.FileResponse: await client.close() -async def test_static_file_range_end_bigger_than_size(aiohttp_client: AiohttpClient, sender: _Sender) -> None: +async def test_static_file_range_end_bigger_than_size( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" async def handler(request: web.Request) -> web.FileResponse: @@ -735,7 +755,9 @@ async def handler(request: web.Request) -> web.FileResponse: await client.close() -async def test_static_file_range_beyond_eof(aiohttp_client: AiohttpClient, sender: _Sender) -> None: +async def test_static_file_range_beyond_eof( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" async def handler(request: web.Request) -> web.FileResponse: @@ -756,7 +778,9 @@ async def handler(request: web.Request) -> web.FileResponse: await client.close() -async def test_static_file_range_tail(aiohttp_client: AiohttpClient, sender: _Sender) -> None: +async def test_static_file_range_tail( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" async def handler(request: web.Request) -> web.FileResponse: @@ -791,7 +815,9 @@ async def handler(request: web.Request) -> web.FileResponse: await client.close() -async def test_static_file_invalid_range(aiohttp_client: AiohttpClient, sender: _Sender) -> None: +async def test_static_file_invalid_range( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" async def handler(request: web.Request) -> web.FileResponse: @@ -1001,7 +1027,9 @@ async def test_static_file_if_range_invalid_date( await client.close() -async def test_static_file_compression(aiohttp_client: AiohttpClient, sender: _Sender) -> None: +async def test_static_file_compression( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "data.unknown_mime_type" async def handler(request: web.Request) -> web.FileResponse: @@ -1025,7 +1053,9 @@ async def handler(request: web.Request) -> web.FileResponse: await client.close() -async def test_static_file_huge_cancel(aiohttp_client: AiohttpClient, tmp_path: pathlib.Path) -> None: +async def test_static_file_huge_cancel( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 100MB file @@ -1068,7 +1098,9 @@ async def handler(request: web.Request) -> web.FileResponse: await client.close() -async def test_static_file_huge_error(aiohttp_client: AiohttpClient, tmp_path: pathlib.Path) -> None: +async def test_static_file_huge_error( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 20MB file diff --git a/tests/test_web_server.py b/tests/test_web_server.py index 0cd462af3b6..312dc4ac85f 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -9,7 +9,9 @@ from aiohttp.pytest_plugin import AiohttpClient, AiohttpRawServer -async def test_simple_server(aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient) -> None: +async def test_simple_server( + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.BaseRequest) -> web.Response: return web.Response(text=str(request.rel_url)) @@ -21,7 +23,9 @@ async def handler(request: web.BaseRequest) -> web.Response: assert txt == "/path/to" -async def test_unsupported_upgrade(aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient) -> None: +async def test_unsupported_upgrade( + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient +) -> None: # don't fail if a client probes for an unsupported protocol upgrade # https://github.com/aio-libs/aiohttp/issues/6446#issuecomment-999032039 async def handler(request: web.BaseRequest) -> web.Response: @@ -38,7 +42,9 @@ async def handler(request: web.BaseRequest) -> web.Response: async def test_raw_server_not_http_exception( - aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop + aiohttp_raw_server: AiohttpRawServer, + aiohttp_client: AiohttpClient, + loop: asyncio.AbstractEventLoop, ) -> None: # disable debug mode not to print traceback loop.set_debug(False) @@ -148,7 +154,9 @@ async def handler(request: web.BaseRequest) -> NoReturn: async def test_raw_server_html_exception( - aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient, loop: asyncio.AbstractEventLoop + aiohttp_raw_server: AiohttpRawServer, + aiohttp_client: AiohttpClient, + loop: asyncio.AbstractEventLoop, ) -> None: # disable debug mode not to print traceback loop.set_debug(False) diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index ff6ad25ae55..61963a4b6bc 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -15,14 +15,22 @@ from aiohttp.typedefs import LooseHeaders from aiohttp.web_ws import WebSocketReady + class _RequestMaker(Protocol): - def __call__(self, method: str, path: str, headers: Optional[CIMultiDict[str]] = None, protocols: bool = False) -> web.Request: - ... + def __call__( + self, + method: str, + path: str, + headers: Optional[CIMultiDict[str]] = None, + protocols: bool = False, + ) -> web.Request: ... @pytest.fixture def app(loop: asyncio.AbstractEventLoop) -> web.Application: - ret: web.Application = mock.create_autospec(web.Application, spec_set=True, on_response_prepare=aiosignal.Signal(ret)) + ret: web.Application = mock.create_autospec( + web.Application, spec_set=True, on_response_prepare=aiosignal.Signal(ret) + ) ret.on_response_prepare.freeze() return ret @@ -35,8 +43,15 @@ def protocol() -> web.RequestHandler[web.Request]: @pytest.fixture -def make_request(app: web.Application, protocol: web.RequestHandler[web.Request]) -> _RequestMaker: - def maker(method: str, path: str, headers: Optional[CIMultiDict[str]] = None, protocols: bool = False) -> web.Request: +def make_request( + app: web.Application, protocol: web.RequestHandler[web.Request] +) -> _RequestMaker: + def maker( + method: str, + path: str, + headers: Optional[CIMultiDict[str]] = None, + protocols: bool = False, + ) -> web.Request: if headers is None: headers = CIMultiDict( { @@ -336,7 +351,9 @@ async def test_write_eof_idempotent(make_request: _RequestMaker) -> None: assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] -async def test_receive_eofstream_in_reader(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: +async def test_receive_eofstream_in_reader( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) @@ -349,13 +366,17 @@ async def test_receive_eofstream_in_reader(make_request: _RequestMaker, loop: as assert ws._payload_writer is not None f = loop.create_future() f.set_result(True) - with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + with mock.patch.object( + ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f + ): msg = await ws.receive() assert msg.type == WSMsgType.CLOSED assert ws.closed -async def test_receive_exception_in_reader(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: +async def test_receive_exception_in_reader( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) @@ -368,7 +389,9 @@ async def test_receive_exception_in_reader(make_request: _RequestMaker, loop: as f = loop.create_future() f.set_result(True) - with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + with mock.patch.object( + ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f + ): msg = await ws.receive() assert msg.type == WSMsgType.ERROR assert ws.closed @@ -376,7 +399,9 @@ async def test_receive_exception_in_reader(make_request: _RequestMaker, loop: as assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] -async def test_receive_close_but_left_open(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: +async def test_receive_close_but_left_open( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) @@ -387,7 +412,9 @@ async def test_receive_close_but_left_open(make_request: _RequestMaker, loop: as f = loop.create_future() f.set_result(True) - with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + with mock.patch.object( + ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f + ): msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert ws.closed @@ -395,7 +422,9 @@ async def test_receive_close_but_left_open(make_request: _RequestMaker, loop: as assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] -async def test_receive_closing(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: +async def test_receive_closing( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) @@ -407,7 +436,9 @@ async def test_receive_closing(make_request: _RequestMaker, loop: asyncio.Abstra f = loop.create_future() f.set_result(True) - with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + with mock.patch.object( + ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f + ): msg = await ws.receive() assert msg.type == WSMsgType.CLOSING assert not ws.closed @@ -422,7 +453,9 @@ async def test_receive_closing(make_request: _RequestMaker, loop: asyncio.Abstra assert msg.type == WSMsgType.CLOSING -async def test_close_after_closing(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: +async def test_close_after_closing( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) @@ -433,7 +466,9 @@ async def test_close_after_closing(make_request: _RequestMaker, loop: asyncio.Ab f = loop.create_future() f.set_result(True) - with mock.patch.object(ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f): + with mock.patch.object( + ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f + ): msg = await ws.receive() assert msg.type == WSMsgType.CLOSING assert not ws.closed @@ -445,7 +480,9 @@ async def test_close_after_closing(make_request: _RequestMaker, loop: asyncio.Ab assert len(req.transport.close.mock_calls) == 1 -async def test_receive_timeouterror(make_request: _RequestMaker, loop: asyncio.AbstractEventLoop) -> None: +async def test_receive_timeouterror( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) @@ -464,7 +501,9 @@ async def test_receive_timeouterror(make_request: _RequestMaker, loop: asyncio.A assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] -async def test_multiple_receive_on_close_connection(make_request: _RequestMaker) -> None: +async def test_multiple_receive_on_close_connection( + make_request: _RequestMaker, +) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) @@ -521,7 +560,9 @@ async def test_prepare_twice_idempotent(make_request: _RequestMaker) -> None: assert impl1 is impl2 -async def test_send_with_per_message_deflate(make_request: _RequestMaker, mocker: MockerFixture) -> None: +async def test_send_with_per_message_deflate( + make_request: _RequestMaker, mocker: MockerFixture +) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) @@ -536,7 +577,9 @@ async def test_send_with_per_message_deflate(make_request: _RequestMaker, mocker m.assert_called_with('"[{}]"', binary=False, compress=9) -async def test_no_transfer_encoding_header(make_request: _RequestMaker, mocker: MockerFixture) -> None: +async def test_no_transfer_encoding_header( + make_request: _RequestMaker, mocker: MockerFixture +) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() await ws._start(req) @@ -562,7 +605,10 @@ async def test_no_transfer_encoding_header(make_request: _RequestMaker, mocker: ], ) async def test_get_extra_info( - make_request: _RequestMaker, mocker: MockerFixture, ws_transport: Optional[mock.MagicMock], expected_result: str + make_request: _RequestMaker, + mocker: MockerFixture, + ws_transport: Optional[mock.MagicMock], + expected_result: str, ) -> None: valid_key = "test" default_value = "default" diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index f28ec42fa8b..8823e2c855a 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -15,7 +15,9 @@ from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer -async def test_websocket_can_prepare(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_websocket_can_prepare( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.Response: ws = web.WebSocketResponse() if not ws.can_prepare(request): @@ -31,7 +33,9 @@ async def handler(request: web.Request) -> web.Response: assert resp.status == 426 -async def test_websocket_json(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_websocket_json( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): @@ -60,7 +64,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert resp.data == expected_value -async def test_websocket_json_invalid_message(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_websocket_json_invalid_message( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -86,7 +92,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert "ValueError was raised" in data -async def test_websocket_send_json(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_websocket_send_json( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -109,7 +117,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert data["test"] == expected_value -async def test_websocket_receive_json(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_websocket_receive_json( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -134,7 +144,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert resp.data == expected_value -async def test_send_recv_text(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_send_recv_text( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -167,7 +179,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_send_recv_bytes(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_send_recv_bytes( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -201,7 +215,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_send_recv_json(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_send_recv_json( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -235,7 +251,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_close_timeout(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_close_timeout( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: aborted = loop.create_future() elapsed = 1e10 # something big @@ -277,7 +295,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.close() -async def test_concurrent_close(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_concurrent_close( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: srv_ws = None async def handler(request: web.Request) -> web.WebSocketResponse: @@ -315,7 +335,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert msg.type == WSMsgType.CLOSED -async def test_concurrent_close_multiple_tasks(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_concurrent_close_multiple_tasks( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: srv_ws = None async def handler(request: web.Request) -> web.WebSocketResponse: @@ -388,7 +410,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert msg.type == WSMsgType.CLOSED -async def test_auto_pong_with_closing_by_peer(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_auto_pong_with_closing_by_peer( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -417,7 +441,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_ping(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_ping( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -443,7 +469,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_client_ping(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_client_ping( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -468,7 +496,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.close() -async def test_pong(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_pong( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -502,7 +532,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_change_status(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_change_status( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -526,7 +558,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.close() -async def test_handle_protocol(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_handle_protocol( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -547,7 +581,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_server_close_handshake(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_server_close_handshake( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -569,7 +605,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_client_close_handshake(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_client_close_handshake( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -632,7 +670,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_receive_timeout(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_receive_timeout( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: raised = False async def handler(request: web.Request) -> web.WebSocketResponse: @@ -658,7 +698,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert raised -async def test_custom_receive_timeout(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_custom_receive_timeout( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: raised = False async def handler(request: web.Request) -> web.WebSocketResponse: @@ -684,7 +726,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: assert raised -async def test_heartbeat(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_heartbeat( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) @@ -704,7 +748,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.close() -async def test_heartbeat_no_pong(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_heartbeat_no_pong( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) @@ -851,7 +897,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.close() -async def test_server_ws_async_for(loop: asyncio.AbstractEventLoop, aiohttp_server: AiohttpServer) -> None: +async def test_server_ws_async_for( + loop: asyncio.AbstractEventLoop, aiohttp_server: AiohttpServer +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -882,7 +930,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_closed_async_for(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_closed_async_for( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() async def handler(request: web.Request) -> web.WebSocketResponse: @@ -917,7 +967,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await closed -async def test_websocket_disable_keepalive(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_websocket_disable_keepalive( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.StreamResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): @@ -944,7 +996,9 @@ async def handler(request: web.Request) -> web.StreamResponse: assert data == "OK" -async def test_receive_str_nonstring(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_receive_str_nonstring( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): @@ -964,7 +1018,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.receive_str() -async def test_receive_bytes_nonbytes(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_receive_bytes_nonbytes( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): @@ -984,7 +1040,9 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.receive_bytes() -async def test_bug3380(loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient) -> None: +async def test_bug3380( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: async def handle_null(request: web.Request) -> web.Response: return web.json_response({"err": None}) @@ -1108,7 +1166,9 @@ async def test_websocket_shutdown(aiohttp_client: AiohttpClient) -> None: app[websockets] = weakref.WeakSet() # need for send signal shutdown server - shutdown_websockets = web.AppKey("shutdown_websockets", weakref.WeakSet[web.WebSocketResponse]) + shutdown_websockets = web.AppKey( + "shutdown_websockets", weakref.WeakSet[web.WebSocketResponse] + ) app[shutdown_websockets] = weakref.WeakSet() async def websocket_handler(request: web.Request) -> web.WebSocketResponse: From bbb585a5311c8d3d52fa932757f4b050f59c0bcf Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Sep 2024 17:27:28 +0100 Subject: [PATCH 3/9] Fix tests --- tests/test_web_websocket.py | 89 ++++++++++++-------------- tests/test_web_websocket_functional.py | 2 +- 2 files changed, 42 insertions(+), 49 deletions(-) diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 61963a4b6bc..55e97ab74f6 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -28,9 +28,8 @@ def __call__( @pytest.fixture def app(loop: asyncio.AbstractEventLoop) -> web.Application: - ret: web.Application = mock.create_autospec( - web.Application, spec_set=True, on_response_prepare=aiosignal.Signal(ret) - ) + ret: web.Application = mock.create_autospec(web.Application, spec_set=True) + ret.on_response_prepare = aiosignal.Signal(ret) # type: ignore[misc] ret.on_response_prepare.freeze() return ret @@ -366,12 +365,10 @@ async def test_receive_eofstream_in_reader( assert ws._payload_writer is not None f = loop.create_future() f.set_result(True) - with mock.patch.object( - ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f - ): - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSED - assert ws.closed + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSED + assert ws.closed async def test_receive_exception_in_reader( @@ -388,15 +385,14 @@ async def test_receive_exception_in_reader( ws._reader.read = make_mocked_coro(res) f = loop.create_future() + assert ws._payload_writer is not None + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] f.set_result(True) - with mock.patch.object( - ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f - ): - msg = await ws.receive() - assert msg.type == WSMsgType.ERROR - assert ws.closed - assert req.transport is not None - assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] + msg = await ws.receive() + assert msg.type == WSMsgType.ERROR + assert ws.closed + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] async def test_receive_close_but_left_open( @@ -411,15 +407,14 @@ async def test_receive_close_but_left_open( ws._reader.read = mock.AsyncMock(return_value=close_message) f = loop.create_future() + assert ws._payload_writer is not None + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] f.set_result(True) - with mock.patch.object( - ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f - ): - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSE - assert ws.closed - assert req.transport is not None - assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSE + assert ws.closed + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] async def test_receive_closing( @@ -435,22 +430,21 @@ async def test_receive_closing( ws._reader.read = read_mock f = loop.create_future() + assert ws._payload_writer is not None + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] f.set_result(True) - with mock.patch.object( - ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f - ): - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING - assert not ws.closed + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + assert not ws.closed - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING - assert not ws.closed + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + assert not ws.closed - ws._cancel(ConnectionResetError("Connection lost")) + ws._cancel(ConnectionResetError("Connection lost")) - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING async def test_close_after_closing( @@ -465,19 +459,18 @@ async def test_close_after_closing( ws._reader.read = mock.AsyncMock(return_value=closing_message) f = loop.create_future() + assert ws._payload_writer is not None + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] f.set_result(True) - with mock.patch.object( - ws._payload_writer, "drain", autospec=True, spec_set=True, return_value=f - ): - msg = await ws.receive() - assert msg.type == WSMsgType.CLOSING - assert not ws.closed - assert req.transport is not None - assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] + msg = await ws.receive() + assert msg.type == WSMsgType.CLOSING + assert not ws.closed + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] - await ws.close() - assert ws.closed - assert len(req.transport.close.mock_calls) == 1 + await ws.close() + assert ws.closed + assert len(req.transport.close.mock_calls) == 1 async def test_receive_timeouterror( diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 8823e2c855a..c6492333007 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1027,7 +1027,7 @@ async def handler(request: web.Request) -> web.WebSocketResponse: raise web.HTTPUpgradeRequired() await ws.prepare(request) - await ws.send_bytes(b"answer") + await ws.send_bytes("answer") # type: ignore[arg-type] await ws.close() return ws From 1430599052b152fc5f8161fb868c34cc7226c192 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Sep 2024 17:37:43 +0100 Subject: [PATCH 4/9] Fix test --- tests/test_web_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index ba0f18a8783..2fe47bd8f72 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -241,6 +241,7 @@ async def test_tcpsite_default_host(make_runner: _RunnerMaker) -> None: assert site.name == "http://0.0.0.0:8080" m = mock.create_autospec(asyncio.AbstractEventLoop, spec_set=True, instance=True) + m.create_server.return_value = mock.create_autospec(asyncio.Server, spec_set=True) with mock.patch( "asyncio.get_event_loop", autospec=True, spec_set=True, return_value=m ) as mock_get_loop: From 95fda3e83e32c217bb311de0ac604cfeb935db3e Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Sep 2024 17:44:46 +0100 Subject: [PATCH 5/9] Lint --- tests/test_web_runner.py | 2 +- tests/test_web_sendfile_functional.py | 3 +-- tests/test_web_websocket.py | 7 ++++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index 2fe47bd8f72..d75e68ee153 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -244,7 +244,7 @@ async def test_tcpsite_default_host(make_runner: _RunnerMaker) -> None: m.create_server.return_value = mock.create_autospec(asyncio.Server, spec_set=True) with mock.patch( "asyncio.get_event_loop", autospec=True, spec_set=True, return_value=m - ) as mock_get_loop: + ): await site.start() m.create_server.assert_called_once() diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 6517454ea43..9e5c05651c1 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -3,7 +3,6 @@ import gzip import pathlib import socket -import ssl import zlib from typing import Callable, Iterable, Iterator, NoReturn, Optional, Tuple from unittest import mock @@ -565,7 +564,7 @@ async def test_static_file_if_none_match_star( @pytest.mark.parametrize("if_modified_since", ("", "Fri, 31 Dec 9999 23:59:59 GMT")) async def test_static_file_if_none_match_weak( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_modified_since: str, ) -> None: diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 55e97ab74f6..bc89bc78967 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -12,7 +12,6 @@ from aiohttp.http import WS_CLOSED_MESSAGE, WSMessage from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro, make_mocked_request -from aiohttp.typedefs import LooseHeaders from aiohttp.web_ws import WebSocketReady @@ -299,12 +298,14 @@ async def test_close_idempotent(make_request: _RequestMaker) -> None: await ws.prepare(req) assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) - assert await ws.close(code=1, message=b"message1") + close_code = await ws.close(code=1, message=b"message1") + assert close_code == 1 assert ws.closed assert req.transport is not None assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] - assert not (await ws.close(code=2, message=b"message2")) + close_code = await ws.close(code=2, message=b"message2") + assert close_code == 0 async def test_prepare_post_method_ok(make_request: _RequestMaker) -> None: From eeb3216f090f5d502abed74b49b9766638a0f244 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Sep 2024 18:20:48 +0100 Subject: [PATCH 6/9] Cleanup --- tests/test_web_sendfile_functional.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 9e5c05651c1..df5e41075f6 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -25,7 +25,9 @@ except ImportError: ssl = None # type: ignore[assignment] -_Sender = Callable[..., web.FileResponse] +class _Sender(Protocol): + def __call__(self, path: PathLike, chunk_size: int = 256 * 1024) -> web.FileResponse: + ... HELLO_AIOHTTP = b"Hello aiohttp! :-)\n" From 10339220638e419de9585d88c2d7bbbfbbc90b94 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Sep 2024 18:34:46 +0100 Subject: [PATCH 7/9] Coverage --- aiohttp/http_websocket.py | 8 ++------ tests/test_web_sendfile_functional.py | 2 +- tests/test_web_websocket_functional.py | 22 ++++++++-------------- 3 files changed, 11 insertions(+), 21 deletions(-) diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index dd17b7675de..b63d41f860f 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -708,16 +708,12 @@ def _write(self, data: bytes) -> None: raise ClientConnectionResetError("Cannot write to closing transport") self.transport.write(data) - async def pong(self, message: Union[bytes, str] = b"") -> None: + async def pong(self, message: bytes = b"") -> None: """Send pong message.""" - if isinstance(message, str): - message = message.encode("utf-8") await self._send_frame(message, WSMsgType.PONG) - async def ping(self, message: Union[bytes, str] = b"") -> None: + async def ping(self, message: bytes = b"") -> None: """Send ping message.""" - if isinstance(message, str): - message = message.encode("utf-8") await self._send_frame(message, WSMsgType.PING) async def send( diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index df5e41075f6..28914b9a0b7 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -4,7 +4,7 @@ import pathlib import socket import zlib -from typing import Callable, Iterable, Iterator, NoReturn, Optional, Tuple +from typing import Callable, Iterable, Iterator, NoReturn, Optional, Protocol, Tuple from unittest import mock import pytest diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index c6492333007..f718126b630 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -18,12 +18,10 @@ async def test_websocket_can_prepare( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: - async def handler(request: web.Request) -> web.Response: + async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() - if not ws.can_prepare(request): - raise web.HTTPUpgradeRequired() - - return web.Response() + assert not ws.can_prepare(request) + raise web.HTTPUpgradeRequired() app = web.Application() app.router.add_route("GET", "/", handler) @@ -38,8 +36,7 @@ async def test_websocket_json( ) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() - if not ws.can_prepare(request): - raise web.HTTPUpgradeRequired() + assert ws.can_prepare(request) await ws.prepare(request) msg = await ws.receive() @@ -1001,8 +998,7 @@ async def test_receive_str_nonstring( ) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() - if not ws.can_prepare(request): - raise web.HTTPUpgradeRequired() + assert ws.can_prepare(request) await ws.prepare(request) await ws.send_bytes(b"answer") @@ -1021,15 +1017,13 @@ async def handler(request: web.Request) -> web.WebSocketResponse: async def test_receive_bytes_nonbytes( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: - async def handler(request: web.Request) -> web.WebSocketResponse: + async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() - if not ws.can_prepare(request): - raise web.HTTPUpgradeRequired() + assert ws.can_prepare(request) await ws.prepare(request) await ws.send_bytes("answer") # type: ignore[arg-type] - await ws.close() - return ws + assert False app = web.Application() app.router.add_route("GET", "/", handler) From d635d80c2b9dc80e2ad09bb38944cc5b8d2c9b28 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:35:28 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_web_sendfile_functional.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 28914b9a0b7..3b642b276c8 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -25,9 +25,11 @@ except ImportError: ssl = None # type: ignore[assignment] + class _Sender(Protocol): - def __call__(self, path: PathLike, chunk_size: int = 256 * 1024) -> web.FileResponse: - ... + def __call__( + self, path: PathLike, chunk_size: int = 256 * 1024 + ) -> web.FileResponse: ... HELLO_AIOHTTP = b"Hello aiohttp! :-)\n" From 9162625bac389a228fdf4c64b45cbfe150b11840 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 25 Sep 2024 18:42:39 +0100 Subject: [PATCH 9/9] Update test_web_sendfile_functional.py --- tests/test_web_sendfile_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 3b642b276c8..a6921cf4105 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -4,7 +4,7 @@ import pathlib import socket import zlib -from typing import Callable, Iterable, Iterator, NoReturn, Optional, Protocol, Tuple +from typing import Iterable, Iterator, NoReturn, Optional, Protocol, Tuple from unittest import mock import pytest