Skip to content

Commit

Permalink
feat: Added some more typing annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
defnull committed Sep 21, 2024
1 parent 10767fc commit 8af4246
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

import re
from io import BytesIO
from typing import Iterator, Union, Optional, Tuple, List
from typing import Iterator, Union, Optional, Tuple, List, MutableMapping, TypeVar
from urllib.parse import parse_qs
from wsgiref.headers import Headers
from collections.abc import MutableMapping as DictMixin
import tempfile


Expand All @@ -29,8 +28,10 @@
##############################################################################
# Some of these were copied from bottle: https://bottlepy.org

_V = TypeVar("V")
_D = TypeVar("D")

class MultiDict(DictMixin):
class MultiDict(MutableMapping[str, _V]):
""" A dict that stores multiple values per key. Most dict methods return the
last value by default. There are special methods to get all values.
"""
Expand All @@ -50,7 +51,7 @@ def __init__(self, *args, **kwargs):
def __len__(self):
return len(self.dict)

def __iter__(self):
def __iter__(self) -> Iterator[_V]:
return iter(self.dict)

def __contains__(self, key):
Expand All @@ -65,10 +66,10 @@ def __str__(self):
def __repr__(self):
return repr(self.dict)

def keys(self):
def keys(self) -> Iterator[str]:
return self.dict.keys()

def __getitem__(self, key):
def __getitem__(self, key) -> _V:
return self.get(key, KeyError, -1)

def __setitem__(self, key, value):
Expand All @@ -80,16 +81,16 @@ def append(self, key, value):
def replace(self, key, value):
self.dict[key] = [value]

def getall(self, key):
def getall(self, key) -> List[_V]:
return self.dict.get(key) or []

def get(self, key, default=None, index=-1):
def get(self, key, default:_D=None, index=-1) -> Union[_V,_D]:
if key not in self.dict and default != KeyError:
return [default][index]

return self.dict[key][index]

def iterallitems(self):
def iterallitems(self) -> Iterator[Tuple[str, _V]]:
""" Yield (key, value) keys, but for all values. """
for key, values in self.dict.items():
for value in values:
Expand Down Expand Up @@ -585,7 +586,7 @@ def __init__(
self._done = []
self._part_iter = None

def __iter__(self):
def __iter__(self) -> Iterator["MultipartPart"]:
"""Iterate over the parts of the multipart message."""
if not self._part_iter:
self._part_iter = self._iterparse()
Expand All @@ -601,7 +602,7 @@ def parts(self):
"""Returns a list with all parts of the multipart message."""
return list(self)

def get(self, name, default=None):
def get(self, name, default: _D = None):
"""Return the first part with that name or a default value."""
for part in self:
if name == part.name:
Expand Down Expand Up @@ -737,7 +738,9 @@ def close(self):
##############################################################################


def parse_form_data(environ, charset="utf8", strict=False, **kwargs):
def parse_form_data(
environ, charset="utf8", strict=False, **kwargs
) -> Tuple[MultiDict[str], MultiDict[MultipartPart]]:
""" Parses both types of form data (multipart and url-encoded) from a WSGI
environment and returns a (forms, files) tuple. Both are instances of
:class:`MultiDict` and may contain multiple values per key.
Expand Down

0 comments on commit 8af4246

Please sign in to comment.