Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #8657/6c6ecfaf backport][3.10] Fix multipart reading with split boundary #8658

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/8653.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed multipart reading when stream buffer splits the boundary over several read() calls -- by :user:`Dreamsorcerer`.
19 changes: 15 additions & 4 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(
) -> None:
self.headers = headers
self._boundary = boundary
self._boundary_len = len(boundary) + 2 # Boundary + \r\n
self._content = content
self._default_charset = default_charset
self._at_eof = False
Expand Down Expand Up @@ -346,15 +347,25 @@ async def _read_chunk_from_stream(self, size: int) -> bytes:
# Reads content chunk of body part with unknown length.
# The Content-Length header for body part is not necessary.
assert (
size >= len(self._boundary) + 2
size >= self._boundary_len
), "Chunk size must be greater or equal than boundary length + 2"
first_chunk = self._prev_chunk is None
if first_chunk:
self._prev_chunk = await self._content.read(size)

chunk = await self._content.read(size)
self._content_eof += int(self._content.at_eof())
assert self._content_eof < 3, "Reading after EOF"
chunk = b""
# content.read() may return less than size, so we need to loop to ensure
# we have enough data to detect the boundary.
while len(chunk) < self._boundary_len:
chunk += await self._content.read(size)
self._content_eof += int(self._content.at_eof())
assert self._content_eof < 3, "Reading after EOF"
if self._content_eof:
break
if len(chunk) > size:
self._content.unread_data(chunk[size:])
chunk = chunk[:size]

assert self._prev_chunk is not None
window = self._prev_chunk + chunk
sub = b"\r\n" + self._boundary
Expand Down
61 changes: 61 additions & 0 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import json
import pathlib
import sys
import zlib
from unittest import mock

Expand Down Expand Up @@ -754,6 +755,66 @@ async def test_invalid_boundary(self) -> None:
with pytest.raises(ValueError):
await reader.next()

@pytest.mark.skipif(sys.version_info < (3, 10), reason="Needs anext()")
async def test_read_boundary_across_chunks(self) -> None:
class SplitBoundaryStream:
def __init__(self) -> None:
self.content = [
b"--foobar\r\n\r\n",
b"Hello,\r\n-",
b"-fo",
b"ob",
b"ar\r\n",
b"\r\nwor",
b"ld!",
b"\r\n--f",
b"oobar--",
]

async def read(self, size=None) -> bytes:
chunk = self.content.pop(0)
assert len(chunk) <= size
return chunk

def at_eof(self) -> bool:
return not self.content

async def readline(self) -> bytes:
line = b""
while self.content and b"\n" not in line:
line += self.content.pop(0)
line, *extra = line.split(b"\n", maxsplit=1)
if extra and extra[0]:
self.content.insert(0, extra[0])
return line + b"\n"

def unread_data(self, data: bytes) -> None:
if self.content:
self.content[0] = data + self.content[0]
else:
self.content.append(data)

stream = SplitBoundaryStream()
reader = aiohttp.MultipartReader(
{CONTENT_TYPE: 'multipart/related;boundary="foobar"'}, stream
)
part = await anext(reader)
result = await part.read_chunk(10)
assert result == b"Hello,"
result = await part.read_chunk(10)
assert result == b""
assert part.at_eof()

part = await anext(reader)
result = await part.read_chunk(10)
assert result == b"world!"
result = await part.read_chunk(10)
assert result == b""
assert part.at_eof()

with pytest.raises(StopAsyncIteration):
await anext(reader)

async def test_release(self) -> None:
with Stream(
newline.join(
Expand Down
Loading