Skip to content

Commit

Permalink
feat: expose multipart headers to users
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Oct 15, 2021
1 parent 3282d23 commit 9f6e890
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 118 deletions.
132 changes: 83 additions & 49 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,21 @@ def __str__(self) -> str:
return ", ".join(repr(item) for item in self)


class ImmutableMultiDict(typing.Mapping):
_KT = typing.TypeVar("_KT", bound=typing.Hashable)
_VT = typing.TypeVar("_VT")
_T = typing.TypeVar("_T")


_UNSET: typing.Any = object()


class ImmutableMultiDict(typing.Mapping[_KT, _VT]):
def __init__(
self,
*args: typing.Union[
"ImmutableMultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
"ImmutableMultiDict[_KT, _VT]",
typing.Mapping[_KT, _VT],
typing.List[typing.Tuple[_KT, _VT]],
],
**kwargs: typing.Any,
) -> None:
Expand Down Expand Up @@ -266,33 +274,43 @@ def __init__(
self._dict = {k: v for k, v in _items}
self._list = _items

def getlist(self, key: typing.Any) -> typing.List[typing.Any]:
def getlist(self, key: _KT) -> typing.List[_VT]:
return [item_value for item_key, item_value in self._list if item_key == key]

def keys(self) -> typing.KeysView:
def keys(self) -> typing.KeysView[_KT]:
return self._dict.keys()

def values(self) -> typing.ValuesView:
def values(self) -> typing.ValuesView[_VT]:
return self._dict.values()

def items(self) -> typing.ItemsView:
def items(self) -> typing.ItemsView[_KT, _VT]:
return self._dict.items()

def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
def multi_items(self) -> typing.List[typing.Tuple[_KT, _VT]]:
return list(self._list)

def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:

@typing.overload
def get(self, key: _KT) -> _VT:
...

@typing.overload
def get(self, key: _KT, default: _T) -> typing.Union[_VT, _T]:
...

def get(self, key: _KT, default: _T = _UNSET) -> typing.Union[_VT, _T]:
if key in self._dict:
return self._dict[key]
if default is _UNSET:
raise KeyError(key)
return default

def __getitem__(self, key: typing.Any) -> str:
def __getitem__(self, key: _KT) -> _VT:
return self._dict[key]

def __contains__(self, key: typing.Any) -> bool:
return key in self._dict

def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> typing.Iterator[_KT]:
return iter(self.keys())

def __len__(self) -> int:
Expand All @@ -309,24 +327,36 @@ def __repr__(self) -> str:
return f"{class_name}({items!r})"


class MultiDict(ImmutableMultiDict):
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
class MultiDict(ImmutableMultiDict[_KT, _VT]):
def __setitem__(self, key: _KT, value: typing.Any) -> None:
self.setlist(key, [value])

def __delitem__(self, key: typing.Any) -> None:
def __delitem__(self, key: _KT) -> None:
self._list = [(k, v) for k, v in self._list if k != key]
del self._dict[key]

def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)
@typing.overload
def pop(self, key: _KT) -> _VT:
...

@typing.overload
def pop(self, key: _KT, default: _T) -> typing.Union[_VT, _T]:
...

def pop(self, key: _KT, default: _T = _UNSET) -> typing.Union[_VT, _T]:
if key in self._dict:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)
if default is _UNSET:
raise KeyError(key)
return default

def popitem(self) -> typing.Tuple:
def popitem(self) -> typing.Tuple[_KT, _VT]:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value

def poplist(self, key: typing.Any) -> typing.List:
def poplist(self, key: _KT) -> typing.List[typing.Tuple[_KT, _VT]]:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
Expand All @@ -335,31 +365,31 @@ def clear(self) -> None:
self._dict.clear()
self._list.clear()

def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
def setdefault(self, key: _KT, default: typing.Optional[_VT] = None) -> _VT:
if key not in self:
self._dict[key] = default
self._list.append((key, default))

return self[key]

def setlist(self, key: typing.Any, values: typing.List) -> None:
def setlist(self, key: _KT, values: typing.List[_VT]) -> None:
if not values:
self.pop(key, None)
else:
existing_items = [(k, v) for (k, v) in self._list if k != key]
self._list = existing_items + [(key, value) for value in values]
self._dict[key] = values[-1]

def append(self, key: typing.Any, value: typing.Any) -> None:
def append(self, key: _KT, value: _VT) -> None:
self._list.append((key, value))
self._dict[key] = value

def update(
self,
*args: typing.Union[
"MultiDict",
typing.Mapping,
typing.List[typing.Tuple[typing.Any, typing.Any]],
"MultiDict[_KT, _VT]",
typing.Mapping[_KT, _VT],
typing.List[typing.Tuple[_KT, _VT]],
],
**kwargs: typing.Any,
) -> None:
Expand Down Expand Up @@ -454,28 +484,6 @@ async def close(self) -> None:
await run_in_threadpool(self.file.close)


class FormData(ImmutableMultiDict):
"""
An immutable multidict, containing both file uploads and text input.
"""

def __init__(
self,
*args: typing.Union[
"FormData",
typing.Mapping[str, typing.Union[str, UploadFile]],
typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
],
**kwargs: typing.Union[str, UploadFile],
) -> None:
super().__init__(*args, **kwargs)

async def close(self) -> None:
for key, value in self.multi_items():
if isinstance(value, UploadFile):
await value.close()


class Headers(typing.Mapping[str, str]):
"""
An immutable, case-insensitive multidict.
Expand Down Expand Up @@ -665,3 +673,29 @@ def __getattr__(self, key: typing.Any) -> typing.Any:

def __delattr__(self, key: typing.Any) -> None:
del self._state[key]


class FormData(ImmutableMultiDict[str, typing.Union[str, UploadFile]]):
"""
An immutable multidict, containing both file uploads and text input.
"""
headers: typing.Optional[ImmutableMultiDict[str, Headers]]
def __init__(
self,
*args: typing.Union[
"FormData",
typing.Mapping[str, typing.Union[str, UploadFile]],
typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
],
raw_headers: typing.Optional[typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]]] = None
) -> None:
super().__init__(*args)
if raw_headers is not None:
self.headers = ImmutableMultiDict([(field_name, Headers(raw=raw)) for field_name, raw in raw_headers])
else:
self.headers = None

async def close(self) -> None:
for _, value in self.multi_items():
if isinstance(value, UploadFile):
await value.close()
7 changes: 6 additions & 1 deletion starlette/formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ async def parse(self) -> FormData:
file: typing.Optional[UploadFile] = None

items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
item_headers: typing.List[typing.Tuple[bytes, bytes]] = []
raw_headers: typing.List[typing.Tuple[str, typing.List[typing.Tuple[bytes, bytes]]]] = []

# Feed the parser with data from the request.
async for chunk in self.stream:
Expand All @@ -195,6 +197,7 @@ async def parse(self) -> FormData:
content_disposition = None
content_type = b""
data = b""
item_headers = []
elif message_type == MultiPartMessage.HEADER_FIELD:
header_field += message_bytes
elif message_type == MultiPartMessage.HEADER_VALUE:
Expand All @@ -205,6 +208,7 @@ async def parse(self) -> FormData:
content_disposition = header_value
elif field == b"content-type":
content_type = header_value
item_headers.append((field, header_value))
header_field = b""
header_value = b""
elif message_type == MultiPartMessage.HEADERS_FINISHED:
Expand All @@ -229,6 +233,7 @@ async def parse(self) -> FormData:
else:
await file.seek(0)
items.append((field_name, file))
raw_headers.append((field_name, item_headers))

parser.finalize()
return FormData(items)
return FormData(items, raw_headers=raw_headers)
Loading

0 comments on commit 9f6e890

Please sign in to comment.