From b93ef57cce9ea95954d17b7e4ca11883a4da2471 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 18 Sep 2024 17:35:53 +0200 Subject: [PATCH] Improve performance of starting web requests (#9172) --- CHANGES/9172.misc.rst | 1 + aiohttp/web_response.py | 6 +++--- tests/test_web_response.py | 16 ++++++++++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) create mode 120000 CHANGES/9172.misc.rst diff --git a/CHANGES/9172.misc.rst b/CHANGES/9172.misc.rst new file mode 120000 index 0000000000..d6a2f2aaaa --- /dev/null +++ b/CHANGES/9172.misc.rst @@ -0,0 +1 @@ +9174.misc.rst \ No newline at end of file diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index f1c25ba4e8..1dde889115 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -679,10 +679,10 @@ async def write_eof(self, data: bytes = b"") -> None: await super().write_eof() async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: - if should_remove_content_length(request.method, self.status): - if hdrs.CONTENT_LENGTH in self._headers: + if hdrs.CONTENT_LENGTH in self._headers: + if should_remove_content_length(request.method, self.status): del self._headers[hdrs.CONTENT_LENGTH] - elif not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: + elif not self._chunked: if isinstance(self._body, Payload): if self._body.size is not None: self._headers[hdrs.CONTENT_LENGTH] = str(self._body.size) diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 3cd1545c14..1c052d42ba 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -627,6 +627,22 @@ async def write_headers(status_line, headers): assert resp.content_length is None +async def test_rm_content_length_if_204() -> None: + """Ensure content-length is removed for 204 responses.""" + writer = mock.create_autospec(StreamWriter, spec_set=True, instance=True) + + async def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH not in headers + + writer.write_headers.side_effect = write_headers + req = make_request("GET", "/", writer=writer) + payload = BytesPayload(b"answer", headers={"Content-Length": "6"}) + resp = Response(body=payload, status=204) + resp.body = payload + await resp.prepare(req) + assert resp.content_length is None + + @pytest.mark.parametrize("status", (100, 101, 204, 304)) async def test_rm_transfer_encoding_rfc_9112_6_3_http_11(status: int) -> None: """Remove transfer encoding for RFC 9112 sec 6.3 with HTTP/1.1."""