Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Improve typing #54

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ arrive, instead of waiting for the entire request to be parsed:

def wsgi(environ, start_response):
assert environ["REQUEST_METHOD"] == "POST"
ctype, copts = mp.parse_options_header(environ.get("CONTENT_TYPE", ""))
ctype, copts = parse_options_header(environ.get("CONTENT_TYPE", ""))
boundary = copts.get("boundary")
charset = copts.get("charset", "utf8")
assert ctype == "multipart/form-data"

parser = mp.MultipartParser(environ["wsgi.input"], boundary, charset)
parser = MultipartParser(environ["wsgi.input"], boundary, charset)
for part in parser:
if part.filename:
print(f"{part.name}: File upload ({part.size} bytes)")
Expand All @@ -104,20 +104,20 @@ the other parsers in this library:

.. code-block:: python

from multipart import PushMultipartParser
from multipart import PushMultipartParser, MultipartSegment

async def process_multipart(reader: asyncio.StreamReader, boundary: str):
with PushMultipartParser(boundary) as parser:
while not parser.closed:
chunk = await reader.read(1024*46)
for event in parser.parse(chunk):
if isinstance(event, list):
print("== Start of segment")
for header, value in event:
if isinstance(event, MultipartSegment):
print(f"== Start of segment: {event.name}")
for header, value in event.headerlist:
print(f"{header}: {value}")
elif isinstance(event, bytearray):
elif event:
print(f"[{len(event)} bytes of data]")
elif event is None:
else:
print("== End of segment")


Expand Down
27 changes: 15 additions & 12 deletions multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

import re
from io import BytesIO
from typing import Iterator, Union, Optional, Tuple, List
from typing import Iterator, Union, Optional, Tuple, List, MutableMapping, TypeVar
from urllib.parse import parse_qs
from wsgiref.headers import Headers
from collections.abc import MutableMapping as DictMixin
import tempfile


Expand All @@ -29,8 +28,10 @@
##############################################################################
# Some of these were copied from bottle: https://bottlepy.org

_V = TypeVar("V")
_D = TypeVar("D")

class MultiDict(DictMixin):
class MultiDict(MutableMapping[str, _V]):
""" A dict that stores multiple values per key. Most dict methods return the
last value by default. There are special methods to get all values.
"""
Expand All @@ -50,7 +51,7 @@ def __init__(self, *args, **kwargs):
def __len__(self):
return len(self.dict)

def __iter__(self):
def __iter__(self) -> Iterator[_V]:
return iter(self.dict)

def __contains__(self, key):
Expand All @@ -65,10 +66,10 @@ def __str__(self):
def __repr__(self):
return repr(self.dict)

def keys(self):
def keys(self) -> Iterator[str]:
return self.dict.keys()

def __getitem__(self, key):
def __getitem__(self, key) -> _V:
return self.get(key, KeyError, -1)

def __setitem__(self, key, value):
Expand All @@ -80,16 +81,16 @@ def append(self, key, value):
def replace(self, key, value):
self.dict[key] = [value]

def getall(self, key):
def getall(self, key) -> List[_V]:
return self.dict.get(key) or []

def get(self, key, default=None, index=-1):
def get(self, key, default:_D=None, index=-1) -> Union[_V,_D]:
if key not in self.dict and default != KeyError:
return [default][index]

return self.dict[key][index]

def iterallitems(self):
def iterallitems(self) -> Iterator[Tuple[str, _V]]:
""" Yield (key, value) keys, but for all values. """
for key, values in self.dict.items():
for value in values:
Expand Down Expand Up @@ -585,7 +586,7 @@ def __init__(
self._done = []
self._part_iter = None

def __iter__(self):
def __iter__(self) -> Iterator["MultipartPart"]:
"""Iterate over the parts of the multipart message."""
if not self._part_iter:
self._part_iter = self._iterparse()
Expand All @@ -601,7 +602,7 @@ def parts(self):
"""Returns a list with all parts of the multipart message."""
return list(self)

def get(self, name, default=None):
def get(self, name, default: _D = None):
"""Return the first part with that name or a default value."""
for part in self:
if name == part.name:
Expand Down Expand Up @@ -737,7 +738,9 @@ def close(self):
##############################################################################


def parse_form_data(environ, charset="utf8", strict=False, **kwargs):
def parse_form_data(
environ, charset="utf8", strict=False, **kwargs
) -> Tuple[MultiDict[str], MultiDict[MultipartPart]]:
""" Parses both types of form data (multipart and url-encoded) from a WSGI
environment and returns a (forms, files) tuple. Both are instances of
:class:`MultiDict` and may contain multiple values per key.
Expand Down