diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index dd17b7675de..b63d41f860f 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -708,16 +708,12 @@ def _write(self, data: bytes) -> None: raise ClientConnectionResetError("Cannot write to closing transport") self.transport.write(data) - async def pong(self, message: Union[bytes, str] = b"") -> None: + async def pong(self, message: bytes = b"") -> None: """Send pong message.""" - if isinstance(message, str): - message = message.encode("utf-8") await self._send_frame(message, WSMsgType.PONG) - async def ping(self, message: Union[bytes, str] = b"") -> None: + async def ping(self, message: bytes = b"") -> None: """Send ping message.""" - if isinstance(message, str): - message = message.encode("utf-8") await self._send_frame(message, WSMsgType.PING) async def send( diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index d5a48e478db..14d47b66e40 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -404,14 +404,14 @@ async def pong(self, message: bytes = b"") -> None: raise RuntimeError("Call .prepare() first") await self._writer.pong(message) - async def send_str(self, data: str, compress: Optional[bool] = None) -> None: + async def send_str(self, data: str, compress: Optional[int] = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, str): raise TypeError("data argument must be str (%r)" % type(data)) await self._writer.send(data, binary=False, compress=compress) - async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None: + async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, (bytes, bytearray, memoryview)): @@ -421,7 +421,7 @@ async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None async def send_json( self, data: Any, - compress: Optional[bool] = None, + compress: Optional[int] = None, *, dumps: JSONEncoder = json.dumps, ) -> None: diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index 302afbb95d6..d75e68ee153 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -1,29 +1,35 @@ -# type: ignore import asyncio import platform import signal -from typing import Any -from unittest.mock import patch +from typing import Any, Iterator, NoReturn, Protocol, Union +from unittest import mock import pytest from aiohttp import web from aiohttp.abc import AbstractAccessLogger from aiohttp.test_utils import get_unused_port_socket +from aiohttp.web_log import AccessLogger + + +class _RunnerMaker(Protocol): + def __call__(self, handle_signals: bool = ..., **kwargs: Any) -> web.AppRunner: ... @pytest.fixture -def app(): +def app() -> web.Application: return web.Application() @pytest.fixture -def make_runner(loop: Any, app: Any): +def make_runner( + loop: asyncio.AbstractEventLoop, app: web.Application +) -> Iterator[_RunnerMaker]: asyncio.set_event_loop(loop) runners = [] - def go(**kwargs): - runner = web.AppRunner(app, **kwargs) + def go(handle_signals: bool = False, **kwargs: Any) -> web.AppRunner: + runner = web.AppRunner(app, handle_signals=handle_signals, **kwargs) runners.append(runner) return runner @@ -32,7 +38,7 @@ def go(**kwargs): loop.run_until_complete(runner.cleanup()) -async def test_site_for_nonfrozen_app(make_runner: Any) -> None: +async def test_site_for_nonfrozen_app(make_runner: _RunnerMaker) -> None: runner = make_runner() with pytest.raises(RuntimeError): web.TCPSite(runner) @@ -42,7 +48,7 @@ async def test_site_for_nonfrozen_app(make_runner: Any) -> None: @pytest.mark.skipif( platform.system() == "Windows", reason="the test is not valid for Windows" ) -async def test_runner_setup_handle_signals(make_runner: Any) -> None: +async def test_runner_setup_handle_signals(make_runner: _RunnerMaker) -> None: runner = make_runner(handle_signals=True) await runner.setup() assert signal.getsignal(signal.SIGTERM) is not signal.SIG_DFL @@ -53,7 +59,7 @@ async def test_runner_setup_handle_signals(make_runner: Any) -> None: @pytest.mark.skipif( platform.system() == "Windows", reason="the test is not valid for Windows" ) -async def test_runner_setup_without_signal_handling(make_runner: Any) -> None: +async def test_runner_setup_without_signal_handling(make_runner: _RunnerMaker) -> None: runner = make_runner(handle_signals=False) await runner.setup() assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL @@ -61,7 +67,7 @@ async def test_runner_setup_without_signal_handling(make_runner: Any) -> None: assert signal.getsignal(signal.SIGTERM) is signal.SIG_DFL -async def test_site_double_added(make_runner: Any) -> None: +async def test_site_double_added(make_runner: _RunnerMaker) -> None: _sock = get_unused_port_socket("127.0.0.1") runner = make_runner() await runner.setup() @@ -73,7 +79,7 @@ async def test_site_double_added(make_runner: Any) -> None: assert len(runner.sites) == 1 -async def test_site_stop_not_started(make_runner: Any) -> None: +async def test_site_stop_not_started(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) @@ -83,13 +89,14 @@ async def test_site_stop_not_started(make_runner: Any) -> None: assert len(runner.sites) == 0 -async def test_custom_log_format(make_runner: Any) -> None: +async def test_custom_log_format(make_runner: _RunnerMaker) -> None: runner = make_runner(access_log_format="abc") await runner.setup() + assert runner.server is not None assert runner.server._kwargs["access_log_format"] == "abc" -async def test_unreg_site(make_runner: Any) -> None: +async def test_unreg_site(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) @@ -97,20 +104,20 @@ async def test_unreg_site(make_runner: Any) -> None: runner._unreg_site(site) -async def test_app_property(make_runner: Any, app: Any) -> None: +async def test_app_property(make_runner: _RunnerMaker, app: web.Application) -> None: runner = make_runner() assert runner.app is app def test_non_app() -> None: with pytest.raises(TypeError): - web.AppRunner(object()) + web.AppRunner(object()) # type: ignore[arg-type] def test_app_handler_args() -> None: app = web.Application(handler_args={"test": True}) runner = web.AppRunner(app) - assert runner._kwargs == {"access_log_class": web.AccessLogger, "test": True} + assert runner._kwargs == {"access_log_class": AccessLogger, "test": True} async def test_app_handler_args_failure() -> None: @@ -132,7 +139,9 @@ async def test_app_handler_args_failure() -> None: ("2", 2), ), ) -async def test_app_handler_args_ceil_threshold(value: Any, expected: Any) -> None: +async def test_app_handler_args_ceil_threshold( + value: Union[int, str, None], expected: int +) -> None: app = web.Application(handler_args={"timeout_ceil_threshold": value}) runner = web.AppRunner(app) await runner.setup() @@ -150,7 +159,7 @@ class Logger: app = web.Application() with pytest.raises(TypeError): - web.AppRunner(app, access_log_class=Logger) + web.AppRunner(app, access_log_class=Logger) # type: ignore[arg-type] async def test_app_make_handler_access_log_class_bad_type2() -> None: @@ -165,7 +174,9 @@ class Logger: async def test_app_make_handler_access_log_class1() -> None: class Logger(AbstractAccessLogger): - def log(self, request, response, time): + def log( + self, request: web.BaseRequest, response: web.StreamResponse, time: float + ) -> None: """Pass log method.""" app = web.Application() @@ -175,7 +186,9 @@ def log(self, request, response, time): async def test_app_make_handler_access_log_class2() -> None: class Logger(AbstractAccessLogger): - def log(self, request, response, time): + def log( + self, request: web.BaseRequest, response: web.StreamResponse, time: float + ) -> None: """Pass log method.""" app = web.Application(handler_args={"access_log_class": Logger}) @@ -183,7 +196,7 @@ def log(self, request, response, time): assert runner._kwargs["access_log_class"] is Logger -async def test_addresses(make_runner: Any, unix_sockname: Any) -> None: +async def test_addresses(make_runner: _RunnerMaker, unix_sockname: str) -> None: _sock = get_unused_port_socket("127.0.0.1") runner = make_runner() await runner.setup() @@ -200,7 +213,7 @@ async def test_addresses(make_runner: Any, unix_sockname: Any) -> None: platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_runner_wrong_loop( - app: Any, selector_loop: Any, pipe_name: Any + app: web.Application, selector_loop: asyncio.AbstractEventLoop, pipe_name: str ) -> None: runner = web.AppRunner(app) await runner.setup() @@ -212,7 +225,7 @@ async def test_named_pipe_runner_wrong_loop( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) async def test_named_pipe_runner_proactor_loop( - proactor_loop: Any, app: Any, pipe_name: Any + proactor_loop: asyncio.AbstractEventLoop, app: web.Application, pipe_name: str ) -> None: runner = web.AppRunner(app) await runner.setup() @@ -221,29 +234,25 @@ async def test_named_pipe_runner_proactor_loop( await runner.cleanup() -async def test_tcpsite_default_host(make_runner: Any) -> None: +async def test_tcpsite_default_host(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner) assert site.name == "http://0.0.0.0:8080" - calls = [] - - async def mock_create_server(*args, **kwargs): - calls.append((args, kwargs)) - - with patch("asyncio.get_event_loop") as mock_get_loop: - mock_get_loop.return_value.create_server = mock_create_server + m = mock.create_autospec(asyncio.AbstractEventLoop, spec_set=True, instance=True) + m.create_server.return_value = mock.create_autospec(asyncio.Server, spec_set=True) + with mock.patch( + "asyncio.get_event_loop", autospec=True, spec_set=True, return_value=m + ): await site.start() - assert len(calls) == 1 - server, host, port = calls[0][0] - assert server is runner.server - assert host is None - assert port == 8080 + m.create_server.assert_called_once() + args, kwargs = m.create_server.call_args + assert args == (runner.server, None, 8080) -async def test_tcpsite_empty_str_host(make_runner: Any) -> None: +async def test_tcpsite_empty_str_host(make_runner: _RunnerMaker) -> None: runner = make_runner() await runner.setup() site = web.TCPSite(runner, host="") @@ -251,15 +260,16 @@ async def test_tcpsite_empty_str_host(make_runner: Any) -> None: def test_run_after_asyncio_run() -> None: - async def nothing(): - pass + called = False - def spy(): - spy.called = True + async def nothing() -> None: + pass - spy.called = False + def spy() -> None: + nonlocal called + called = True - async def shutdown(): + async def shutdown() -> NoReturn: spy() raise web.GracefulExit() @@ -271,4 +281,4 @@ async def shutdown(): app.on_startup.append(lambda a: asyncio.create_task(shutdown())) web.run_app(app) - assert spy.called, "run_app() should work after asyncio.run()." + assert called, "run_app() should work after asyncio.run()." diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index f55d329c36a..a6921cf4105 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -1,17 +1,19 @@ -# type: ignore import asyncio import bz2 import gzip import pathlib import socket import zlib -from typing import Any, Iterable, Optional +from typing import Iterable, Iterator, NoReturn, Optional, Protocol, Tuple from unittest import mock import pytest +from _pytest.fixtures import SubRequest import aiohttp from aiohttp import web +from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer +from aiohttp.typedefs import PathLike try: import brotlicffi as brotli @@ -21,14 +23,22 @@ try: import ssl except ImportError: - ssl = None + ssl = None # type: ignore[assignment] + + +class _Sender(Protocol): + def __call__( + self, path: PathLike, chunk_size: int = 256 * 1024 + ) -> web.FileResponse: ... HELLO_AIOHTTP = b"Hello aiohttp! :-)\n" @pytest.fixture(scope="module") -def hello_txt(request, tmp_path_factory) -> pathlib.Path: +def hello_txt( + request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathFactory +) -> pathlib.Path: """Create a temp path with hello.txt and compressed versions. The uncompressed text file path is returned by default. Alternatively, an @@ -50,22 +60,24 @@ def hello_txt(request, tmp_path_factory) -> pathlib.Path: @pytest.fixture -def loop_with_mocked_native_sendfile(loop: Any): - def sendfile(transport, fobj, offset, count): +def loop_with_mocked_native_sendfile( + loop: asyncio.AbstractEventLoop, +) -> Iterator[asyncio.AbstractEventLoop]: + def sendfile(transport: object, fobj: object, offset: int, count: int) -> NoReturn: if count == 0: raise ValueError("count must be a positive integer (got 0)") raise NotImplementedError - loop.sendfile = sendfile - return loop + with mock.patch.object(loop, "sendfile", sendfile): + yield loop @pytest.fixture(params=["sendfile", "no_sendfile"], ids=["sendfile", "no_sendfile"]) -def sender(request: Any, loop: Any): +def sender(request: SubRequest, loop: asyncio.AbstractEventLoop) -> Iterator[_Sender]: sendfile_mock = None - def maker(*args, **kwargs): - ret = web.FileResponse(*args, **kwargs) + def maker(path: PathLike, chunk_size: int = 256 * 1024) -> web.FileResponse: + ret = web.FileResponse(path, chunk_size=chunk_size) rloop = asyncio.get_running_loop() is_patched = rloop.sendfile is sendfile_mock assert is_patched if request.param == "no_sendfile" else not is_patched @@ -85,11 +97,11 @@ def maker(*args, **kwargs): @pytest.fixture -def app_with_static_route(sender: Any) -> web.Application: +def app_with_static_route(sender: _Sender) -> web.Application: filename = "data.unknown_mime_type" filepath = pathlib.Path(__file__).parent / filename - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath) app = web.Application() @@ -98,7 +110,7 @@ async def handler(request): async def test_static_file_ok( - aiohttp_client: Any, app_with_static_route: web.Application + aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) @@ -112,10 +124,12 @@ async def test_static_file_ok( await client.close() -async def test_zero_bytes_file_ok(aiohttp_client: Any, sender: Any) -> None: +async def test_zero_bytes_file_ok( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "data.zero_bytes" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath) app = web.Application() @@ -138,11 +152,12 @@ async def handler(request): async def test_zero_bytes_file_mocked_native_sendfile( - aiohttp_client: Any, loop_with_mocked_native_sendfile: Any + aiohttp_client: AiohttpClient, + loop_with_mocked_native_sendfile: asyncio.AbstractEventLoop, ) -> None: filepath = pathlib.Path(__file__).parent / "data.zero_bytes" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: asyncio.set_event_loop(loop_with_mocked_native_sendfile) return web.FileResponse(filepath) @@ -167,7 +182,7 @@ async def handler(request): async def test_static_file_ok_string_path( - aiohttp_client: Any, app_with_static_route: web.Application + aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) @@ -181,7 +196,7 @@ async def test_static_file_ok_string_path( await client.close() -async def test_static_file_not_exists(aiohttp_client: Any) -> None: +async def test_static_file_not_exists(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) @@ -191,7 +206,7 @@ async def test_static_file_not_exists(aiohttp_client: Any) -> None: await client.close() -async def test_static_file_name_too_long(aiohttp_client: Any) -> None: +async def test_static_file_name_too_long(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) @@ -201,7 +216,7 @@ async def test_static_file_name_too_long(aiohttp_client: Any) -> None: await client.close() -async def test_static_file_upper_directory(aiohttp_client: Any) -> None: +async def test_static_file_upper_directory(aiohttp_client: AiohttpClient) -> None: app = web.Application() client = await aiohttp_client(app) @@ -211,10 +226,12 @@ async def test_static_file_upper_directory(aiohttp_client: Any) -> None: await client.close() -async def test_static_file_with_content_type(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_with_content_type( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.jpg" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -236,11 +253,11 @@ async def handler(request): @pytest.mark.parametrize("hello_txt", ["gzip", "br"], indirect=True) async def test_static_file_custom_content_type( - hello_txt: pathlib.Path, aiohttp_client: Any, sender: Any + hello_txt: pathlib.Path, aiohttp_client: AiohttpClient, sender: _Sender ) -> None: """Test that custom type without encoding is returned for encoded request.""" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: resp = sender(hello_txt, chunk_size=16) resp.content_type = "application/pdf" return resp @@ -265,14 +282,14 @@ async def handler(request): ) async def test_static_file_custom_content_type_compress( hello_txt: pathlib.Path, - aiohttp_client: Any, - sender: Any, + aiohttp_client: AiohttpClient, + sender: _Sender, accept_encoding: str, expect_encoding: str, -): +) -> None: """Test that custom type with encoding is returned for unencoded requests.""" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: resp = sender(hello_txt, chunk_size=16) resp.content_type = "application/pdf" return resp @@ -298,15 +315,15 @@ async def handler(request): @pytest.mark.parametrize("forced_compression", [None, web.ContentCoding.gzip]) async def test_static_file_with_encoding_and_enable_compression( hello_txt: pathlib.Path, - aiohttp_client: Any, - sender: Any, + aiohttp_client: AiohttpClient, + sender: _Sender, accept_encoding: str, expect_encoding: str, forced_compression: Optional[web.ContentCoding], -): +) -> None: """Test that enable_compression does not double compress when an encoded file is also present.""" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: resp = sender(hello_txt) resp.enable_compression(forced_compression) return resp @@ -335,11 +352,14 @@ async def handler(request): indirect=["hello_txt"], ) async def test_static_file_with_content_encoding( - hello_txt: pathlib.Path, aiohttp_client: Any, sender: Any, expect_type: str + hello_txt: pathlib.Path, + aiohttp_client: AiohttpClient, + sender: _Sender, + expect_type: str, ) -> None: """Test requesting static compressed files returns the correct content type and encoding.""" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(hello_txt) app = web.Application() @@ -358,7 +378,7 @@ async def handler(request): async def test_static_file_if_modified_since( - aiohttp_client: Any, app_with_static_route: web.Application + aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) @@ -381,7 +401,7 @@ async def test_static_file_if_modified_since( async def test_static_file_if_modified_since_past_date( - aiohttp_client: Any, app_with_static_route: web.Application + aiohttp_client: AiohttpClient, app_with_static_route: web.Application ) -> None: client = await aiohttp_client(app_with_static_route) @@ -396,8 +416,8 @@ async def test_static_file_if_modified_since_past_date( async def test_static_file_if_modified_since_invalid_date( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "not a valid HTTP-date" @@ -411,8 +431,8 @@ async def test_static_file_if_modified_since_invalid_date( async def test_static_file_if_modified_since_future_date( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -431,7 +451,7 @@ async def test_static_file_if_modified_since_future_date( @pytest.mark.parametrize("if_unmodified_since", ("", "Fri, 31 Dec 0000 23:59:59 GMT")) async def test_static_file_if_match( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_unmodified_since: str, ) -> None: @@ -467,11 +487,11 @@ async def test_static_file_if_match( ], ) async def test_static_file_if_match_custom_tags( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_unmodified_since: str, - etags: Iterable[str], - expected_status: Iterable[int], + etags: Tuple[str], + expected_status: int, ) -> None: client = await aiohttp_client(app_with_static_route) @@ -496,7 +516,7 @@ async def test_static_file_if_match_custom_tags( ), ) async def test_static_file_if_none_match( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_modified_since: str, additional_etags: Iterable[str], @@ -528,7 +548,7 @@ async def test_static_file_if_none_match( async def test_static_file_if_none_match_star( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, ) -> None: client = await aiohttp_client(app_with_static_route) @@ -548,7 +568,7 @@ async def test_static_file_if_none_match_star( @pytest.mark.parametrize("if_modified_since", ("", "Fri, 31 Dec 9999 23:59:59 GMT")) async def test_static_file_if_none_match_weak( - aiohttp_client: Any, + aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_modified_since: str, ) -> None: @@ -581,10 +601,10 @@ async def test_static_file_if_none_match_weak( @pytest.mark.skipif(not ssl, reason="ssl not supported") async def test_static_file_ssl( - aiohttp_server: Any, - ssl_ctx: Any, - aiohttp_client: Any, - client_ssl_ctx: Any, + aiohttp_server: AiohttpServer, + ssl_ctx: ssl.SSLContext, + aiohttp_client: AiohttpClient, + client_ssl_ctx: ssl.SSLContext, ) -> None: dirname = pathlib.Path(__file__).parent filename = "data.unknown_mime_type" @@ -606,7 +626,9 @@ async def test_static_file_ssl( await client.close() -async def test_static_file_directory_traversal_attack(aiohttp_client: Any) -> None: +async def test_static_file_directory_traversal_attack( + aiohttp_client: AiohttpClient, +) -> None: dirname = pathlib.Path(__file__).parent relpath = "../README.rst" full_path = dirname / relpath @@ -633,7 +655,9 @@ async def test_static_file_directory_traversal_attack(aiohttp_client: Any) -> No await client.close() -async def test_static_file_huge(aiohttp_client: Any, tmp_path: Any) -> None: +async def test_static_file_huge( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 20MB file @@ -652,29 +676,31 @@ async def test_static_file_huge(aiohttp_client: Any, tmp_path: Any) -> None: ct = resp.headers["CONTENT-TYPE"] assert "application/octet-stream" == ct assert resp.headers.get("CONTENT-ENCODING") is None - assert int(resp.headers.get("CONTENT-LENGTH")) == file_st.st_size + assert int(resp.headers["CONTENT-LENGTH"]) == file_st.st_size - f = file_path.open("rb") + f2 = file_path.open("rb") off = 0 cnt = 0 while off < file_st.st_size: chunk = await resp.content.readany() - expected = f.read(len(chunk)) + expected = f2.read(len(chunk)) assert chunk == expected off += len(chunk) cnt += 1 - f.close() + f2.close() resp.release() await client.close() -async def test_static_file_range(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_range( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "sample.txt" filesize = filepath.stat().st_size - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -728,10 +754,12 @@ async def handler(request): await client.close() -async def test_static_file_range_end_bigger_than_size(aiohttp_client: Any, sender: Any): +async def test_static_file_range_end_bigger_than_size( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -762,10 +790,12 @@ async def handler(request): await client.close() -async def test_static_file_range_beyond_eof(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_range_beyond_eof( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -783,10 +813,12 @@ async def handler(request): await client.close() -async def test_static_file_range_tail(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_range_tail( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -818,10 +850,12 @@ async def handler(request): await client.close() -async def test_static_file_invalid_range(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_invalid_range( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "aiohttp.png" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: return sender(filepath, chunk_size=16) app = web.Application() @@ -868,8 +902,8 @@ async def handler(request): async def test_static_file_if_unmodified_since_past_with_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" @@ -885,8 +919,8 @@ async def test_static_file_if_unmodified_since_past_with_range( async def test_static_file_if_unmodified_since_future_with_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -904,8 +938,8 @@ async def test_static_file_if_unmodified_since_future_with_range( async def test_static_file_if_range_past_with_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" @@ -919,8 +953,8 @@ async def test_static_file_if_range_past_with_range( async def test_static_file_if_range_future_with_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -936,8 +970,8 @@ async def test_static_file_if_range_future_with_range( async def test_static_file_if_unmodified_since_past_without_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" @@ -951,8 +985,8 @@ async def test_static_file_if_unmodified_since_past_without_range( async def test_static_file_if_unmodified_since_future_without_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -967,8 +1001,8 @@ async def test_static_file_if_unmodified_since_future_without_range( async def test_static_file_if_range_past_without_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Mon, 1 Jan 1990 01:01:01 GMT" @@ -983,8 +1017,8 @@ async def test_static_file_if_range_past_without_range( async def test_static_file_if_range_future_without_range( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "Fri, 31 Dec 9999 23:59:59 GMT" @@ -999,8 +1033,8 @@ async def test_static_file_if_range_future_without_range( async def test_static_file_if_unmodified_since_invalid_date( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "not a valid HTTP-date" @@ -1014,8 +1048,8 @@ async def test_static_file_if_unmodified_since_invalid_date( async def test_static_file_if_range_invalid_date( - aiohttp_client: Any, app_with_static_route: web.Application -): + aiohttp_client: AiohttpClient, app_with_static_route: web.Application +) -> None: client = await aiohttp_client(app_with_static_route) lastmod = "not a valid HTTP-date" @@ -1028,10 +1062,12 @@ async def test_static_file_if_range_invalid_date( await client.close() -async def test_static_file_compression(aiohttp_client: Any, sender: Any) -> None: +async def test_static_file_compression( + aiohttp_client: AiohttpClient, sender: _Sender +) -> None: filepath = pathlib.Path(__file__).parent / "data.unknown_mime_type" - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: ret = sender(filepath) ret.enable_compression() return ret @@ -1052,7 +1088,9 @@ async def handler(request): await client.close() -async def test_static_file_huge_cancel(aiohttp_client: Any, tmp_path: Any) -> None: +async def test_static_file_huge_cancel( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 100MB file @@ -1062,11 +1100,12 @@ async def test_static_file_huge_cancel(aiohttp_client: Any, tmp_path: Any) -> No task = None - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: nonlocal task task = request.task # reduce send buffer size tr = request.transport + assert tr is not None sock = tr.get_extra_info("socket") sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) ret = web.FileResponse(file_path) @@ -1079,6 +1118,7 @@ async def handler(request): resp = await client.get("/") assert resp.status == 200 + assert task is not None task.cancel() await asyncio.sleep(0) data = b"" @@ -1093,7 +1133,9 @@ async def handler(request): await client.close() -async def test_static_file_huge_error(aiohttp_client: Any, tmp_path: Any) -> None: +async def test_static_file_huge_error( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: file_path = tmp_path / "huge_data.unknown_mime_type" # fill 20MB file @@ -1101,9 +1143,10 @@ async def test_static_file_huge_error(aiohttp_client: Any, tmp_path: Any) -> Non f.seek(20 * 1024 * 1024) f.write(b"1") - async def handler(request): + async def handler(request: web.Request) -> web.FileResponse: # reduce send buffer size tr = request.transport + assert tr is not None sock = tr.get_extra_info("socket") sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) ret = web.FileResponse(file_path) diff --git a/tests/test_web_server.py b/tests/test_web_server.py index 619b83e947c..312dc4ac85f 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -1,16 +1,18 @@ -# type: ignore import asyncio from contextlib import suppress -from typing import Any +from typing import Callable, NoReturn from unittest import mock import pytest from aiohttp import client, web +from aiohttp.pytest_plugin import AiohttpClient, AiohttpRawServer -async def test_simple_server(aiohttp_raw_server: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_simple_server( + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.BaseRequest) -> web.Response: return web.Response(text=str(request.rel_url)) server = await aiohttp_raw_server(handler) @@ -21,10 +23,12 @@ async def handler(request): assert txt == "/path/to" -async def test_unsupported_upgrade(aiohttp_raw_server, aiohttp_client) -> None: +async def test_unsupported_upgrade( + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient +) -> None: # don't fail if a client probes for an unsupported protocol upgrade # https://github.com/aio-libs/aiohttp/issues/6446#issuecomment-999032039 - async def handler(request: web.Request): + async def handler(request: web.BaseRequest) -> web.Response: return web.Response(body=await request.read()) upgrade_headers = {"Connection": "Upgrade", "Upgrade": "unsupported_proto"} @@ -38,14 +42,16 @@ async def handler(request: web.Request): async def test_raw_server_not_http_exception( - aiohttp_raw_server: Any, aiohttp_client: Any, loop: Any + aiohttp_raw_server: AiohttpRawServer, + aiohttp_client: AiohttpClient, + loop: asyncio.AbstractEventLoop, ) -> None: # disable debug mode not to print traceback loop.set_debug(False) exc = RuntimeError("custom runtime error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc logger = mock.Mock() @@ -63,13 +69,13 @@ async def handler(request): async def test_raw_server_handler_timeout( - aiohttp_raw_server: Any, aiohttp_client: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: loop = asyncio.get_event_loop() loop.set_debug(True) exc = asyncio.TimeoutError("error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc logger = mock.Mock() @@ -83,9 +89,9 @@ async def handler(request): async def test_raw_server_do_not_swallow_exceptions( - aiohttp_raw_server: Any, aiohttp_client: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise asyncio.CancelledError() loop = asyncio.get_event_loop() @@ -101,13 +107,13 @@ async def handler(request): async def test_raw_server_cancelled_in_write_eof( - aiohttp_raw_server: Any, aiohttp_client: Any -): + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient +) -> None: class MyResponse(web.Response): - async def write_eof(self, data=b""): + async def write_eof(self, data: bytes = b"") -> NoReturn: raise asyncio.CancelledError("error") - async def handler(request): + async def handler(request: web.BaseRequest) -> MyResponse: resp = MyResponse(text=str(request.rel_url)) return resp @@ -125,11 +131,11 @@ async def handler(request): async def test_raw_server_not_http_exception_debug( - aiohttp_raw_server: Any, aiohttp_client: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: exc = RuntimeError("custom runtime error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc loop = asyncio.get_event_loop() @@ -148,14 +154,16 @@ async def handler(request): async def test_raw_server_html_exception( - aiohttp_raw_server: Any, aiohttp_client: Any, loop: Any + aiohttp_raw_server: AiohttpRawServer, + aiohttp_client: AiohttpClient, + loop: asyncio.AbstractEventLoop, ) -> None: # disable debug mode not to print traceback loop.set_debug(False) exc = RuntimeError("custom runtime error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc logger = mock.Mock() @@ -177,11 +185,11 @@ async def handler(request): async def test_raw_server_html_exception_debug( - aiohttp_raw_server: Any, aiohttp_client: Any + aiohttp_raw_server: AiohttpRawServer, aiohttp_client: AiohttpClient ) -> None: exc = RuntimeError("custom runtime error") - async def handler(request): + async def handler(request: web.BaseRequest) -> NoReturn: raise exc loop = asyncio.get_event_loop() @@ -204,11 +212,11 @@ async def handler(request): logger.exception.assert_called_with("Error handling request", exc_info=exc) -async def test_handler_cancellation(aiohttp_unused_port) -> None: +async def test_handler_cancellation(aiohttp_unused_port: Callable[[], int]) -> None: event = asyncio.Event() port = aiohttp_unused_port() - async def on_request(_: web.Request) -> web.Response: + async def on_request(request: web.Request) -> web.Response: nonlocal event try: await asyncio.sleep(10) @@ -228,6 +236,7 @@ async def on_request(_: web.Request) -> web.Response: await site.start() + assert runner.server is not None try: assert runner.server.handler_cancellation, "Flag was not propagated" @@ -244,13 +253,13 @@ async def on_request(_: web.Request) -> web.Response: await asyncio.gather(runner.shutdown(), site.stop()) -async def test_no_handler_cancellation(aiohttp_unused_port) -> None: +async def test_no_handler_cancellation(aiohttp_unused_port: Callable[[], int]) -> None: timeout_event = asyncio.Event() done_event = asyncio.Event() port = aiohttp_unused_port() started = False - async def on_request(_: web.Request) -> web.Response: + async def on_request(request: web.Request) -> web.Response: nonlocal done_event, started, timeout_event started = True await asyncio.wait_for(timeout_event.wait(), timeout=5) diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 067634f0a4d..bc89bc78967 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -1,40 +1,55 @@ -# type: ignore import asyncio import time -from typing import Any +from typing import Optional, Protocol from unittest import mock import aiosignal import pytest from multidict import CIMultiDict +from pytest_mock import MockerFixture -from aiohttp import WSMsgType +from aiohttp import WSMsgType, web +from aiohttp.http import WS_CLOSED_MESSAGE, WSMessage from aiohttp.streams import EofStream from aiohttp.test_utils import make_mocked_coro, make_mocked_request -from aiohttp.web import HTTPBadRequest, WebSocketResponse -from aiohttp.web_ws import WS_CLOSED_MESSAGE, WebSocketReady, WSMessage +from aiohttp.web_ws import WebSocketReady + + +class _RequestMaker(Protocol): + def __call__( + self, + method: str, + path: str, + headers: Optional[CIMultiDict[str]] = None, + protocols: bool = False, + ) -> web.Request: ... @pytest.fixture -def app(loop: Any): - ret = mock.Mock() - ret.loop = loop - ret._debug = False - ret.on_response_prepare = aiosignal.Signal(ret) +def app(loop: asyncio.AbstractEventLoop) -> web.Application: + ret: web.Application = mock.create_autospec(web.Application, spec_set=True) + ret.on_response_prepare = aiosignal.Signal(ret) # type: ignore[misc] ret.on_response_prepare.freeze() return ret @pytest.fixture -def protocol(): +def protocol() -> web.RequestHandler[web.Request]: ret = mock.Mock() ret.set_parser.return_value = ret return ret @pytest.fixture -def make_request(app: Any, protocol: Any): - def maker(method, path, headers=None, protocols=False): +def make_request( + app: web.Application, protocol: web.RequestHandler[web.Request] +) -> _RequestMaker: + def maker( + method: str, + path: str, + headers: Optional[CIMultiDict[str]] = None, + protocols: bool = False, + ) -> web.Request: if headers is None: headers = CIMultiDict( { @@ -49,107 +64,106 @@ def maker(method, path, headers=None, protocols=False): if protocols: headers["SEC-WEBSOCKET-PROTOCOL"] = "chat, superchat" - return make_mocked_request( - method, path, headers, app=app, protocol=protocol, loop=app.loop - ) + return make_mocked_request(method, path, headers, app=app, protocol=protocol) return maker async def test_nonstarted_ping() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.ping() async def test_nonstarted_pong() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.pong() async def test_nonstarted_send_str() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_str("string") async def test_nonstarted_send_bytes() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_bytes(b"bytes") async def test_nonstarted_send_json() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.send_json({"type": "json"}) async def test_nonstarted_close() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.close() async def test_nonstarted_receive_str() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.receive_str() async def test_nonstarted_receive_bytes() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.receive_bytes() async def test_nonstarted_receive_json() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.receive_json() -async def test_send_str_nonstring(make_request: Any) -> None: +async def test_send_str_nonstring(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): - await ws.send_str(b"bytes") + await ws.send_str(b"bytes") # type: ignore[arg-type] -async def test_send_bytes_nonbytes(make_request: Any) -> None: +async def test_send_bytes_nonbytes(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): - await ws.send_bytes("string") + await ws.send_bytes("string") # type: ignore[arg-type] -async def test_send_json_nonjson(make_request: Any) -> None: +async def test_send_json_nonjson(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(TypeError): await ws.send_json(set()) async def test_write_non_prepared() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.write(b"data") -async def test_heartbeat_timeout(make_request: Any) -> None: +async def test_heartbeat_timeout(make_request: _RequestMaker) -> None: """Verify the transport is closed when the heartbeat timeout is reached.""" loop = asyncio.get_running_loop() future = loop.create_future() req = make_request("GET", "/") + assert req.transport is not None + req.transport.close.side_effect = lambda: future.set_result(None) # type: ignore[attr-defined] lowest_time = time.get_clock_info("monotonic").resolution req._protocol._timeout_ceil_threshold = lowest_time - ws = WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time) + ws = web.WebSocketResponse(heartbeat=lowest_time, timeout=lowest_time) await ws.prepare(req) - ws._req.transport.close.side_effect = lambda: future.set_result(None) await future assert ws.closed @@ -182,27 +196,27 @@ def test_bool_websocket_not_ready() -> None: assert bool(websocket_ready) is False -def test_can_prepare_ok(make_request: Any) -> None: +def test_can_prepare_ok(make_request: _RequestMaker) -> None: req = make_request("GET", "/", protocols=True) - ws = WebSocketResponse(protocols=("chat",)) + ws = web.WebSocketResponse(protocols=("chat",)) assert WebSocketReady(True, "chat") == ws.can_prepare(req) -def test_can_prepare_unknown_protocol(make_request: Any) -> None: +def test_can_prepare_unknown_protocol(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() assert WebSocketReady(True, None) == ws.can_prepare(req) -def test_can_prepare_without_upgrade(make_request: Any) -> None: +def test_can_prepare_without_upgrade(make_request: _RequestMaker) -> None: req = make_request("GET", "/", headers=CIMultiDict({})) - ws = WebSocketResponse() + ws = web.WebSocketResponse() assert WebSocketReady(False, None) == ws.can_prepare(req) -async def test_can_prepare_started(make_request: Any) -> None: +async def test_can_prepare_started(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) with pytest.raises(RuntimeError) as ctx: ws.can_prepare(req) @@ -211,27 +225,30 @@ async def test_can_prepare_started(make_request: Any) -> None: def test_closed_after_ctor() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() assert not ws.closed assert ws.close_code is None -async def test_send_str_closed(make_request: Any) -> None: +async def test_send_str_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() - assert len(ws._req.transport.close.mock_calls) == 1 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] with pytest.raises(ConnectionError): await ws.send_str("string") -async def test_send_bytes_closed(make_request: Any) -> None: +async def test_send_bytes_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -239,10 +256,11 @@ async def test_send_bytes_closed(make_request: Any) -> None: await ws.send_bytes(b"bytes") -async def test_send_json_closed(make_request: Any) -> None: +async def test_send_json_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -250,10 +268,11 @@ async def test_send_json_closed(make_request: Any) -> None: await ws.send_json({"type": "json"}) -async def test_ping_closed(make_request: Any) -> None: +async def test_ping_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -261,10 +280,11 @@ async def test_ping_closed(make_request: Any) -> None: await ws.ping() -async def test_pong_closed(make_request: Any, mocker: Any) -> None: +async def test_pong_closed(make_request: _RequestMaker, mocker: MockerFixture) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -272,62 +292,70 @@ async def test_pong_closed(make_request: Any, mocker: Any) -> None: await ws.pong() -async def test_close_idempotent(make_request: Any) -> None: +async def test_close_idempotent(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) - assert await ws.close(code=1, message="message1") + close_code = await ws.close(code=1, message=b"message1") + assert close_code == 1 assert ws.closed - assert len(ws._req.transport.close.mock_calls) == 1 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] - assert not (await ws.close(code=2, message="message2")) + close_code = await ws.close(code=2, message=b"message2") + assert close_code == 0 -async def test_prepare_post_method_ok(make_request: Any) -> None: +async def test_prepare_post_method_ok(make_request: _RequestMaker) -> None: req = make_request("POST", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) assert ws.prepared -async def test_prepare_without_upgrade(make_request: Any) -> None: +async def test_prepare_without_upgrade(make_request: _RequestMaker) -> None: req = make_request("GET", "/", headers=CIMultiDict({})) - ws = WebSocketResponse() - with pytest.raises(HTTPBadRequest): + ws = web.WebSocketResponse() + with pytest.raises(web.HTTPBadRequest): await ws.prepare(req) async def test_wait_closed_before_start() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.close() async def test_write_eof_not_started() -> None: - ws = WebSocketResponse() + ws = web.WebSocketResponse() with pytest.raises(RuntimeError): await ws.write_eof() -async def test_write_eof_idempotent(make_request: Any) -> None: +async def test_write_eof_idempotent(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) - assert len(ws._req.transport.close.mock_calls) == 0 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() await ws.write_eof() await ws.write_eof() await ws.write_eof() - assert len(ws._req.transport.close.mock_calls) == 1 + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] -async def test_receive_eofstream_in_reader(make_request: Any, loop: Any) -> None: +async def test_receive_eofstream_in_reader( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) ws._reader = mock.Mock() @@ -335,18 +363,20 @@ async def test_receive_eofstream_in_reader(make_request: Any, loop: Any) -> None res = loop.create_future() res.set_exception(exc) ws._reader.read = make_mocked_coro(res) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) - + assert ws._payload_writer is not None + f = loop.create_future() + f.set_result(True) + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] msg = await ws.receive() assert msg.type == WSMsgType.CLOSED assert ws.closed -async def test_receive_exception_in_reader(make_request: Any, loop: Any) -> None: +async def test_receive_exception_in_reader( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) ws._reader = mock.Mock() @@ -354,47 +384,56 @@ async def test_receive_exception_in_reader(make_request: Any, loop: Any) -> None res = loop.create_future() res.set_exception(exc) ws._reader.read = make_mocked_coro(res) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) + f = loop.create_future() + assert ws._payload_writer is not None + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] + f.set_result(True) msg = await ws.receive() assert msg.type == WSMsgType.ERROR assert ws.closed - assert len(ws._req.transport.close.mock_calls) == 1 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] -async def test_receive_close_but_left_open(make_request: Any, loop: Any) -> None: +async def test_receive_close_but_left_open( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) close_message = WSMessage(WSMsgType.CLOSE, 1000, "close") ws._reader = mock.Mock() ws._reader.read = mock.AsyncMock(return_value=close_message) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) + f = loop.create_future() + assert ws._payload_writer is not None + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] + f.set_result(True) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE assert ws.closed - assert len(ws._req.transport.close.mock_calls) == 1 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] -async def test_receive_closing(make_request: Any, loop: Any) -> None: +async def test_receive_closing( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) closing_message = WSMessage(WSMsgType.CLOSING, 1000, "closing") ws._reader = mock.Mock() read_mock = mock.AsyncMock(return_value=closing_message) ws._reader.read = read_mock - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) + f = loop.create_future() + assert ws._payload_writer is not None + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] + f.set_result(True) msg = await ws.receive() assert msg.type == WSMsgType.CLOSING assert not ws.closed @@ -409,33 +448,40 @@ async def test_receive_closing(make_request: Any, loop: Any) -> None: assert msg.type == WSMsgType.CLOSING -async def test_close_after_closing(make_request: Any, loop: Any) -> None: +async def test_close_after_closing( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) closing_message = WSMessage(WSMsgType.CLOSING, 1000, "closing") ws._reader = mock.Mock() ws._reader.read = mock.AsyncMock(return_value=closing_message) - ws._payload_writer.drain = mock.Mock() - ws._payload_writer.drain.return_value = loop.create_future() - ws._payload_writer.drain.return_value.set_result(True) + f = loop.create_future() + assert ws._payload_writer is not None + ws._payload_writer.drain.return_value = f # type: ignore[attr-defined] + f.set_result(True) msg = await ws.receive() assert msg.type == WSMsgType.CLOSING assert not ws.closed - assert len(ws._req.transport.close.mock_calls) == 0 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] await ws.close() assert ws.closed - assert len(ws._req.transport.close.mock_calls) == 1 + assert len(req.transport.close.mock_calls) == 1 -async def test_receive_timeouterror(make_request: Any, loop: Any) -> None: +async def test_receive_timeouterror( + make_request: _RequestMaker, loop: asyncio.AbstractEventLoop +) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) - assert len(ws._req.transport.close.mock_calls) == 0 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] ws._reader = mock.Mock() res = loop.create_future() @@ -446,13 +492,16 @@ async def test_receive_timeouterror(make_request: Any, loop: Any) -> None: await ws.receive() # Should not close the connection on timeout - assert len(ws._req.transport.close.mock_calls) == 0 + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] -async def test_multiple_receive_on_close_connection(make_request: Any) -> None: +async def test_multiple_receive_on_close_connection( + make_request: _RequestMaker, +) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) + assert ws._reader is not None ws._reader.feed_data(WS_CLOSED_MESSAGE) await ws.close() @@ -465,9 +514,9 @@ async def test_multiple_receive_on_close_connection(make_request: Any) -> None: await ws.receive() -async def test_concurrent_receive(make_request: Any) -> None: +async def test_concurrent_receive(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) ws._waiting = True @@ -475,11 +524,12 @@ async def test_concurrent_receive(make_request: Any) -> None: await ws.receive() -async def test_close_exc(make_request: Any) -> None: +async def test_close_exc(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) - assert len(ws._req.transport.close.mock_calls) == 0 + assert req.transport is not None + assert len(req.transport.close.mock_calls) == 0 # type: ignore[attr-defined] exc = ValueError() ws._writer = mock.Mock() @@ -487,7 +537,7 @@ async def test_close_exc(make_request: Any) -> None: await ws.close() assert ws.closed assert ws.exception() is exc - assert len(ws._req.transport.close.mock_calls) == 1 + assert len(req.transport.close.mock_calls) == 1 # type: ignore[attr-defined] ws._closed = False ws._writer.close.side_effect = asyncio.CancelledError() @@ -495,34 +545,37 @@ async def test_close_exc(make_request: Any) -> None: await ws.close() -async def test_prepare_twice_idempotent(make_request: Any) -> None: +async def test_prepare_twice_idempotent(make_request: _RequestMaker) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() impl1 = await ws.prepare(req) impl2 = await ws.prepare(req) assert impl1 is impl2 -async def test_send_with_per_message_deflate(make_request: Any, mocker: Any) -> None: +async def test_send_with_per_message_deflate( + make_request: _RequestMaker, mocker: MockerFixture +) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) - writer_send = ws._writer.send = make_mocked_coro() - - await ws.send_str("string", compress=15) - writer_send.assert_called_with("string", binary=False, compress=15) + with mock.patch.object(ws._writer, "send", autospec=True, spec_set=True) as m: + await ws.send_str("string", compress=15) + m.assert_called_with("string", binary=False, compress=15) - await ws.send_bytes(b"bytes", compress=0) - writer_send.assert_called_with(b"bytes", binary=True, compress=0) + await ws.send_bytes(b"bytes", compress=0) + m.assert_called_with(b"bytes", binary=True, compress=0) - await ws.send_json("[{}]", compress=9) - writer_send.assert_called_with('"[{}]"', binary=False, compress=9) + await ws.send_json("[{}]", compress=9) + m.assert_called_with('"[{}]"', binary=False, compress=9) -async def test_no_transfer_encoding_header(make_request: Any, mocker: Any) -> None: +async def test_no_transfer_encoding_header( + make_request: _RequestMaker, mocker: MockerFixture +) -> None: req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws._start(req) assert "Transfer-Encoding" not in ws.headers @@ -546,13 +599,16 @@ async def test_no_transfer_encoding_header(make_request: Any, mocker: Any) -> No ], ) async def test_get_extra_info( - make_request: Any, mocker: Any, ws_transport: Any, expected_result: Any + make_request: _RequestMaker, + mocker: MockerFixture, + ws_transport: Optional[mock.MagicMock], + expected_result: str, ) -> None: valid_key = "test" default_value = "default" req = make_request("GET", "/") - ws = WebSocketResponse() + ws = web.WebSocketResponse() await ws.prepare(req) ws._writer = ws_transport diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index ff0a4cc6b47..f718126b630 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1,11 +1,10 @@ -# type: ignore # HTTP websocket server functional tests import asyncio import contextlib import sys import weakref -from typing import Any, NoReturn, Optional +from typing import NoReturn, Optional from unittest import mock import pytest @@ -13,16 +12,16 @@ import aiohttp from aiohttp import WSServerHandshakeError, web from aiohttp.http import WSCloseCode, WSMsgType -from aiohttp.pytest_plugin import AiohttpClient +from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer -async def test_websocket_can_prepare(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_can_prepare( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() - if not ws.can_prepare(request): - raise web.HTTPUpgradeRequired() - - return web.Response() + assert not ws.can_prepare(request) + raise web.HTTPUpgradeRequired() app = web.Application() app.router.add_route("GET", "/", handler) @@ -32,11 +31,12 @@ async def handler(request): assert resp.status == 426 -async def test_websocket_json(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_json( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() - if not ws.can_prepare(request): - return web.HTTPUpgradeRequired() + assert ws.can_prepare(request) await ws.prepare(request) msg = await ws.receive() @@ -61,8 +61,10 @@ async def handler(request): assert resp.data == expected_value -async def test_websocket_json_invalid_message(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_json_invalid_message( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) try: @@ -87,8 +89,10 @@ async def handler(request): assert "ValueError was raised" in data -async def test_websocket_send_json(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_send_json( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -110,8 +114,10 @@ async def handler(request): assert data["test"] == expected_value -async def test_websocket_receive_json(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_receive_json( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -135,10 +141,12 @@ async def handler(request): assert resp.data == expected_value -async def test_send_recv_text(loop: Any, aiohttp_client: Any) -> None: +async def test_send_recv_text( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive_str() @@ -168,10 +176,12 @@ async def handler(request): await closed -async def test_send_recv_bytes(loop: Any, aiohttp_client: Any) -> None: +async def test_send_recv_bytes( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -202,10 +212,12 @@ async def handler(request): await closed -async def test_send_recv_json(loop: Any, aiohttp_client: Any) -> None: +async def test_send_recv_json( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) data = await ws.receive_json() @@ -236,16 +248,19 @@ async def handler(request): await closed -async def test_close_timeout(loop: Any, aiohttp_client: Any) -> None: +async def test_close_timeout( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: aborted = loop.create_future() elapsed = 1e10 # something big - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal elapsed ws = web.WebSocketResponse(timeout=0.1) await ws.prepare(request) assert "request" == (await ws.receive_str()) await ws.send_str("reply") + assert ws._loop is not None begin = ws._loop.time() assert await ws.close() elapsed = ws._loop.time() - begin @@ -277,10 +292,12 @@ async def handler(request): await ws.close() -async def test_concurrent_close(loop: Any, aiohttp_client: Any) -> None: +async def test_concurrent_close( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: srv_ws = None - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) await ws.prepare(request) @@ -304,6 +321,7 @@ async def handler(request): ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) + assert srv_ws is not None await srv_ws.close(code=WSCloseCode.INVALID_TEXT) msg = await ws.receive() @@ -314,10 +332,12 @@ async def handler(request): assert msg.type == WSMsgType.CLOSED -async def test_concurrent_close_multiple_tasks(loop: Any, aiohttp_client: Any) -> None: +async def test_concurrent_close_multiple_tasks( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: srv_ws = None - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) await ws.prepare(request) @@ -341,6 +361,7 @@ async def handler(request): ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar")) + assert srv_ws is not None task1 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) task2 = asyncio.create_task(srv_ws.close(code=WSCloseCode.INVALID_TEXT)) @@ -355,10 +376,12 @@ async def handler(request): assert msg.type == WSMsgType.CLOSED -async def test_close_op_code_from_client(loop: Any, aiohttp_client: Any) -> None: +async def test_close_op_code_from_client( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: srv_ws: Optional[web.WebSocketResponse] = None - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws ws = srv_ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) @@ -372,7 +395,7 @@ async def handler(request): app.router.add_get("/", handler) client = await aiohttp_client(app) - ws: web.WebSocketResponse = await client.ws_connect("/", protocols=("eggs", "bar")) + ws = await client.ws_connect("/", protocols=("eggs", "bar")) await ws._writer._send_frame(b"", WSMsgType.CLOSE) @@ -384,10 +407,12 @@ async def handler(request): assert msg.type == WSMsgType.CLOSED -async def test_auto_pong_with_closing_by_peer(loop: Any, aiohttp_client: Any) -> None: +async def test_auto_pong_with_closing_by_peer( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) await ws.receive() @@ -409,18 +434,20 @@ async def handler(request): msg = await ws.receive() assert msg.type == WSMsgType.PONG - await ws.close(code=WSCloseCode.OK, message="exit message") + await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed -async def test_ping(loop: Any, aiohttp_client: Any) -> None: +async def test_ping( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) - await ws.ping("data") + await ws.ping(b"data") await ws.receive() closed.set_result(None) return ws @@ -439,10 +466,12 @@ async def handler(request): await closed -async def aiohttp_client_ping(loop: Any, aiohttp_client: Any): +async def test_client_ping( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -456,7 +485,7 @@ async def handler(request): ws = await client.ws_connect("/", autoping=False) - await ws.ping("data") + await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" @@ -464,16 +493,18 @@ async def handler(request): await ws.close() -async def test_pong(loop: Any, aiohttp_client: Any) -> None: +async def test_pong( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) msg = await ws.receive() assert msg.type == WSMsgType.PING - await ws.pong("data") + await ws.pong(b"data") msg = await ws.receive() assert msg.type == WSMsgType.CLOSE @@ -488,20 +519,22 @@ async def handler(request): ws = await client.ws_connect("/", autoping=False) - await ws.ping("data") + await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" - await ws.close(code=WSCloseCode.OK, message="exit message") + await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed -async def test_change_status(loop: Any, aiohttp_client: Any) -> None: +async def test_change_status( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() ws.set_status(200) assert 200 == ws.status @@ -522,10 +555,12 @@ async def handler(request): await ws.close() -async def test_handle_protocol(loop: Any, aiohttp_client: Any) -> None: +async def test_handle_protocol( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) await ws.close() @@ -543,10 +578,12 @@ async def handler(request): await closed -async def test_server_close_handshake(loop: Any, aiohttp_client: Any) -> None: +async def test_server_close_handshake( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) await ws.close() @@ -565,10 +602,12 @@ async def handler(request): await closed -async def aiohttp_client_close_handshake(loop: Any, aiohttp_client: Any): +async def test_client_close_handshake( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar")) await ws.prepare(request) @@ -598,11 +637,11 @@ async def handler(request): async def test_server_close_handshake_server_eats_client_messages( - loop: Any, aiohttp_client: Any -): + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(protocols=("foo", "bar")) await ws.prepare(request) await ws.close() @@ -628,10 +667,12 @@ async def handler(request): await closed -async def test_receive_timeout(loop: Any, aiohttp_client: Any) -> None: +async def test_receive_timeout( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: raised = False - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(receive_timeout=0.1) await ws.prepare(request) @@ -654,10 +695,12 @@ async def handler(request): assert raised -async def test_custom_receive_timeout(loop: Any, aiohttp_client: Any) -> None: +async def test_custom_receive_timeout( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: raised = False - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(receive_timeout=None) await ws.prepare(request) @@ -680,8 +723,10 @@ async def handler(request): assert raised -async def test_heartbeat(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_heartbeat( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) await ws.receive() @@ -700,8 +745,10 @@ async def handler(request): await ws.close() -async def test_heartbeat_no_pong(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_heartbeat_no_pong( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) @@ -731,6 +778,8 @@ async def handler(request: web.Request) -> NoReturn: # We patch write here to simulate a connection reset error # since if we closed the connection normally, the server would # would cancel the heartbeat task and we wouldn't get a ping + assert ws_server._req is not None + assert ws_server._writer is not None with mock.patch.object( ws_server._req.transport, "write", side_effect=ConnectionResetError ), mock.patch.object( @@ -789,11 +838,11 @@ async def handler(request: web.Request) -> NoReturn: async def test_heartbeat_no_pong_send_many_messages( - loop: Any, aiohttp_client: Any + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test no pong after sending many messages.""" - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) for _ in range(10): @@ -818,11 +867,11 @@ async def handler(request): async def test_heartbeat_no_pong_receive_many_messages( - loop: Any, aiohttp_client: Any + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: """Test no pong after receiving many messages.""" - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(heartbeat=0.05) await ws.prepare(request) for _ in range(10): @@ -845,10 +894,12 @@ async def handler(request): await ws.close() -async def test_server_ws_async_for(loop: Any, aiohttp_server: Any) -> None: +async def test_server_ws_async_for( + loop: asyncio.AbstractEventLoop, aiohttp_server: AiohttpServer +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) async for msg in ws: @@ -876,10 +927,12 @@ async def handler(request): await closed -async def test_closed_async_for(loop: Any, aiohttp_client: Any) -> None: +async def test_closed_async_for( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() await ws.prepare(request) @@ -911,8 +964,10 @@ async def handler(request): await closed -async def test_websocket_disable_keepalive(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_websocket_disable_keepalive( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> web.StreamResponse: ws = web.WebSocketResponse() if not ws.can_prepare(request): return web.Response(text="OK") @@ -938,11 +993,12 @@ async def handler(request): assert data == "OK" -async def test_receive_str_nonstring(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_receive_str_nonstring( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() - if not ws.can_prepare(request): - return web.HTTPUpgradeRequired() + assert ws.can_prepare(request) await ws.prepare(request) await ws.send_bytes(b"answer") @@ -958,16 +1014,16 @@ async def handler(request): await ws.receive_str() -async def test_receive_bytes_nonbytes(loop: Any, aiohttp_client: Any) -> None: - async def handler(request): +async def test_receive_bytes_nonbytes( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handler(request: web.Request) -> NoReturn: ws = web.WebSocketResponse() - if not ws.can_prepare(request): - return web.HTTPUpgradeRequired() + assert ws.can_prepare(request) await ws.prepare(request) - await ws.send_bytes("answer") - await ws.close() - return ws + await ws.send_bytes("answer") # type: ignore[arg-type] + assert False app = web.Application() app.router.add_route("GET", "/", handler) @@ -978,11 +1034,13 @@ async def handler(request): await ws.receive_bytes() -async def test_bug3380(loop: Any, aiohttp_client: Any) -> None: - async def handle_null(request): - return aiohttp.web.json_response({"err": None}) +async def test_bug3380( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + async def handle_null(request: web.Request) -> web.Response: + return web.json_response({"err": None}) - async def ws_handler(request): + async def ws_handler(request: web.Request) -> web.Response: return web.Response(status=401) app = web.Application() @@ -1004,11 +1062,11 @@ async def ws_handler(request): async def test_receive_being_cancelled_keeps_connection_open( - loop: Any, aiohttp_client: Any + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) @@ -1021,7 +1079,7 @@ async def handler(request): msg = await ws.receive() assert msg.type == WSMsgType.PING await asyncio.sleep(0) - await ws.pong("data") + await ws.pong(b"data") msg = await ws.receive() assert msg.type == WSMsgType.CLOSE @@ -1037,24 +1095,24 @@ async def handler(request): ws = await client.ws_connect("/", autoping=False) await asyncio.sleep(0) - await ws.ping("data") + await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" - await ws.close(code=WSCloseCode.OK, message="exit message") + await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed async def test_receive_timeout_keeps_connection_open( - loop: Any, aiohttp_client: Any + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: closed = loop.create_future() timed_out = loop.create_future() - async def handler(request): + async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(autoping=False) await ws.prepare(request) @@ -1067,7 +1125,7 @@ async def handler(request): msg = await ws.receive() assert msg.type == WSMsgType.PING await asyncio.sleep(0) - await ws.pong("data") + await ws.pong(b"data") msg = await ws.receive() assert msg.type == WSMsgType.CLOSE @@ -1083,13 +1141,13 @@ async def handler(request): ws = await client.ws_connect("/", autoping=False) await timed_out - await ws.ping("data") + await ws.ping(b"data") msg = await ws.receive() assert msg.type == WSMsgType.PONG assert msg.data == b"data" - await ws.close(code=WSCloseCode.OK, message="exit message") + await ws.close(code=WSCloseCode.OK, message=b"exit message") await closed @@ -1098,11 +1156,13 @@ async def test_websocket_shutdown(aiohttp_client: AiohttpClient) -> None: """Test that the client websocket gets the close message when the server is shutting down.""" url = "/ws" app = web.Application() - websockets = web.AppKey("websockets", weakref.WeakSet) + websockets = web.AppKey("websockets", weakref.WeakSet[web.WebSocketResponse]) app[websockets] = weakref.WeakSet() # need for send signal shutdown server - shutdown_websockets = web.AppKey("shutdown_websockets", weakref.WeakSet) + shutdown_websockets = web.AppKey( + "shutdown_websockets", weakref.WeakSet[web.WebSocketResponse] + ) app[shutdown_websockets] = weakref.WeakSet() async def websocket_handler(request: web.Request) -> web.WebSocketResponse: @@ -1124,7 +1184,7 @@ async def on_shutdown(app: web.Application) -> None: websocket = app[shutdown_websockets].pop() await websocket.close( code=aiohttp.WSCloseCode.GOING_AWAY, - message="Server shutdown", + message=b"Server shutdown", ) app.router.add_get(url, websocket_handler)