diff --git a/pyproject.toml b/pyproject.toml index 52156528e..b993c2a2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,4 +103,5 @@ exclude_lines = [ "pragma: nocover", "if typing.TYPE_CHECKING:", "@typing.overload", + "raise NotImplementedError", ] diff --git a/starlette/responses.py b/starlette/responses.py index fc92cbab1..b951e5125 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -374,13 +374,7 @@ async def _handle_simple(self, send: Send, send_header_only: bool) -> None: while more_body: chunk = await file.read(self.chunk_size) more_body = len(chunk) == self.chunk_size - await send( - { - "type": "http.response.body", - "body": chunk, - "more_body": more_body, - } - ) + await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) async def _handle_single_range( self, send: Send, start: int, end: int, file_size: int, send_header_only: bool @@ -419,10 +413,12 @@ async def _handle_multiple_ranges( else: async with await anyio.open_file(self.path, mode="rb") as file: for start, end in ranges: - await file.seek(start) - chunk = await file.read(min(self.chunk_size, end - start)) await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) - await send({"type": "http.response.body", "body": chunk, "more_body": True}) + await file.seek(start) + while start < end: + chunk = await file.read(min(self.chunk_size, end - start)) + start += len(chunk) + await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"\n", "more_body": True}) await send( { diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 15080e5c5..e2375a7b9 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -289,7 +289,7 @@ async def passthrough( } async def receive() -> Message: - raise NotImplementedError("Should not be called!") # pragma: no cover + raise NotImplementedError("Should not be called!") async def send(message: Message) -> None: if message["type"] == "http.response.body": @@ -330,7 +330,7 @@ async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> R } async def receive() -> Message: - raise NotImplementedError("Should not be called!") # pragma: no cover + raise NotImplementedError("Should not be called!") async def send(message: Message) -> None: if message["type"] == "http.response.body": @@ -403,7 +403,7 @@ async def passthrough( } async def receive() -> Message: - raise NotImplementedError("Should not be called!") # pragma: no cover + raise NotImplementedError("Should not be called!") async def send(message: Message) -> None: if message["type"] == "http.response.body": diff --git a/tests/test_responses.py b/tests/test_responses.py index 359516c75..645d26a68 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -4,7 +4,7 @@ import time from http.cookies import SimpleCookie from pathlib import Path -from typing import AsyncIterator, Iterator +from typing import Any, AsyncIterator, Iterator import anyio import pytest @@ -682,3 +682,58 @@ def test_file_response_insert_ranges(file_response_client: TestClient) -> None: "", f"--{boundary}--", ] + + +@pytest.mark.anyio +async def test_file_response_multi_small_chunk_size(readme_file: Path) -> None: + class SmallChunkSizeFileResponse(FileResponse): + chunk_size = 10 + + app = SmallChunkSizeFileResponse(path=str(readme_file)) + + received_chunks: list[bytes] = [] + start_message: dict[str, Any] = {} + + async def receive() -> Message: + raise NotImplementedError("Should not be called!") + + async def send(message: Message) -> None: + if message["type"] == "http.response.start": + start_message.update(message) + elif message["type"] == "http.response.body": + received_chunks.append(message["body"]) + + await app({"type": "http", "method": "get", "headers": [(b"range", b"bytes=0-15,20-35,35-50")]}, receive, send) + assert start_message["status"] == 206 + + headers = Headers(raw=start_message["headers"]) + assert headers.get("content-type") == "text/plain; charset=utf-8" + assert headers.get("accept-ranges") == "bytes" + assert "content-length" in headers + assert "last-modified" in headers + assert "etag" in headers + assert headers["content-range"].startswith("multipart/byteranges; boundary=") + boundary = headers["content-range"].split("boundary=")[1] + + assert received_chunks == [ + # Send the part headers. + f"--{boundary}\nContent-Type: text/plain; charset=utf-8\nContent-Range: bytes 0-15/526\n\n".encode(), + # Send the first chunk (10 bytes). + b"# B\xc3\xa1iZ\xc3\xa9\n", + # Send the second chunk (6 bytes). + b"\nPower", + # Send the new line to separate the parts. + b"\n", + # Send the part headers. We merge the ranges 20-35 and 35-50 into a single part. + f"--{boundary}\nContent-Type: text/plain; charset=utf-8\nContent-Range: bytes 20-50/526\n\n".encode(), + # Send the first chunk (10 bytes). + b"and exquis", + # Send the second chunk (10 bytes). + b"ite WSGI/A", + # Send the third chunk (10 bytes). + b"SGI framew", + # Send the last chunk (1 byte). + b"o", + b"\n", + f"\n--{boundary}--\n".encode(), + ]