From 8af424637de47c12476fc9afb3de92ce9466b71a Mon Sep 17 00:00:00 2001 From: Marcel Hellkamp Date: Sat, 21 Sep 2024 13:51:46 +0200 Subject: [PATCH] feat: Added some more typing annotations --- multipart.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/multipart.py b/multipart.py index 47e04e3..822d758 100644 --- a/multipart.py +++ b/multipart.py @@ -17,10 +17,9 @@ import re from io import BytesIO -from typing import Iterator, Union, Optional, Tuple, List +from typing import Iterator, Union, Optional, Tuple, List, MutableMapping, TypeVar from urllib.parse import parse_qs from wsgiref.headers import Headers -from collections.abc import MutableMapping as DictMixin import tempfile @@ -29,8 +28,10 @@ ############################################################################## # Some of these were copied from bottle: https://bottlepy.org +_V = TypeVar("V") +_D = TypeVar("D") -class MultiDict(DictMixin): +class MultiDict(MutableMapping[str, _V]): """ A dict that stores multiple values per key. Most dict methods return the last value by default. There are special methods to get all values. """ @@ -50,7 +51,7 @@ def __init__(self, *args, **kwargs): def __len__(self): return len(self.dict) - def __iter__(self): + def __iter__(self) -> Iterator[_V]: return iter(self.dict) def __contains__(self, key): @@ -65,10 +66,10 @@ def __str__(self): def __repr__(self): return repr(self.dict) - def keys(self): + def keys(self) -> Iterator[str]: return self.dict.keys() - def __getitem__(self, key): + def __getitem__(self, key) -> _V: return self.get(key, KeyError, -1) def __setitem__(self, key, value): @@ -80,16 +81,16 @@ def append(self, key, value): def replace(self, key, value): self.dict[key] = [value] - def getall(self, key): + def getall(self, key) -> List[_V]: return self.dict.get(key) or [] - def get(self, key, default=None, index=-1): + def get(self, key, default:_D=None, index=-1) -> Union[_V,_D]: if key not in self.dict and default != KeyError: return [default][index] return self.dict[key][index] - def iterallitems(self): + def iterallitems(self) -> Iterator[Tuple[str, _V]]: """ Yield (key, value) keys, but for all values. """ for key, values in self.dict.items(): for value in values: @@ -585,7 +586,7 @@ def __init__( self._done = [] self._part_iter = None - def __iter__(self): + def __iter__(self) -> Iterator["MultipartPart"]: """Iterate over the parts of the multipart message.""" if not self._part_iter: self._part_iter = self._iterparse() @@ -601,7 +602,7 @@ def parts(self): """Returns a list with all parts of the multipart message.""" return list(self) - def get(self, name, default=None): + def get(self, name, default: _D = None): """Return the first part with that name or a default value.""" for part in self: if name == part.name: @@ -737,7 +738,9 @@ def close(self): ############################################################################## -def parse_form_data(environ, charset="utf8", strict=False, **kwargs): +def parse_form_data( + environ, charset="utf8", strict=False, **kwargs +) -> Tuple[MultiDict[str], MultiDict[MultipartPart]]: """ Parses both types of form data (multipart and url-encoded) from a WSGI environment and returns a (forms, files) tuple. Both are instances of :class:`MultiDict` and may contain multiple values per key.