diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 17dc46eb6..192a6ca07 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -286,7 +286,7 @@ def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: return self._dict[key] return default - def __getitem__(self, key: typing.Any) -> str: + def __getitem__(self, key: typing.Any) -> typing.Any: return self._dict[key] def __contains__(self, key: typing.Any) -> bool: @@ -459,6 +459,8 @@ class FormData(ImmutableMultiDict): An immutable multidict, containing both file uploads and text input. """ + headers: typing.Optional[ImmutableMultiDict] + def __init__( self, *args: typing.Union[ @@ -466,9 +468,17 @@ def __init__( typing.Mapping[str, typing.Union[str, UploadFile]], typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], ], - **kwargs: typing.Union[str, UploadFile], + raw_headers: typing.Optional[ + typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]] + ] = None, ) -> None: - super().__init__(*args, **kwargs) + super().__init__(*args) + if raw_headers is not None: + self.headers = ImmutableMultiDict( + [(field_name, Headers(raw=raw)) for field_name, raw in raw_headers] + ) + else: + self.headers = None async def close(self) -> None: for key, value in self.multi_items(): diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 1614a9d69..db493ba70 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -184,6 +184,10 @@ async def parse(self) -> FormData: file: typing.Optional[UploadFile] = None items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] + item_headers: typing.List[typing.Tuple[bytes, bytes]] = [] + raw_headers: typing.List[ + typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]] + ] = [] # Feed the parser with data from the request. async for chunk in self.stream: @@ -195,6 +199,7 @@ async def parse(self) -> FormData: content_disposition = None content_type = b"" data = b"" + item_headers = [] elif message_type == MultiPartMessage.HEADER_FIELD: header_field += message_bytes elif message_type == MultiPartMessage.HEADER_VALUE: @@ -205,6 +210,7 @@ async def parse(self) -> FormData: content_disposition = header_value elif field == b"content-type": content_type = header_value + item_headers.append((field, header_value)) header_field = b"" header_value = b"" elif message_type == MultiPartMessage.HEADERS_FINISHED: @@ -229,6 +235,7 @@ async def parse(self) -> FormData: else: await file.seek(0) items.append((field_name, file)) + raw_headers.append((field_name, item_headers)) parser.finalize() - return FormData(items) + return FormData(items, raw_headers=raw_headers) diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 8a1174e1d..17413dc55 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -29,6 +29,11 @@ async def app(scope, receive, send): else: output[key] = value await request.close() + if data.headers is not None: + output["__headers"] = [ + (field_name, dict(headers.items())) + for field_name, headers in data.headers.multi_items() + ] response = JSONResponse(output) await response(scope, receive, send) @@ -52,6 +57,11 @@ async def multi_items_app(scope, receive, send): else: output[key].append(value) await request.close() + if data.headers is not None: + output["__headers"] = [ + (field_name, dict(headers.items())) + for field_name, headers in data.headers.multi_items() + ] response = JSONResponse(output) await response(scope, receive, send) @@ -65,6 +75,11 @@ async def app_read_body(scope, receive, send): for key, value in data.items(): output[key] = value await request.close() + if data.headers is not None: + output["__headers"] = [ + (field_name, dict(headers.items())) + for field_name, headers in data.headers.multi_items() + ] response = JSONResponse(output) await response(scope, receive, send) @@ -72,7 +87,10 @@ async def app_read_body(scope, receive, send): def test_multipart_request_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) - assert response.json() == {"some": "data"} + assert response.json() == { + "some": "data", + "__headers": [["some", {"content-disposition": 'form-data; name="some"'}]], + } def test_multipart_request_files(tmpdir, test_client_factory): @@ -88,7 +106,15 @@ def test_multipart_request_files(tmpdir, test_client_factory): "filename": "test.txt", "content": "", "content_type": "", - } + }, + "__headers": [ + [ + "test", + { + "content-disposition": 'form-data; name="test"; filename="test.txt"' # noqa: E501 + }, + ] + ], } @@ -105,7 +131,16 @@ def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): "filename": "test.txt", "content": "", "content_type": "text/plain", - } + }, + "__headers": [ + [ + "test", + { + "content-disposition": 'form-data; name="test"; filename="test.txt"', # noqa: E501 + "content-type": "text/plain", + }, + ] + ], } @@ -134,6 +169,21 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory): "content": "", "content_type": "text/plain", }, + "__headers": [ + [ + "test1", + { + "content-disposition": 'form-data; name="test1"; filename="test1.txt"', # noqa: E501 + }, + ], + [ + "test2", + { + "content-disposition": 'form-data; name="test2"; filename="test2.txt"', # noqa: E501 + "content-type": "text/plain", + }, + ], + ], } @@ -166,7 +216,23 @@ def test_multi_items(tmpdir, test_client_factory): "content": "", "content_type": "text/plain", }, - ] + ], + "__headers": [ + ["test1", {"content-disposition": 'form-data; name="test1"'}], + [ + "test1", + { + "content-disposition": 'form-data; name="test1"; filename="test1.txt"' # noqa: E501 + }, + ], + [ + "test1", + { + "content-disposition": 'form-data; name="test1"; filename="test2.txt"', # noqa: E501 + "content-type": "text/plain", + }, + ], + ], } @@ -197,13 +263,24 @@ def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): }, ) assert response.json() == { + "field0": "value0", "file": { "filename": "file.txt", "content": "", "content_type": "text/plain", }, - "field0": "value0", "field1": "value1", + "__headers": [ + ["field0", {"content-disposition": 'form-data; name="field0"'}], + [ + "file", + { + "content-disposition": 'form-data; name="file"; filename="file.txt"', # noqa: E501 + "content-type": "text/plain", + }, + ], + ["field1", {"content-disposition": 'form-data; name="field1"'}], + ], } @@ -231,7 +308,16 @@ def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory "filename": "文書.txt", "content": "", "content_type": "text/plain", - } + }, + "__headers": [ + [ + "file", + { + "content-disposition": 'form-data; name="file"; filename="æ\x96\x87æ\x9b¸.txt"', # noqa: E501 + "content-type": "text/plain", + }, + ] + ], } @@ -258,7 +344,16 @@ def test_multipart_request_without_charset_for_filename(tmpdir, test_client_fact "filename": "画像.jpg", "content": "", "content_type": "image/jpeg", - } + }, + "__headers": [ + [ + "file", + { + "content-disposition": 'form-data; name="file"; filename="ç\x94»å\x83\x8f.jpg"', # noqa: E501 + "content-type": "image/jpeg", + }, + ] + ], } @@ -280,7 +375,10 @@ def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): ) }, ) - assert response.json() == {"value": "Transférer"} + assert response.json() == { + "value": "Transférer", + "__headers": [["value", {"content-disposition": 'form-data; name="value"'}]], + } def test_urlencoded_request_data(tmpdir, test_client_factory): @@ -318,7 +416,14 @@ def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory): response = client.post( "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART ) - assert response.json() == {"some": "data", "second": "key pair"} + assert response.json() == { + "some": "data", + "second": "key pair", + "__headers": [ + ["some", {"content-disposition": 'form-data; name="some"'}], + ["second", {"content-disposition": 'form-data; name="second"'}], + ], + } def test_user_safe_decode_helper():