From 9f6e890d71474c2fc1b5050a9f30e2f658c23cfd Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 15 Oct 2021 14:41:28 -0500 Subject: [PATCH 1/6] feat: expose multipart headers to users --- starlette/datastructures.py | 132 +++++++++++++++++++++++------------- starlette/formparsers.py | 7 +- tests/test_formparsers.py | 89 ++++++------------------ 3 files changed, 110 insertions(+), 118 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 17dc46eb6..cf4e9b659 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -230,13 +230,21 @@ def __str__(self) -> str: return ", ".join(repr(item) for item in self) -class ImmutableMultiDict(typing.Mapping): +_KT = typing.TypeVar("_KT", bound=typing.Hashable) +_VT = typing.TypeVar("_VT") +_T = typing.TypeVar("_T") + + +_UNSET: typing.Any = object() + + +class ImmutableMultiDict(typing.Mapping[_KT, _VT]): def __init__( self, *args: typing.Union[ - "ImmutableMultiDict", - typing.Mapping, - typing.List[typing.Tuple[typing.Any, typing.Any]], + "ImmutableMultiDict[_KT, _VT]", + typing.Mapping[_KT, _VT], + typing.List[typing.Tuple[_KT, _VT]], ], **kwargs: typing.Any, ) -> None: @@ -266,33 +274,43 @@ def __init__( self._dict = {k: v for k, v in _items} self._list = _items - def getlist(self, key: typing.Any) -> typing.List[typing.Any]: + def getlist(self, key: _KT) -> typing.List[_VT]: return [item_value for item_key, item_value in self._list if item_key == key] - def keys(self) -> typing.KeysView: + def keys(self) -> typing.KeysView[_KT]: return self._dict.keys() - def values(self) -> typing.ValuesView: + def values(self) -> typing.ValuesView[_VT]: return self._dict.values() - def items(self) -> typing.ItemsView: + def items(self) -> typing.ItemsView[_KT, _VT]: return self._dict.items() - def multi_items(self) -> typing.List[typing.Tuple[str, str]]: + def multi_items(self) -> typing.List[typing.Tuple[_KT, _VT]]: return list(self._list) - - def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + + @typing.overload + def get(self, key: _KT) -> _VT: + ... + + @typing.overload + def get(self, key: _KT, default: _T) -> typing.Union[_VT, _T]: + ... + + def get(self, key: _KT, default: _T = _UNSET) -> typing.Union[_VT, _T]: if key in self._dict: return self._dict[key] + if default is _UNSET: + raise KeyError(key) return default - def __getitem__(self, key: typing.Any) -> str: + def __getitem__(self, key: _KT) -> _VT: return self._dict[key] def __contains__(self, key: typing.Any) -> bool: return key in self._dict - def __iter__(self) -> typing.Iterator[typing.Any]: + def __iter__(self) -> typing.Iterator[_KT]: return iter(self.keys()) def __len__(self) -> int: @@ -309,24 +327,36 @@ def __repr__(self) -> str: return f"{class_name}({items!r})" -class MultiDict(ImmutableMultiDict): - def __setitem__(self, key: typing.Any, value: typing.Any) -> None: +class MultiDict(ImmutableMultiDict[_KT, _VT]): + def __setitem__(self, key: _KT, value: typing.Any) -> None: self.setlist(key, [value]) - def __delitem__(self, key: typing.Any) -> None: + def __delitem__(self, key: _KT) -> None: self._list = [(k, v) for k, v in self._list if k != key] del self._dict[key] - def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: - self._list = [(k, v) for k, v in self._list if k != key] - return self._dict.pop(key, default) + @typing.overload + def pop(self, key: _KT) -> _VT: + ... + + @typing.overload + def pop(self, key: _KT, default: _T) -> typing.Union[_VT, _T]: + ... + + def pop(self, key: _KT, default: _T = _UNSET) -> typing.Union[_VT, _T]: + if key in self._dict: + self._list = [(k, v) for k, v in self._list if k != key] + return self._dict.pop(key, default) + if default is _UNSET: + raise KeyError(key) + return default - def popitem(self) -> typing.Tuple: + def popitem(self) -> typing.Tuple[_KT, _VT]: key, value = self._dict.popitem() self._list = [(k, v) for k, v in self._list if k != key] return key, value - def poplist(self, key: typing.Any) -> typing.List: + def poplist(self, key: _KT) -> typing.List[typing.Tuple[_KT, _VT]]: values = [v for k, v in self._list if k == key] self.pop(key) return values @@ -335,14 +365,14 @@ def clear(self) -> None: self._dict.clear() self._list.clear() - def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + def setdefault(self, key: _KT, default: typing.Optional[_VT] = None) -> _VT: if key not in self: self._dict[key] = default self._list.append((key, default)) return self[key] - def setlist(self, key: typing.Any, values: typing.List) -> None: + def setlist(self, key: _KT, values: typing.List[_VT]) -> None: if not values: self.pop(key, None) else: @@ -350,16 +380,16 @@ def setlist(self, key: typing.Any, values: typing.List) -> None: self._list = existing_items + [(key, value) for value in values] self._dict[key] = values[-1] - def append(self, key: typing.Any, value: typing.Any) -> None: + def append(self, key: _KT, value: _VT) -> None: self._list.append((key, value)) self._dict[key] = value def update( self, *args: typing.Union[ - "MultiDict", - typing.Mapping, - typing.List[typing.Tuple[typing.Any, typing.Any]], + "MultiDict[_KT, _VT]", + typing.Mapping[_KT, _VT], + typing.List[typing.Tuple[_KT, _VT]], ], **kwargs: typing.Any, ) -> None: @@ -454,28 +484,6 @@ async def close(self) -> None: await run_in_threadpool(self.file.close) -class FormData(ImmutableMultiDict): - """ - An immutable multidict, containing both file uploads and text input. - """ - - def __init__( - self, - *args: typing.Union[ - "FormData", - typing.Mapping[str, typing.Union[str, UploadFile]], - typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], - ], - **kwargs: typing.Union[str, UploadFile], - ) -> None: - super().__init__(*args, **kwargs) - - async def close(self) -> None: - for key, value in self.multi_items(): - if isinstance(value, UploadFile): - await value.close() - - class Headers(typing.Mapping[str, str]): """ An immutable, case-insensitive multidict. @@ -665,3 +673,29 @@ def __getattr__(self, key: typing.Any) -> typing.Any: def __delattr__(self, key: typing.Any) -> None: del self._state[key] + + +class FormData(ImmutableMultiDict[str, typing.Union[str, UploadFile]]): + """ + An immutable multidict, containing both file uploads and text input. + """ + headers: typing.Optional[ImmutableMultiDict[str, Headers]] + def __init__( + self, + *args: typing.Union[ + "FormData", + typing.Mapping[str, typing.Union[str, UploadFile]], + typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], + ], + raw_headers: typing.Optional[typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]]] = None + ) -> None: + super().__init__(*args) + if raw_headers is not None: + self.headers = ImmutableMultiDict([(field_name, Headers(raw=raw)) for field_name, raw in raw_headers]) + else: + self.headers = None + + async def close(self) -> None: + for _, value in self.multi_items(): + if isinstance(value, UploadFile): + await value.close() diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 1614a9d69..1b350d306 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -184,6 +184,8 @@ async def parse(self) -> FormData: file: typing.Optional[UploadFile] = None items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] + item_headers: typing.List[typing.Tuple[bytes, bytes]] = [] + raw_headers: typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]] = [] # Feed the parser with data from the request. async for chunk in self.stream: @@ -195,6 +197,7 @@ async def parse(self) -> FormData: content_disposition = None content_type = b"" data = b"" + item_headers = [] elif message_type == MultiPartMessage.HEADER_FIELD: header_field += message_bytes elif message_type == MultiPartMessage.HEADER_VALUE: @@ -205,6 +208,7 @@ async def parse(self) -> FormData: content_disposition = header_value elif field == b"content-type": content_type = header_value + item_headers.append((field, header_value)) header_field = b"" header_value = b"" elif message_type == MultiPartMessage.HEADERS_FINISHED: @@ -229,6 +233,7 @@ async def parse(self) -> FormData: else: await file.seek(0) items.append((field_name, file)) + raw_headers.append((field_name, item_headers)) parser.finalize() - return FormData(items) + return FormData(items, raw_headers=raw_headers) diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 8a1174e1d..edeba73fa 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -29,6 +29,8 @@ async def app(scope, receive, send): else: output[key] = value await request.close() + if data.headers is not None: + output["__headers"] = [(field_name, dict(headers.items())) for field_name, headers in data.headers.multi_items()] response = JSONResponse(output) await response(scope, receive, send) @@ -52,6 +54,8 @@ async def multi_items_app(scope, receive, send): else: output[key].append(value) await request.close() + if data.headers is not None: + output["__headers"] = [(field_name, dict(headers.items())) for field_name, headers in data.headers.multi_items()] response = JSONResponse(output) await response(scope, receive, send) @@ -65,6 +69,8 @@ async def app_read_body(scope, receive, send): for key, value in data.items(): output[key] = value await request.close() + if data.headers is not None: + output["__headers"] = [(field_name, dict(headers.items())) for field_name, headers in data.headers.multi_items()] response = JSONResponse(output) await response(scope, receive, send) @@ -72,7 +78,7 @@ async def app_read_body(scope, receive, send): def test_multipart_request_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) - assert response.json() == {"some": "data"} + assert response.json() == {'some': 'data', '__headers': [['some', {'content-disposition': 'form-data; name="some"'}]]} def test_multipart_request_files(tmpdir, test_client_factory): @@ -83,13 +89,8 @@ def test_multipart_request_files(tmpdir, test_client_factory): client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": f}) - assert response.json() == { - "test": { - "filename": "test.txt", - "content": "", - "content_type": "", - } - } + assert response.json() == {'test': {'filename': 'test.txt', 'content': '', 'content_type': ''}, '__headers': [['test', {'content-disposition': 'form-data; name="test"; filename="test.txt"'}]]} + def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): @@ -100,13 +101,8 @@ def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": ("test.txt", f, "text/plain")}) - assert response.json() == { - "test": { - "filename": "test.txt", - "content": "", - "content_type": "text/plain", - } - } + assert response.json() == {'test': {'filename': 'test.txt', 'content': '', 'content_type': 'text/plain'}, '__headers': [['test', {'content-disposition': 'form-data; name="test"; filename="test.txt"', 'content-type': 'text/plain'}]]} + def test_multipart_request_multiple_files(tmpdir, test_client_factory): @@ -123,18 +119,7 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory): response = client.post( "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")} ) - assert response.json() == { - "test1": { - "filename": "test1.txt", - "content": "", - "content_type": "", - }, - "test2": { - "filename": "test2.txt", - "content": "", - "content_type": "text/plain", - }, - } + assert response.json() == {'test1': {'filename': 'test1.txt', 'content': '', 'content_type': ''}, 'test2': {'filename': 'test2.txt', 'content': '', 'content_type': 'text/plain'}, '__headers': [['test1', {'content-disposition': 'form-data; name="test1"; filename="test1.txt"'}], ['test2', {'content-disposition': 'form-data; name="test2"; filename="test2.txt"', 'content-type': 'text/plain'}]]} def test_multi_items(tmpdir, test_client_factory): @@ -153,21 +138,7 @@ def test_multi_items(tmpdir, test_client_factory): data=[("test1", "abc")], files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))], ) - assert response.json() == { - "test1": [ - "abc", - { - "filename": "test1.txt", - "content": "", - "content_type": "", - }, - { - "filename": "test2.txt", - "content": "", - "content_type": "text/plain", - }, - ] - } + assert response.json() == {'test1': ['abc', {'filename': 'test1.txt', 'content': '', 'content_type': ''}, {'filename': 'test2.txt', 'content': '', 'content_type': 'text/plain'}], '__headers': [['test1', {'content-disposition': 'form-data; name="test1"'}], ['test1', {'content-disposition': 'form-data; name="test1"; filename="test1.txt"'}], ['test1', {'content-disposition': 'form-data; name="test1"; filename="test2.txt"', 'content-type': 'text/plain'}]]} def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): @@ -196,15 +167,7 @@ def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): ) }, ) - assert response.json() == { - "file": { - "filename": "file.txt", - "content": "", - "content_type": "text/plain", - }, - "field0": "value0", - "field1": "value1", - } + assert response.json() == {'field0': 'value0', 'file': {'filename': 'file.txt', 'content': '', 'content_type': 'text/plain'}, 'field1': 'value1', '__headers': [['field0', {'content-disposition': 'form-data; name="field0"'}], ['file', {'content-disposition': 'form-data; name="file"; filename="file.txt"', 'content-type': 'text/plain'}], ['field1', {'content-disposition': 'form-data; name="field1"'}]]} def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory): @@ -226,13 +189,8 @@ def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory ) }, ) - assert response.json() == { - "file": { - "filename": "文書.txt", - "content": "", - "content_type": "text/plain", - } - } + assert response.json() == {'file': {'filename': '文書.txt', 'content': '', 'content_type': 'text/plain'}, '__headers': [['file', {'content-disposition': 'form-data; name="file"; filename="æ\x96\x87æ\x9b¸.txt"', 'content-type': 'text/plain'}]]} + def test_multipart_request_without_charset_for_filename(tmpdir, test_client_factory): @@ -253,13 +211,8 @@ def test_multipart_request_without_charset_for_filename(tmpdir, test_client_fact ) }, ) - assert response.json() == { - "file": { - "filename": "画像.jpg", - "content": "", - "content_type": "image/jpeg", - } - } + assert response.json() == {'file': {'filename': '画像.jpg', 'content': '', 'content_type': 'image/jpeg'}, '__headers': [['file', {'content-disposition': 'form-data; name="file"; filename="ç\x94»å\x83\x8f.jpg"', 'content-type': 'image/jpeg'}]]} + def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): @@ -280,7 +233,7 @@ def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): ) }, ) - assert response.json() == {"value": "Transférer"} + assert response.json() == {'value': 'Transférer', '__headers': [['value', {'content-disposition': 'form-data; name="value"'}]]} def test_urlencoded_request_data(tmpdir, test_client_factory): @@ -318,7 +271,7 @@ def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory): response = client.post( "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART ) - assert response.json() == {"some": "data", "second": "key pair"} + assert response.json() == {'some': 'data', 'second': 'key pair', '__headers': [['some', {'content-disposition': 'form-data; name="some"'}], ['second', {'content-disposition': 'form-data; name="second"'}]]} def test_user_safe_decode_helper(): @@ -328,4 +281,4 @@ def test_user_safe_decode_helper(): def test_user_safe_decode_ignores_wrong_charset(): result = _user_safe_decode(b"abc", "latin-8") - assert result == "abc" + assert result == "abc" \ No newline at end of file From 94d2c3a91a5032610c91bda32209782c0ebac1ec Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 15 Oct 2021 14:50:49 -0500 Subject: [PATCH 2/6] linting --- starlette/datastructures.py | 16 ++- starlette/formparsers.py | 4 +- tests/test_formparsers.py | 188 ++++++++++++++++++++++++++++++++---- 3 files changed, 184 insertions(+), 24 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index cf4e9b659..125f54b34 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -288,11 +288,11 @@ def items(self) -> typing.ItemsView[_KT, _VT]: def multi_items(self) -> typing.List[typing.Tuple[_KT, _VT]]: return list(self._list) - + @typing.overload def get(self, key: _KT) -> _VT: ... - + @typing.overload def get(self, key: _KT, default: _T) -> typing.Union[_VT, _T]: ... @@ -338,7 +338,7 @@ def __delitem__(self, key: _KT) -> None: @typing.overload def pop(self, key: _KT) -> _VT: ... - + @typing.overload def pop(self, key: _KT, default: _T) -> typing.Union[_VT, _T]: ... @@ -679,7 +679,9 @@ class FormData(ImmutableMultiDict[str, typing.Union[str, UploadFile]]): """ An immutable multidict, containing both file uploads and text input. """ + headers: typing.Optional[ImmutableMultiDict[str, Headers]] + def __init__( self, *args: typing.Union[ @@ -687,11 +689,15 @@ def __init__( typing.Mapping[str, typing.Union[str, UploadFile]], typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], ], - raw_headers: typing.Optional[typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]]] = None + raw_headers: typing.Optional[ + typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]] + ] = None, ) -> None: super().__init__(*args) if raw_headers is not None: - self.headers = ImmutableMultiDict([(field_name, Headers(raw=raw)) for field_name, raw in raw_headers]) + self.headers = ImmutableMultiDict( + [(field_name, Headers(raw=raw)) for field_name, raw in raw_headers] + ) else: self.headers = None diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 1b350d306..db493ba70 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -185,7 +185,9 @@ async def parse(self) -> FormData: items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] item_headers: typing.List[typing.Tuple[bytes, bytes]] = [] - raw_headers: typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]] = [] + raw_headers: typing.List[ + typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]] + ] = [] # Feed the parser with data from the request. async for chunk in self.stream: diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index edeba73fa..cb3fe2404 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -30,7 +30,10 @@ async def app(scope, receive, send): output[key] = value await request.close() if data.headers is not None: - output["__headers"] = [(field_name, dict(headers.items())) for field_name, headers in data.headers.multi_items()] + output["__headers"] = [ + (field_name, dict(headers.items())) + for field_name, headers in data.headers.multi_items() + ] response = JSONResponse(output) await response(scope, receive, send) @@ -55,7 +58,10 @@ async def multi_items_app(scope, receive, send): output[key].append(value) await request.close() if data.headers is not None: - output["__headers"] = [(field_name, dict(headers.items())) for field_name, headers in data.headers.multi_items()] + output["__headers"] = [ + (field_name, dict(headers.items())) + for field_name, headers in data.headers.multi_items() + ] response = JSONResponse(output) await response(scope, receive, send) @@ -70,7 +76,10 @@ async def app_read_body(scope, receive, send): output[key] = value await request.close() if data.headers is not None: - output["__headers"] = [(field_name, dict(headers.items())) for field_name, headers in data.headers.multi_items()] + output["__headers"] = [ + (field_name, dict(headers.items())) + for field_name, headers in data.headers.multi_items() + ] response = JSONResponse(output) await response(scope, receive, send) @@ -78,7 +87,10 @@ async def app_read_body(scope, receive, send): def test_multipart_request_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) - assert response.json() == {'some': 'data', '__headers': [['some', {'content-disposition': 'form-data; name="some"'}]]} + assert response.json() == { + "some": "data", + "__headers": [["some", {"content-disposition": 'form-data; name="some"'}]], + } def test_multipart_request_files(tmpdir, test_client_factory): @@ -89,8 +101,21 @@ def test_multipart_request_files(tmpdir, test_client_factory): client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": f}) - assert response.json() == {'test': {'filename': 'test.txt', 'content': '', 'content_type': ''}, '__headers': [['test', {'content-disposition': 'form-data; name="test"; filename="test.txt"'}]]} - + assert response.json() == { + "test": { + "filename": "test.txt", + "content": "", + "content_type": "", + }, + "__headers": [ + [ + "test", + { + "content-disposition": 'form-data; name="test"; filename="test.txt"' + }, + ] + ], + } def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): @@ -101,8 +126,22 @@ def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": ("test.txt", f, "text/plain")}) - assert response.json() == {'test': {'filename': 'test.txt', 'content': '', 'content_type': 'text/plain'}, '__headers': [['test', {'content-disposition': 'form-data; name="test"; filename="test.txt"', 'content-type': 'text/plain'}]]} - + assert response.json() == { + "test": { + "filename": "test.txt", + "content": "", + "content_type": "text/plain", + }, + "__headers": [ + [ + "test", + { + "content-disposition": 'form-data; name="test"; filename="test.txt"', + "content-type": "text/plain", + }, + ] + ], + } def test_multipart_request_multiple_files(tmpdir, test_client_factory): @@ -119,7 +158,33 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory): response = client.post( "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")} ) - assert response.json() == {'test1': {'filename': 'test1.txt', 'content': '', 'content_type': ''}, 'test2': {'filename': 'test2.txt', 'content': '', 'content_type': 'text/plain'}, '__headers': [['test1', {'content-disposition': 'form-data; name="test1"; filename="test1.txt"'}], ['test2', {'content-disposition': 'form-data; name="test2"; filename="test2.txt"', 'content-type': 'text/plain'}]]} + assert response.json() == { + "test1": { + "filename": "test1.txt", + "content": "", + "content_type": "", + }, + "test2": { + "filename": "test2.txt", + "content": "", + "content_type": "text/plain", + }, + "__headers": [ + [ + "test1", + { + "content-disposition": 'form-data; name="test1"; filename="test1.txt"' + }, + ], + [ + "test2", + { + "content-disposition": 'form-data; name="test2"; filename="test2.txt"', + "content-type": "text/plain", + }, + ], + ], + } def test_multi_items(tmpdir, test_client_factory): @@ -138,7 +203,37 @@ def test_multi_items(tmpdir, test_client_factory): data=[("test1", "abc")], files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))], ) - assert response.json() == {'test1': ['abc', {'filename': 'test1.txt', 'content': '', 'content_type': ''}, {'filename': 'test2.txt', 'content': '', 'content_type': 'text/plain'}], '__headers': [['test1', {'content-disposition': 'form-data; name="test1"'}], ['test1', {'content-disposition': 'form-data; name="test1"; filename="test1.txt"'}], ['test1', {'content-disposition': 'form-data; name="test1"; filename="test2.txt"', 'content-type': 'text/plain'}]]} + assert response.json() == { + "test1": [ + "abc", + { + "filename": "test1.txt", + "content": "", + "content_type": "", + }, + { + "filename": "test2.txt", + "content": "", + "content_type": "text/plain", + }, + ], + "__headers": [ + ["test1", {"content-disposition": 'form-data; name="test1"'}], + [ + "test1", + { + "content-disposition": 'form-data; name="test1"; filename="test1.txt"' + }, + ], + [ + "test1", + { + "content-disposition": 'form-data; name="test1"; filename="test2.txt"', + "content-type": "text/plain", + }, + ], + ], + } def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): @@ -167,7 +262,26 @@ def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): ) }, ) - assert response.json() == {'field0': 'value0', 'file': {'filename': 'file.txt', 'content': '', 'content_type': 'text/plain'}, 'field1': 'value1', '__headers': [['field0', {'content-disposition': 'form-data; name="field0"'}], ['file', {'content-disposition': 'form-data; name="file"; filename="file.txt"', 'content-type': 'text/plain'}], ['field1', {'content-disposition': 'form-data; name="field1"'}]]} + assert response.json() == { + "field0": "value0", + "file": { + "filename": "file.txt", + "content": "", + "content_type": "text/plain", + }, + "field1": "value1", + "__headers": [ + ["field0", {"content-disposition": 'form-data; name="field0"'}], + [ + "file", + { + "content-disposition": 'form-data; name="file"; filename="file.txt"', + "content-type": "text/plain", + }, + ], + ["field1", {"content-disposition": 'form-data; name="field1"'}], + ], + } def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory): @@ -189,8 +303,22 @@ def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory ) }, ) - assert response.json() == {'file': {'filename': '文書.txt', 'content': '', 'content_type': 'text/plain'}, '__headers': [['file', {'content-disposition': 'form-data; name="file"; filename="æ\x96\x87æ\x9b¸.txt"', 'content-type': 'text/plain'}]]} - + assert response.json() == { + "file": { + "filename": "文書.txt", + "content": "", + "content_type": "text/plain", + }, + "__headers": [ + [ + "file", + { + "content-disposition": 'form-data; name="file"; filename="æ\x96\x87æ\x9b¸.txt"', + "content-type": "text/plain", + }, + ] + ], + } def test_multipart_request_without_charset_for_filename(tmpdir, test_client_factory): @@ -211,8 +339,22 @@ def test_multipart_request_without_charset_for_filename(tmpdir, test_client_fact ) }, ) - assert response.json() == {'file': {'filename': '画像.jpg', 'content': '', 'content_type': 'image/jpeg'}, '__headers': [['file', {'content-disposition': 'form-data; name="file"; filename="ç\x94»å\x83\x8f.jpg"', 'content-type': 'image/jpeg'}]]} - + assert response.json() == { + "file": { + "filename": "画像.jpg", + "content": "", + "content_type": "image/jpeg", + }, + "__headers": [ + [ + "file", + { + "content-disposition": 'form-data; name="file"; filename="ç\x94»å\x83\x8f.jpg"', + "content-type": "image/jpeg", + }, + ] + ], + } def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): @@ -233,7 +375,10 @@ def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): ) }, ) - assert response.json() == {'value': 'Transférer', '__headers': [['value', {'content-disposition': 'form-data; name="value"'}]]} + assert response.json() == { + "value": "Transférer", + "__headers": [["value", {"content-disposition": 'form-data; name="value"'}]], + } def test_urlencoded_request_data(tmpdir, test_client_factory): @@ -271,7 +416,14 @@ def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory): response = client.post( "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART ) - assert response.json() == {'some': 'data', 'second': 'key pair', '__headers': [['some', {'content-disposition': 'form-data; name="some"'}], ['second', {'content-disposition': 'form-data; name="second"'}]]} + assert response.json() == { + "some": "data", + "second": "key pair", + "__headers": [ + ["some", {"content-disposition": 'form-data; name="some"'}], + ["second", {"content-disposition": 'form-data; name="second"'}], + ], + } def test_user_safe_decode_helper(): @@ -281,4 +433,4 @@ def test_user_safe_decode_helper(): def test_user_safe_decode_ignores_wrong_charset(): result = _user_safe_decode(b"abc", "latin-8") - assert result == "abc" \ No newline at end of file + assert result == "abc" From 41cba7d125cdf35cf5faafe53a50ef9b51b3fab3 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 15 Oct 2021 15:04:51 -0500 Subject: [PATCH 3/6] linting --- starlette/datastructures.py | 76 ++++++++++++------------------------- tests/test_formparsers.py | 18 ++++----- 2 files changed, 34 insertions(+), 60 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 125f54b34..635604d95 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -232,19 +232,15 @@ def __str__(self) -> str: _KT = typing.TypeVar("_KT", bound=typing.Hashable) _VT = typing.TypeVar("_VT") -_T = typing.TypeVar("_T") - - -_UNSET: typing.Any = object() class ImmutableMultiDict(typing.Mapping[_KT, _VT]): def __init__( self, *args: typing.Union[ - "ImmutableMultiDict[_KT, _VT]", - typing.Mapping[_KT, _VT], - typing.List[typing.Tuple[_KT, _VT]], + "ImmutableMultiDict", + typing.Mapping, + typing.List[typing.Tuple[typing.Any, typing.Any]], ], **kwargs: typing.Any, ) -> None: @@ -274,43 +270,33 @@ def __init__( self._dict = {k: v for k, v in _items} self._list = _items - def getlist(self, key: _KT) -> typing.List[_VT]: + def getlist(self, key: typing.Any) -> typing.List[typing.Any]: return [item_value for item_key, item_value in self._list if item_key == key] - def keys(self) -> typing.KeysView[_KT]: + def keys(self) -> typing.KeysView: return self._dict.keys() - def values(self) -> typing.ValuesView[_VT]: + def values(self) -> typing.ValuesView: return self._dict.values() - def items(self) -> typing.ItemsView[_KT, _VT]: + def items(self) -> typing.ItemsView: return self._dict.items() - def multi_items(self) -> typing.List[typing.Tuple[_KT, _VT]]: + def multi_items(self) -> typing.List[typing.Tuple[str, str]]: return list(self._list) - @typing.overload - def get(self, key: _KT) -> _VT: - ... - - @typing.overload - def get(self, key: _KT, default: _T) -> typing.Union[_VT, _T]: - ... - - def get(self, key: _KT, default: _T = _UNSET) -> typing.Union[_VT, _T]: + def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: if key in self._dict: return self._dict[key] - if default is _UNSET: - raise KeyError(key) return default - def __getitem__(self, key: _KT) -> _VT: + def __getitem__(self, key: typing.Any) -> _VT: return self._dict[key] def __contains__(self, key: typing.Any) -> bool: return key in self._dict - def __iter__(self) -> typing.Iterator[_KT]: + def __iter__(self) -> typing.Iterator[typing.Any]: return iter(self.keys()) def __len__(self) -> int: @@ -327,36 +313,24 @@ def __repr__(self) -> str: return f"{class_name}({items!r})" -class MultiDict(ImmutableMultiDict[_KT, _VT]): - def __setitem__(self, key: _KT, value: typing.Any) -> None: +class MultiDict(ImmutableMultiDict): + def __setitem__(self, key: typing.Any, value: typing.Any) -> None: self.setlist(key, [value]) - def __delitem__(self, key: _KT) -> None: + def __delitem__(self, key: typing.Any) -> None: self._list = [(k, v) for k, v in self._list if k != key] del self._dict[key] - @typing.overload - def pop(self, key: _KT) -> _VT: - ... - - @typing.overload - def pop(self, key: _KT, default: _T) -> typing.Union[_VT, _T]: - ... - - def pop(self, key: _KT, default: _T = _UNSET) -> typing.Union[_VT, _T]: - if key in self._dict: - self._list = [(k, v) for k, v in self._list if k != key] - return self._dict.pop(key, default) - if default is _UNSET: - raise KeyError(key) - return default + def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + self._list = [(k, v) for k, v in self._list if k != key] + return self._dict.pop(key, default) - def popitem(self) -> typing.Tuple[_KT, _VT]: + def popitem(self) -> typing.Tuple: key, value = self._dict.popitem() self._list = [(k, v) for k, v in self._list if k != key] return key, value - def poplist(self, key: _KT) -> typing.List[typing.Tuple[_KT, _VT]]: + def poplist(self, key: typing.Any) -> typing.List: values = [v for k, v in self._list if k == key] self.pop(key) return values @@ -365,14 +339,14 @@ def clear(self) -> None: self._dict.clear() self._list.clear() - def setdefault(self, key: _KT, default: typing.Optional[_VT] = None) -> _VT: + def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: if key not in self: self._dict[key] = default self._list.append((key, default)) return self[key] - def setlist(self, key: _KT, values: typing.List[_VT]) -> None: + def setlist(self, key: typing.Any, values: typing.List) -> None: if not values: self.pop(key, None) else: @@ -380,16 +354,16 @@ def setlist(self, key: _KT, values: typing.List[_VT]) -> None: self._list = existing_items + [(key, value) for value in values] self._dict[key] = values[-1] - def append(self, key: _KT, value: _VT) -> None: + def append(self, key: typing.Any, value: typing.Any) -> None: self._list.append((key, value)) self._dict[key] = value def update( self, *args: typing.Union[ - "MultiDict[_KT, _VT]", - typing.Mapping[_KT, _VT], - typing.List[typing.Tuple[_KT, _VT]], + "MultiDict", + typing.Mapping, + typing.List[typing.Tuple[typing.Any, typing.Any]], ], **kwargs: typing.Any, ) -> None: diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index cb3fe2404..17413dc55 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -111,7 +111,7 @@ def test_multipart_request_files(tmpdir, test_client_factory): [ "test", { - "content-disposition": 'form-data; name="test"; filename="test.txt"' + "content-disposition": 'form-data; name="test"; filename="test.txt"' # noqa: E501 }, ] ], @@ -136,7 +136,7 @@ def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): [ "test", { - "content-disposition": 'form-data; name="test"; filename="test.txt"', + "content-disposition": 'form-data; name="test"; filename="test.txt"', # noqa: E501 "content-type": "text/plain", }, ] @@ -173,13 +173,13 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory): [ "test1", { - "content-disposition": 'form-data; name="test1"; filename="test1.txt"' + "content-disposition": 'form-data; name="test1"; filename="test1.txt"', # noqa: E501 }, ], [ "test2", { - "content-disposition": 'form-data; name="test2"; filename="test2.txt"', + "content-disposition": 'form-data; name="test2"; filename="test2.txt"', # noqa: E501 "content-type": "text/plain", }, ], @@ -222,13 +222,13 @@ def test_multi_items(tmpdir, test_client_factory): [ "test1", { - "content-disposition": 'form-data; name="test1"; filename="test1.txt"' + "content-disposition": 'form-data; name="test1"; filename="test1.txt"' # noqa: E501 }, ], [ "test1", { - "content-disposition": 'form-data; name="test1"; filename="test2.txt"', + "content-disposition": 'form-data; name="test1"; filename="test2.txt"', # noqa: E501 "content-type": "text/plain", }, ], @@ -275,7 +275,7 @@ def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): [ "file", { - "content-disposition": 'form-data; name="file"; filename="file.txt"', + "content-disposition": 'form-data; name="file"; filename="file.txt"', # noqa: E501 "content-type": "text/plain", }, ], @@ -313,7 +313,7 @@ def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory [ "file", { - "content-disposition": 'form-data; name="file"; filename="æ\x96\x87æ\x9b¸.txt"', + "content-disposition": 'form-data; name="file"; filename="æ\x96\x87æ\x9b¸.txt"', # noqa: E501 "content-type": "text/plain", }, ] @@ -349,7 +349,7 @@ def test_multipart_request_without_charset_for_filename(tmpdir, test_client_fact [ "file", { - "content-disposition": 'form-data; name="file"; filename="ç\x94»å\x83\x8f.jpg"', + "content-disposition": 'form-data; name="file"; filename="ç\x94»å\x83\x8f.jpg"', # noqa: E501 "content-type": "image/jpeg", }, ] From bd33808046a3e3c643a176bf79d2d8fa6ea985fd Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Dec 2021 10:16:21 -0600 Subject: [PATCH 4/6] remove typevars --- starlette/datastructures.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 635604d95..9386cfbef 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -230,11 +230,7 @@ def __str__(self) -> str: return ", ".join(repr(item) for item in self) -_KT = typing.TypeVar("_KT", bound=typing.Hashable) -_VT = typing.TypeVar("_VT") - - -class ImmutableMultiDict(typing.Mapping[_KT, _VT]): +class ImmutableMultiDict(typing.Mapping): def __init__( self, *args: typing.Union[ @@ -290,7 +286,7 @@ def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: return self._dict[key] return default - def __getitem__(self, key: typing.Any) -> _VT: + def __getitem__(self, key: typing.Any) -> typing.Any: return self._dict[key] def __contains__(self, key: typing.Any) -> bool: @@ -649,12 +645,12 @@ def __delattr__(self, key: typing.Any) -> None: del self._state[key] -class FormData(ImmutableMultiDict[str, typing.Union[str, UploadFile]]): +class FormData(ImmutableMultiDict): """ An immutable multidict, containing both file uploads and text input. """ - headers: typing.Optional[ImmutableMultiDict[str, Headers]] + headers: typing.Optional[ImmutableMultiDict] def __init__( self, From 9b3c033c7f4e46d74b958837f77c37bcae3504d7 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 22 Dec 2021 11:47:17 -0600 Subject: [PATCH 5/6] minimize diff by moving FormData back to it's original location --- starlette/datastructures.py | 64 ++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 9386cfbef..936fcc26f 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -454,6 +454,38 @@ async def close(self) -> None: await run_in_threadpool(self.file.close) +class FormData(ImmutableMultiDict): + """ + An immutable multidict, containing both file uploads and text input. + """ + + headers: typing.Optional[ImmutableMultiDict] + + def __init__( + self, + *args: typing.Union[ + "FormData", + typing.Mapping[str, typing.Union[str, UploadFile]], + typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], + ], + raw_headers: typing.Optional[ + typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]] + ] = None, + ) -> None: + super().__init__(*args) + if raw_headers is not None: + self.headers = ImmutableMultiDict( + [(field_name, Headers(raw=raw)) for field_name, raw in raw_headers] + ) + else: + self.headers = None + + async def close(self) -> None: + for _, value in self.multi_items(): + if isinstance(value, UploadFile): + await value.close() + + class Headers(typing.Mapping[str, str]): """ An immutable, case-insensitive multidict. @@ -643,35 +675,3 @@ def __getattr__(self, key: typing.Any) -> typing.Any: def __delattr__(self, key: typing.Any) -> None: del self._state[key] - - -class FormData(ImmutableMultiDict): - """ - An immutable multidict, containing both file uploads and text input. - """ - - headers: typing.Optional[ImmutableMultiDict] - - def __init__( - self, - *args: typing.Union[ - "FormData", - typing.Mapping[str, typing.Union[str, UploadFile]], - typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], - ], - raw_headers: typing.Optional[ - typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]] - ] = None, - ) -> None: - super().__init__(*args) - if raw_headers is not None: - self.headers = ImmutableMultiDict( - [(field_name, Headers(raw=raw)) for field_name, raw in raw_headers] - ) - else: - self.headers = None - - async def close(self) -> None: - for _, value in self.multi_items(): - if isinstance(value, UploadFile): - await value.close() From 883c5369d4a391d454c13a553de5d544dc4b7dcf Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 22 Dec 2021 11:48:22 -0600 Subject: [PATCH 6/6] more diff minimization --- starlette/datastructures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 936fcc26f..192a6ca07 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -481,7 +481,7 @@ def __init__( self.headers = None async def close(self) -> None: - for _, value in self.multi_items(): + for key, value in self.multi_items(): if isinstance(value, UploadFile): await value.close()