the Swiss Army knife of Python web development.
-{gyver}\n\n\n- -""".encode( - "latin1" - ) - ] - - return easteregged + return int(value) diff --git a/src/werkzeug/_reloader.py b/src/werkzeug/_reloader.py index 57f3117..d7e91a6 100644 --- a/src/werkzeug/_reloader.py +++ b/src/werkzeug/_reloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import fnmatch import os import subprocess @@ -20,7 +22,7 @@ if hasattr(sys, "real_prefix"): # virtualenv < 20 - prefix.add(sys.real_prefix) # type: ignore[attr-defined] + prefix.add(sys.real_prefix) _stat_ignore_scan = tuple(prefix) del prefix @@ -55,13 +57,13 @@ def _iter_module_paths() -> t.Iterator[str]: yield name -def _remove_by_pattern(paths: t.Set[str], exclude_patterns: t.Set[str]) -> None: +def _remove_by_pattern(paths: set[str], exclude_patterns: set[str]) -> None: for pattern in exclude_patterns: paths.difference_update(fnmatch.filter(paths, pattern)) def _find_stat_paths( - extra_files: t.Set[str], exclude_patterns: t.Set[str] + extra_files: set[str], exclude_patterns: set[str] ) -> t.Iterable[str]: """Find paths for the stat reloader to watch. Returns imported module files, Python files under non-system paths. Extra files and @@ -115,7 +117,7 @@ def _find_stat_paths( def _find_watchdog_paths( - extra_files: t.Set[str], exclude_patterns: t.Set[str] + extra_files: set[str], exclude_patterns: set[str] ) -> t.Iterable[str]: """Find paths for the stat reloader to watch. Looks at the same sources as the stat reloader, but watches everything under @@ -139,7 +141,7 @@ def _find_watchdog_paths( def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: - root: t.Dict[str, dict] = {} + root: dict[str, dict[str, t.Any]] = {} for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True): node = root @@ -151,21 +153,28 @@ def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: rv = set() - def _walk(node: t.Mapping[str, dict], path: t.Tuple[str, ...]) -> None: + def _walk(node: t.Mapping[str, dict[str, t.Any]], path: tuple[str, ...]) -> None: for prefix, child in node.items(): _walk(child, path + (prefix,)) - if not node: + # If there are no more nodes, and a path has been accumulated, add it. + # Path may be empty if the "" entry is in sys.path. + if not node and path: rv.add(os.path.join(*path)) _walk(root, ()) return rv -def _get_args_for_reloading() -> t.List[str]: +def _get_args_for_reloading() -> list[str]: """Determine how the script was executed, and return the args needed to execute it again in a new process. """ + if sys.version_info >= (3, 10): + # sys.orig_argv, added in Python 3.10, contains the exact args used to invoke + # Python. Still replace argv[0] with sys.executable for accuracy. + return [sys.executable, *sys.orig_argv[1:]] + rv = [sys.executable] py_script = sys.argv[0] args = sys.argv[1:] @@ -221,15 +230,15 @@ class ReloaderLoop: def __init__( self, - extra_files: t.Optional[t.Iterable[str]] = None, - exclude_patterns: t.Optional[t.Iterable[str]] = None, - interval: t.Union[int, float] = 1, + extra_files: t.Iterable[str] | None = None, + exclude_patterns: t.Iterable[str] | None = None, + interval: int | float = 1, ) -> None: - self.extra_files: t.Set[str] = {os.path.abspath(x) for x in extra_files or ()} - self.exclude_patterns: t.Set[str] = set(exclude_patterns or ()) + self.extra_files: set[str] = {os.path.abspath(x) for x in extra_files or ()} + self.exclude_patterns: set[str] = set(exclude_patterns or ()) self.interval = interval - def __enter__(self) -> "ReloaderLoop": + def __enter__(self) -> ReloaderLoop: """Do any setup, then run one step of the watch to populate the initial filesystem state. """ @@ -281,7 +290,7 @@ class StatReloaderLoop(ReloaderLoop): name = "stat" def __enter__(self) -> ReloaderLoop: - self.mtimes: t.Dict[str, float] = {} + self.mtimes: dict[str, float] = {} return super().__enter__() def run_step(self) -> None: @@ -303,17 +312,22 @@ def run_step(self) -> None: class WatchdogReloaderLoop(ReloaderLoop): def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: - from watchdog.observers import Observer + from watchdog.events import EVENT_TYPE_OPENED + from watchdog.events import FileModifiedEvent from watchdog.events import PatternMatchingEventHandler + from watchdog.observers import Observer super().__init__(*args, **kwargs) trigger_reload = self.trigger_reload - class EventHandler(PatternMatchingEventHandler): # type: ignore - def on_any_event(self, event): # type: ignore + class EventHandler(PatternMatchingEventHandler): + def on_any_event(self, event: FileModifiedEvent): # type: ignore + if event.event_type == EVENT_TYPE_OPENED: + return + trigger_reload(event.src_path) - reloader_name = Observer.__name__.lower() + reloader_name = Observer.__name__.lower() # type: ignore[attr-defined] if reloader_name.endswith("observer"): reloader_name = reloader_name[:-8] @@ -326,7 +340,7 @@ def on_any_event(self, event): # type: ignore # the source file (or initial pyc file) as well. Ignore Git and # Mercurial internal changes. extra_patterns = [p for p in self.extra_files if not os.path.isdir(p)] - self.event_handler = EventHandler( + self.event_handler = EventHandler( # type: ignore[no-untyped-call] patterns=["*.py", "*.pyc", "*.zip", *extra_patterns], ignore_patterns=[ *[f"*/{d}/*" for d in _ignore_common_dirs], @@ -343,12 +357,12 @@ def trigger_reload(self, filename: str) -> None: self.log_reload(filename) def __enter__(self) -> ReloaderLoop: - self.watches: t.Dict[str, t.Any] = {} - self.observer.start() + self.watches: dict[str, t.Any] = {} + self.observer.start() # type: ignore[no-untyped-call] return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore - self.observer.stop() + self.observer.stop() # type: ignore[no-untyped-call] self.observer.join() def run(self) -> None: @@ -364,7 +378,7 @@ def run_step(self) -> None: for path in _find_watchdog_paths(self.extra_files, self.exclude_patterns): if path not in self.watches: try: - self.watches[path] = self.observer.schedule( + self.watches[path] = self.observer.schedule( # type: ignore[no-untyped-call] self.event_handler, path, recursive=True ) except OSError: @@ -379,10 +393,10 @@ def run_step(self) -> None: watch = self.watches.pop(path, None) if watch is not None: - self.observer.unschedule(watch) + self.observer.unschedule(watch) # type: ignore[no-untyped-call] -reloader_loops: t.Dict[str, t.Type[ReloaderLoop]] = { +reloader_loops: dict[str, type[ReloaderLoop]] = { "stat": StatReloaderLoop, "watchdog": WatchdogReloaderLoop, } @@ -416,9 +430,9 @@ def ensure_echo_on() -> None: def run_with_reloader( main_func: t.Callable[[], None], - extra_files: t.Optional[t.Iterable[str]] = None, - exclude_patterns: t.Optional[t.Iterable[str]] = None, - interval: t.Union[int, float] = 1, + extra_files: t.Iterable[str] | None = None, + exclude_patterns: t.Iterable[str] | None = None, + interval: int | float = 1, reloader_type: str = "auto", ) -> None: """Run the given function in an independent Python interpreter.""" diff --git a/src/werkzeug/datastructures.py b/src/werkzeug/datastructures.py deleted file mode 100644 index 43ee8c7..0000000 --- a/src/werkzeug/datastructures.py +++ /dev/null @@ -1,3040 +0,0 @@ -import base64 -import codecs -import mimetypes -import os -import re -from collections.abc import Collection -from collections.abc import MutableSet -from copy import deepcopy -from io import BytesIO -from itertools import repeat -from os import fspath - -from . import exceptions -from ._internal import _missing - - -def is_immutable(self): - raise TypeError(f"{type(self).__name__!r} objects are immutable") - - -def iter_multi_items(mapping): - """Iterates over the items of a mapping yielding keys and values - without dropping any from more complex structures. - """ - if isinstance(mapping, MultiDict): - yield from mapping.items(multi=True) - elif isinstance(mapping, dict): - for key, value in mapping.items(): - if isinstance(value, (tuple, list)): - for v in value: - yield key, v - else: - yield key, value - else: - yield from mapping - - -class ImmutableListMixin: - """Makes a :class:`list` immutable. - - .. versionadded:: 0.5 - - :private: - """ - - _hash_cache = None - - def __hash__(self): - if self._hash_cache is not None: - return self._hash_cache - rv = self._hash_cache = hash(tuple(self)) - return rv - - def __reduce_ex__(self, protocol): - return type(self), (list(self),) - - def __delitem__(self, key): - is_immutable(self) - - def __iadd__(self, other): - is_immutable(self) - - def __imul__(self, other): - is_immutable(self) - - def __setitem__(self, key, value): - is_immutable(self) - - def append(self, item): - is_immutable(self) - - def remove(self, item): - is_immutable(self) - - def extend(self, iterable): - is_immutable(self) - - def insert(self, pos, value): - is_immutable(self) - - def pop(self, index=-1): - is_immutable(self) - - def reverse(self): - is_immutable(self) - - def sort(self, key=None, reverse=False): - is_immutable(self) - - -class ImmutableList(ImmutableListMixin, list): - """An immutable :class:`list`. - - .. versionadded:: 0.5 - - :private: - """ - - def __repr__(self): - return f"{type(self).__name__}({list.__repr__(self)})" - - -class ImmutableDictMixin: - """Makes a :class:`dict` immutable. - - .. versionadded:: 0.5 - - :private: - """ - - _hash_cache = None - - @classmethod - def fromkeys(cls, keys, value=None): - instance = super().__new__(cls) - instance.__init__(zip(keys, repeat(value))) - return instance - - def __reduce_ex__(self, protocol): - return type(self), (dict(self),) - - def _iter_hashitems(self): - return self.items() - - def __hash__(self): - if self._hash_cache is not None: - return self._hash_cache - rv = self._hash_cache = hash(frozenset(self._iter_hashitems())) - return rv - - def setdefault(self, key, default=None): - is_immutable(self) - - def update(self, *args, **kwargs): - is_immutable(self) - - def pop(self, key, default=None): - is_immutable(self) - - def popitem(self): - is_immutable(self) - - def __setitem__(self, key, value): - is_immutable(self) - - def __delitem__(self, key): - is_immutable(self) - - def clear(self): - is_immutable(self) - - -class ImmutableMultiDictMixin(ImmutableDictMixin): - """Makes a :class:`MultiDict` immutable. - - .. versionadded:: 0.5 - - :private: - """ - - def __reduce_ex__(self, protocol): - return type(self), (list(self.items(multi=True)),) - - def _iter_hashitems(self): - return self.items(multi=True) - - def add(self, key, value): - is_immutable(self) - - def popitemlist(self): - is_immutable(self) - - def poplist(self, key): - is_immutable(self) - - def setlist(self, key, new_list): - is_immutable(self) - - def setlistdefault(self, key, default_list=None): - is_immutable(self) - - -def _calls_update(name): - def oncall(self, *args, **kw): - rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) - - if self.on_update is not None: - self.on_update(self) - - return rv - - oncall.__name__ = name - return oncall - - -class UpdateDictMixin(dict): - """Makes dicts call `self.on_update` on modifications. - - .. versionadded:: 0.5 - - :private: - """ - - on_update = None - - def setdefault(self, key, default=None): - modified = key not in self - rv = super().setdefault(key, default) - if modified and self.on_update is not None: - self.on_update(self) - return rv - - def pop(self, key, default=_missing): - modified = key in self - if default is _missing: - rv = super().pop(key) - else: - rv = super().pop(key, default) - if modified and self.on_update is not None: - self.on_update(self) - return rv - - __setitem__ = _calls_update("__setitem__") - __delitem__ = _calls_update("__delitem__") - clear = _calls_update("clear") - popitem = _calls_update("popitem") - update = _calls_update("update") - - -class TypeConversionDict(dict): - """Works like a regular dict but the :meth:`get` method can perform - type conversions. :class:`MultiDict` and :class:`CombinedMultiDict` - are subclasses of this class and provide the same feature. - - .. versionadded:: 0.5 - """ - - def get(self, key, default=None, type=None): - """Return the default value if the requested data doesn't exist. - If `type` is provided and is a callable it should convert the value, - return it or raise a :exc:`ValueError` if that is not possible. In - this case the function will return the default as if the value was not - found: - - >>> d = TypeConversionDict(foo='42', bar='blub') - >>> d.get('foo', type=int) - 42 - >>> d.get('bar', -1, type=int) - -1 - - :param key: The key to be looked up. - :param default: The default value to be returned if the key can't - be looked up. If not further specified `None` is - returned. - :param type: A callable that is used to cast the value in the - :class:`MultiDict`. If a :exc:`ValueError` is raised - by this callable the default value is returned. - """ - try: - rv = self[key] - except KeyError: - return default - if type is not None: - try: - rv = type(rv) - except ValueError: - rv = default - return rv - - -class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): - """Works like a :class:`TypeConversionDict` but does not support - modifications. - - .. versionadded:: 0.5 - """ - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return TypeConversionDict(self) - - def __copy__(self): - return self - - -class MultiDict(TypeConversionDict): - """A :class:`MultiDict` is a dictionary subclass customized to deal with - multiple values for the same key which is for example used by the parsing - functions in the wrappers. This is necessary because some HTML form - elements pass multiple values for the same key. - - :class:`MultiDict` implements all standard dictionary methods. - Internally, it saves all values for a key as a list, but the standard dict - access methods will only return the first value for a key. If you want to - gain access to the other values, too, you have to use the `list` methods as - explained below. - - Basic Usage: - - >>> d = MultiDict([('a', 'b'), ('a', 'c')]) - >>> d - MultiDict([('a', 'b'), ('a', 'c')]) - >>> d['a'] - 'b' - >>> d.getlist('a') - ['b', 'c'] - >>> 'a' in d - True - - It behaves like a normal dict thus all dict functions will only return the - first value when multiple values for one key are found. - - From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a - subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will - render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP - exceptions. - - A :class:`MultiDict` can be constructed from an iterable of - ``(key, value)`` tuples, a dict, a :class:`MultiDict` or from Werkzeug 0.2 - onwards some keyword parameters. - - :param mapping: the initial value for the :class:`MultiDict`. Either a - regular dict, an iterable of ``(key, value)`` tuples - or `None`. - """ - - def __init__(self, mapping=None): - if isinstance(mapping, MultiDict): - dict.__init__(self, ((k, l[:]) for k, l in mapping.lists())) - elif isinstance(mapping, dict): - tmp = {} - for key, value in mapping.items(): - if isinstance(value, (tuple, list)): - if len(value) == 0: - continue - value = list(value) - else: - value = [value] - tmp[key] = value - dict.__init__(self, tmp) - else: - tmp = {} - for key, value in mapping or (): - tmp.setdefault(key, []).append(value) - dict.__init__(self, tmp) - - def __getstate__(self): - return dict(self.lists()) - - def __setstate__(self, value): - dict.clear(self) - dict.update(self, value) - - def __iter__(self): - # Work around https://bugs.python.org/issue43246. - # (`return super().__iter__()` also works here, which makes this look - # even more like it should be a no-op, yet it isn't.) - return dict.__iter__(self) - - def __getitem__(self, key): - """Return the first data value for this key; - raises KeyError if not found. - - :param key: The key to be looked up. - :raise KeyError: if the key does not exist. - """ - - if key in self: - lst = dict.__getitem__(self, key) - if len(lst) > 0: - return lst[0] - raise exceptions.BadRequestKeyError(key) - - def __setitem__(self, key, value): - """Like :meth:`add` but removes an existing key first. - - :param key: the key for the value. - :param value: the value to set. - """ - dict.__setitem__(self, key, [value]) - - def add(self, key, value): - """Adds a new value for the key. - - .. versionadded:: 0.6 - - :param key: the key for the value. - :param value: the value to add. - """ - dict.setdefault(self, key, []).append(value) - - def getlist(self, key, type=None): - """Return the list of items for a given key. If that key is not in the - `MultiDict`, the return value will be an empty list. Just like `get`, - `getlist` accepts a `type` parameter. All items will be converted - with the callable defined there. - - :param key: The key to be looked up. - :param type: A callable that is used to cast the value in the - :class:`MultiDict`. If a :exc:`ValueError` is raised - by this callable the value will be removed from the list. - :return: a :class:`list` of all the values for the key. - """ - try: - rv = dict.__getitem__(self, key) - except KeyError: - return [] - if type is None: - return list(rv) - result = [] - for item in rv: - try: - result.append(type(item)) - except ValueError: - pass - return result - - def setlist(self, key, new_list): - """Remove the old values for a key and add new ones. Note that the list - you pass the values in will be shallow-copied before it is inserted in - the dictionary. - - >>> d = MultiDict() - >>> d.setlist('foo', ['1', '2']) - >>> d['foo'] - '1' - >>> d.getlist('foo') - ['1', '2'] - - :param key: The key for which the values are set. - :param new_list: An iterable with the new values for the key. Old values - are removed first. - """ - dict.__setitem__(self, key, list(new_list)) - - def setdefault(self, key, default=None): - """Returns the value for the key if it is in the dict, otherwise it - returns `default` and sets that value for `key`. - - :param key: The key to be looked up. - :param default: The default value to be returned if the key is not - in the dict. If not further specified it's `None`. - """ - if key not in self: - self[key] = default - else: - default = self[key] - return default - - def setlistdefault(self, key, default_list=None): - """Like `setdefault` but sets multiple values. The list returned - is not a copy, but the list that is actually used internally. This - means that you can put new values into the dict by appending items - to the list: - - >>> d = MultiDict({"foo": 1}) - >>> d.setlistdefault("foo").extend([2, 3]) - >>> d.getlist("foo") - [1, 2, 3] - - :param key: The key to be looked up. - :param default_list: An iterable of default values. It is either copied - (in case it was a list) or converted into a list - before returned. - :return: a :class:`list` - """ - if key not in self: - default_list = list(default_list or ()) - dict.__setitem__(self, key, default_list) - else: - default_list = dict.__getitem__(self, key) - return default_list - - def items(self, multi=False): - """Return an iterator of ``(key, value)`` pairs. - - :param multi: If set to `True` the iterator returned will have a pair - for each value of each key. Otherwise it will only - contain pairs for the first value of each key. - """ - for key, values in dict.items(self): - if multi: - for value in values: - yield key, value - else: - yield key, values[0] - - def lists(self): - """Return a iterator of ``(key, values)`` pairs, where values is the list - of all values associated with the key.""" - for key, values in dict.items(self): - yield key, list(values) - - def values(self): - """Returns an iterator of the first value on every key's value list.""" - for values in dict.values(self): - yield values[0] - - def listvalues(self): - """Return an iterator of all values associated with a key. Zipping - :meth:`keys` and this is the same as calling :meth:`lists`: - - >>> d = MultiDict({"foo": [1, 2, 3]}) - >>> zip(d.keys(), d.listvalues()) == d.lists() - True - """ - return dict.values(self) - - def copy(self): - """Return a shallow copy of this object.""" - return self.__class__(self) - - def deepcopy(self, memo=None): - """Return a deep copy of this object.""" - return self.__class__(deepcopy(self.to_dict(flat=False), memo)) - - def to_dict(self, flat=True): - """Return the contents as regular dict. If `flat` is `True` the - returned dict will only have the first item present, if `flat` is - `False` all values will be returned as lists. - - :param flat: If set to `False` the dict returned will have lists - with all the values in it. Otherwise it will only - contain the first value for each key. - :return: a :class:`dict` - """ - if flat: - return dict(self.items()) - return dict(self.lists()) - - def update(self, mapping): - """update() extends rather than replaces existing key lists: - - >>> a = MultiDict({'x': 1}) - >>> b = MultiDict({'x': 2, 'y': 3}) - >>> a.update(b) - >>> a - MultiDict([('y', 3), ('x', 1), ('x', 2)]) - - If the value list for a key in ``other_dict`` is empty, no new values - will be added to the dict and the key will not be created: - - >>> x = {'empty_list': []} - >>> y = MultiDict() - >>> y.update(x) - >>> y - MultiDict([]) - """ - for key, value in iter_multi_items(mapping): - MultiDict.add(self, key, value) - - def pop(self, key, default=_missing): - """Pop the first item for a list on the dict. Afterwards the - key is removed from the dict, so additional values are discarded: - - >>> d = MultiDict({"foo": [1, 2, 3]}) - >>> d.pop("foo") - 1 - >>> "foo" in d - False - - :param key: the key to pop. - :param default: if provided the value to return if the key was - not in the dictionary. - """ - try: - lst = dict.pop(self, key) - - if len(lst) == 0: - raise exceptions.BadRequestKeyError(key) - - return lst[0] - except KeyError: - if default is not _missing: - return default - - raise exceptions.BadRequestKeyError(key) from None - - def popitem(self): - """Pop an item from the dict.""" - try: - item = dict.popitem(self) - - if len(item[1]) == 0: - raise exceptions.BadRequestKeyError(item[0]) - - return (item[0], item[1][0]) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - def poplist(self, key): - """Pop the list for a key from the dict. If the key is not in the dict - an empty list is returned. - - .. versionchanged:: 0.5 - If the key does no longer exist a list is returned instead of - raising an error. - """ - return dict.pop(self, key, []) - - def popitemlist(self): - """Pop a ``(key, list)`` tuple from the dict.""" - try: - return dict.popitem(self) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - def __copy__(self): - return self.copy() - - def __deepcopy__(self, memo): - return self.deepcopy(memo=memo) - - def __repr__(self): - return f"{type(self).__name__}({list(self.items(multi=True))!r})" - - -class _omd_bucket: - """Wraps values in the :class:`OrderedMultiDict`. This makes it - possible to keep an order over multiple different keys. It requires - a lot of extra memory and slows down access a lot, but makes it - possible to access elements in O(1) and iterate in O(n). - """ - - __slots__ = ("prev", "key", "value", "next") - - def __init__(self, omd, key, value): - self.prev = omd._last_bucket - self.key = key - self.value = value - self.next = None - - if omd._first_bucket is None: - omd._first_bucket = self - if omd._last_bucket is not None: - omd._last_bucket.next = self - omd._last_bucket = self - - def unlink(self, omd): - if self.prev: - self.prev.next = self.next - if self.next: - self.next.prev = self.prev - if omd._first_bucket is self: - omd._first_bucket = self.next - if omd._last_bucket is self: - omd._last_bucket = self.prev - - -class OrderedMultiDict(MultiDict): - """Works like a regular :class:`MultiDict` but preserves the - order of the fields. To convert the ordered multi dict into a - list you can use the :meth:`items` method and pass it ``multi=True``. - - In general an :class:`OrderedMultiDict` is an order of magnitude - slower than a :class:`MultiDict`. - - .. admonition:: note - - Due to a limitation in Python you cannot convert an ordered - multi dict into a regular dict by using ``dict(multidict)``. - Instead you have to use the :meth:`to_dict` method, otherwise - the internal bucket objects are exposed. - """ - - def __init__(self, mapping=None): - dict.__init__(self) - self._first_bucket = self._last_bucket = None - if mapping is not None: - OrderedMultiDict.update(self, mapping) - - def __eq__(self, other): - if not isinstance(other, MultiDict): - return NotImplemented - if isinstance(other, OrderedMultiDict): - iter1 = iter(self.items(multi=True)) - iter2 = iter(other.items(multi=True)) - try: - for k1, v1 in iter1: - k2, v2 = next(iter2) - if k1 != k2 or v1 != v2: - return False - except StopIteration: - return False - try: - next(iter2) - except StopIteration: - return True - return False - if len(self) != len(other): - return False - for key, values in self.lists(): - if other.getlist(key) != values: - return False - return True - - __hash__ = None - - def __reduce_ex__(self, protocol): - return type(self), (list(self.items(multi=True)),) - - def __getstate__(self): - return list(self.items(multi=True)) - - def __setstate__(self, values): - dict.clear(self) - for key, value in values: - self.add(key, value) - - def __getitem__(self, key): - if key in self: - return dict.__getitem__(self, key)[0].value - raise exceptions.BadRequestKeyError(key) - - def __setitem__(self, key, value): - self.poplist(key) - self.add(key, value) - - def __delitem__(self, key): - self.pop(key) - - def keys(self): - return (key for key, value in self.items()) - - def __iter__(self): - return iter(self.keys()) - - def values(self): - return (value for key, value in self.items()) - - def items(self, multi=False): - ptr = self._first_bucket - if multi: - while ptr is not None: - yield ptr.key, ptr.value - ptr = ptr.next - else: - returned_keys = set() - while ptr is not None: - if ptr.key not in returned_keys: - returned_keys.add(ptr.key) - yield ptr.key, ptr.value - ptr = ptr.next - - def lists(self): - returned_keys = set() - ptr = self._first_bucket - while ptr is not None: - if ptr.key not in returned_keys: - yield ptr.key, self.getlist(ptr.key) - returned_keys.add(ptr.key) - ptr = ptr.next - - def listvalues(self): - for _key, values in self.lists(): - yield values - - def add(self, key, value): - dict.setdefault(self, key, []).append(_omd_bucket(self, key, value)) - - def getlist(self, key, type=None): - try: - rv = dict.__getitem__(self, key) - except KeyError: - return [] - if type is None: - return [x.value for x in rv] - result = [] - for item in rv: - try: - result.append(type(item.value)) - except ValueError: - pass - return result - - def setlist(self, key, new_list): - self.poplist(key) - for value in new_list: - self.add(key, value) - - def setlistdefault(self, key, default_list=None): - raise TypeError("setlistdefault is unsupported for ordered multi dicts") - - def update(self, mapping): - for key, value in iter_multi_items(mapping): - OrderedMultiDict.add(self, key, value) - - def poplist(self, key): - buckets = dict.pop(self, key, ()) - for bucket in buckets: - bucket.unlink(self) - return [x.value for x in buckets] - - def pop(self, key, default=_missing): - try: - buckets = dict.pop(self, key) - except KeyError: - if default is not _missing: - return default - - raise exceptions.BadRequestKeyError(key) from None - - for bucket in buckets: - bucket.unlink(self) - - return buckets[0].value - - def popitem(self): - try: - key, buckets = dict.popitem(self) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - for bucket in buckets: - bucket.unlink(self) - - return key, buckets[0].value - - def popitemlist(self): - try: - key, buckets = dict.popitem(self) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - for bucket in buckets: - bucket.unlink(self) - - return key, [x.value for x in buckets] - - -def _options_header_vkw(value, kw): - return http.dump_options_header( - value, {k.replace("_", "-"): v for k, v in kw.items()} - ) - - -def _unicodify_header_value(value): - if isinstance(value, bytes): - value = value.decode("latin-1") - if not isinstance(value, str): - value = str(value) - return value - - -class Headers: - """An object that stores some headers. It has a dict-like interface, - but is ordered, can store the same key multiple times, and iterating - yields ``(key, value)`` pairs instead of only keys. - - This data structure is useful if you want a nicer way to handle WSGI - headers which are stored as tuples in a list. - - From Werkzeug 0.3 onwards, the :exc:`KeyError` raised by this class is - also a subclass of the :class:`~exceptions.BadRequest` HTTP exception - and will render a page for a ``400 BAD REQUEST`` if caught in a - catch-all for HTTP exceptions. - - Headers is mostly compatible with the Python :class:`wsgiref.headers.Headers` - class, with the exception of `__getitem__`. :mod:`wsgiref` will return - `None` for ``headers['missing']``, whereas :class:`Headers` will raise - a :class:`KeyError`. - - To create a new ``Headers`` object, pass it a list, dict, or - other ``Headers`` object with default values. These values are - validated the same way values added later are. - - :param defaults: The list of default values for the :class:`Headers`. - - .. versionchanged:: 2.1.0 - Default values are validated the same as values added later. - - .. versionchanged:: 0.9 - This data structure now stores unicode values similar to how the - multi dicts do it. The main difference is that bytes can be set as - well which will automatically be latin1 decoded. - - .. versionchanged:: 0.9 - The :meth:`linked` function was removed without replacement as it - was an API that does not support the changes to the encoding model. - """ - - def __init__(self, defaults=None): - self._list = [] - if defaults is not None: - self.extend(defaults) - - def __getitem__(self, key, _get_mode=False): - if not _get_mode: - if isinstance(key, int): - return self._list[key] - elif isinstance(key, slice): - return self.__class__(self._list[key]) - if not isinstance(key, str): - raise exceptions.BadRequestKeyError(key) - ikey = key.lower() - for k, v in self._list: - if k.lower() == ikey: - return v - # micro optimization: if we are in get mode we will catch that - # exception one stack level down so we can raise a standard - # key error instead of our special one. - if _get_mode: - raise KeyError() - raise exceptions.BadRequestKeyError(key) - - def __eq__(self, other): - def lowered(item): - return (item[0].lower(),) + item[1:] - - return other.__class__ is self.__class__ and set( - map(lowered, other._list) - ) == set(map(lowered, self._list)) - - __hash__ = None - - def get(self, key, default=None, type=None, as_bytes=False): - """Return the default value if the requested data doesn't exist. - If `type` is provided and is a callable it should convert the value, - return it or raise a :exc:`ValueError` if that is not possible. In - this case the function will return the default as if the value was not - found: - - >>> d = Headers([('Content-Length', '42')]) - >>> d.get('Content-Length', type=int) - 42 - - .. versionadded:: 0.9 - Added support for `as_bytes`. - - :param key: The key to be looked up. - :param default: The default value to be returned if the key can't - be looked up. If not further specified `None` is - returned. - :param type: A callable that is used to cast the value in the - :class:`Headers`. If a :exc:`ValueError` is raised - by this callable the default value is returned. - :param as_bytes: return bytes instead of strings. - """ - try: - rv = self.__getitem__(key, _get_mode=True) - except KeyError: - return default - if as_bytes: - rv = rv.encode("latin1") - if type is None: - return rv - try: - return type(rv) - except ValueError: - return default - - def getlist(self, key, type=None, as_bytes=False): - """Return the list of items for a given key. If that key is not in the - :class:`Headers`, the return value will be an empty list. Just like - :meth:`get`, :meth:`getlist` accepts a `type` parameter. All items will - be converted with the callable defined there. - - .. versionadded:: 0.9 - Added support for `as_bytes`. - - :param key: The key to be looked up. - :param type: A callable that is used to cast the value in the - :class:`Headers`. If a :exc:`ValueError` is raised - by this callable the value will be removed from the list. - :return: a :class:`list` of all the values for the key. - :param as_bytes: return bytes instead of strings. - """ - ikey = key.lower() - result = [] - for k, v in self: - if k.lower() == ikey: - if as_bytes: - v = v.encode("latin1") - if type is not None: - try: - v = type(v) - except ValueError: - continue - result.append(v) - return result - - def get_all(self, name): - """Return a list of all the values for the named field. - - This method is compatible with the :mod:`wsgiref` - :meth:`~wsgiref.headers.Headers.get_all` method. - """ - return self.getlist(name) - - def items(self, lower=False): - for key, value in self: - if lower: - key = key.lower() - yield key, value - - def keys(self, lower=False): - for key, _ in self.items(lower): - yield key - - def values(self): - for _, value in self.items(): - yield value - - def extend(self, *args, **kwargs): - """Extend headers in this object with items from another object - containing header items as well as keyword arguments. - - To replace existing keys instead of extending, use - :meth:`update` instead. - - If provided, the first argument can be another :class:`Headers` - object, a :class:`MultiDict`, :class:`dict`, or iterable of - pairs. - - .. versionchanged:: 1.0 - Support :class:`MultiDict`. Allow passing ``kwargs``. - """ - if len(args) > 1: - raise TypeError(f"update expected at most 1 arguments, got {len(args)}") - - if args: - for key, value in iter_multi_items(args[0]): - self.add(key, value) - - for key, value in iter_multi_items(kwargs): - self.add(key, value) - - def __delitem__(self, key, _index_operation=True): - if _index_operation and isinstance(key, (int, slice)): - del self._list[key] - return - key = key.lower() - new = [] - for k, v in self._list: - if k.lower() != key: - new.append((k, v)) - self._list[:] = new - - def remove(self, key): - """Remove a key. - - :param key: The key to be removed. - """ - return self.__delitem__(key, _index_operation=False) - - def pop(self, key=None, default=_missing): - """Removes and returns a key or index. - - :param key: The key to be popped. If this is an integer the item at - that position is removed, if it's a string the value for - that key is. If the key is omitted or `None` the last - item is removed. - :return: an item. - """ - if key is None: - return self._list.pop() - if isinstance(key, int): - return self._list.pop(key) - try: - rv = self[key] - self.remove(key) - except KeyError: - if default is not _missing: - return default - raise - return rv - - def popitem(self): - """Removes a key or index and returns a (key, value) item.""" - return self.pop() - - def __contains__(self, key): - """Check if a key is present.""" - try: - self.__getitem__(key, _get_mode=True) - except KeyError: - return False - return True - - def __iter__(self): - """Yield ``(key, value)`` tuples.""" - return iter(self._list) - - def __len__(self): - return len(self._list) - - def add(self, _key, _value, **kw): - """Add a new header tuple to the list. - - Keyword arguments can specify additional parameters for the header - value, with underscores converted to dashes:: - - >>> d = Headers() - >>> d.add('Content-Type', 'text/plain') - >>> d.add('Content-Disposition', 'attachment', filename='foo.png') - - The keyword argument dumping uses :func:`dump_options_header` - behind the scenes. - - .. versionadded:: 0.4.1 - keyword arguments were added for :mod:`wsgiref` compatibility. - """ - if kw: - _value = _options_header_vkw(_value, kw) - _key = _unicodify_header_value(_key) - _value = _unicodify_header_value(_value) - self._validate_value(_value) - self._list.append((_key, _value)) - - def _validate_value(self, value): - if not isinstance(value, str): - raise TypeError("Value should be a string.") - if "\n" in value or "\r" in value: - raise ValueError( - "Detected newline in header value. This is " - "a potential security problem" - ) - - def add_header(self, _key, _value, **_kw): - """Add a new header tuple to the list. - - An alias for :meth:`add` for compatibility with the :mod:`wsgiref` - :meth:`~wsgiref.headers.Headers.add_header` method. - """ - self.add(_key, _value, **_kw) - - def clear(self): - """Clears all headers.""" - del self._list[:] - - def set(self, _key, _value, **kw): - """Remove all header tuples for `key` and add a new one. The newly - added key either appears at the end of the list if there was no - entry or replaces the first one. - - Keyword arguments can specify additional parameters for the header - value, with underscores converted to dashes. See :meth:`add` for - more information. - - .. versionchanged:: 0.6.1 - :meth:`set` now accepts the same arguments as :meth:`add`. - - :param key: The key to be inserted. - :param value: The value to be inserted. - """ - if kw: - _value = _options_header_vkw(_value, kw) - _key = _unicodify_header_value(_key) - _value = _unicodify_header_value(_value) - self._validate_value(_value) - if not self._list: - self._list.append((_key, _value)) - return - listiter = iter(self._list) - ikey = _key.lower() - for idx, (old_key, _old_value) in enumerate(listiter): - if old_key.lower() == ikey: - # replace first occurrence - self._list[idx] = (_key, _value) - break - else: - self._list.append((_key, _value)) - return - self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey] - - def setlist(self, key, values): - """Remove any existing values for a header and add new ones. - - :param key: The header key to set. - :param values: An iterable of values to set for the key. - - .. versionadded:: 1.0 - """ - if values: - values_iter = iter(values) - self.set(key, next(values_iter)) - - for value in values_iter: - self.add(key, value) - else: - self.remove(key) - - def setdefault(self, key, default): - """Return the first value for the key if it is in the headers, - otherwise set the header to the value given by ``default`` and - return that. - - :param key: The header key to get. - :param default: The value to set for the key if it is not in the - headers. - """ - if key in self: - return self[key] - - self.set(key, default) - return default - - def setlistdefault(self, key, default): - """Return the list of values for the key if it is in the - headers, otherwise set the header to the list of values given - by ``default`` and return that. - - Unlike :meth:`MultiDict.setlistdefault`, modifying the returned - list will not affect the headers. - - :param key: The header key to get. - :param default: An iterable of values to set for the key if it - is not in the headers. - - .. versionadded:: 1.0 - """ - if key not in self: - self.setlist(key, default) - - return self.getlist(key) - - def __setitem__(self, key, value): - """Like :meth:`set` but also supports index/slice based setting.""" - if isinstance(key, (slice, int)): - if isinstance(key, int): - value = [value] - value = [ - (_unicodify_header_value(k), _unicodify_header_value(v)) - for (k, v) in value - ] - for (_, v) in value: - self._validate_value(v) - if isinstance(key, int): - self._list[key] = value[0] - else: - self._list[key] = value - else: - self.set(key, value) - - def update(self, *args, **kwargs): - """Replace headers in this object with items from another - headers object and keyword arguments. - - To extend existing keys instead of replacing, use :meth:`extend` - instead. - - If provided, the first argument can be another :class:`Headers` - object, a :class:`MultiDict`, :class:`dict`, or iterable of - pairs. - - .. versionadded:: 1.0 - """ - if len(args) > 1: - raise TypeError(f"update expected at most 1 arguments, got {len(args)}") - - if args: - mapping = args[0] - - if isinstance(mapping, (Headers, MultiDict)): - for key in mapping.keys(): - self.setlist(key, mapping.getlist(key)) - elif isinstance(mapping, dict): - for key, value in mapping.items(): - if isinstance(value, (list, tuple)): - self.setlist(key, value) - else: - self.set(key, value) - else: - for key, value in mapping: - self.set(key, value) - - for key, value in kwargs.items(): - if isinstance(value, (list, tuple)): - self.setlist(key, value) - else: - self.set(key, value) - - def to_wsgi_list(self): - """Convert the headers into a list suitable for WSGI. - - :return: list - """ - return list(self) - - def copy(self): - return self.__class__(self._list) - - def __copy__(self): - return self.copy() - - def __str__(self): - """Returns formatted headers suitable for HTTP transmission.""" - strs = [] - for key, value in self.to_wsgi_list(): - strs.append(f"{key}: {value}") - strs.append("\r\n") - return "\r\n".join(strs) - - def __repr__(self): - return f"{type(self).__name__}({list(self)!r})" - - -class ImmutableHeadersMixin: - """Makes a :class:`Headers` immutable. We do not mark them as - hashable though since the only usecase for this datastructure - in Werkzeug is a view on a mutable structure. - - .. versionadded:: 0.5 - - :private: - """ - - def __delitem__(self, key, **kwargs): - is_immutable(self) - - def __setitem__(self, key, value): - is_immutable(self) - - def set(self, _key, _value, **kw): - is_immutable(self) - - def setlist(self, key, values): - is_immutable(self) - - def add(self, _key, _value, **kw): - is_immutable(self) - - def add_header(self, _key, _value, **_kw): - is_immutable(self) - - def remove(self, key): - is_immutable(self) - - def extend(self, *args, **kwargs): - is_immutable(self) - - def update(self, *args, **kwargs): - is_immutable(self) - - def insert(self, pos, value): - is_immutable(self) - - def pop(self, key=None, default=_missing): - is_immutable(self) - - def popitem(self): - is_immutable(self) - - def setdefault(self, key, default): - is_immutable(self) - - def setlistdefault(self, key, default): - is_immutable(self) - - -class EnvironHeaders(ImmutableHeadersMixin, Headers): - """Read only version of the headers from a WSGI environment. This - provides the same interface as `Headers` and is constructed from - a WSGI environment. - - From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a - subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will - render a page for a ``400 BAD REQUEST`` if caught in a catch-all for - HTTP exceptions. - """ - - def __init__(self, environ): - self.environ = environ - - def __eq__(self, other): - return self.environ is other.environ - - __hash__ = None - - def __getitem__(self, key, _get_mode=False): - # _get_mode is a no-op for this class as there is no index but - # used because get() calls it. - if not isinstance(key, str): - raise KeyError(key) - key = key.upper().replace("-", "_") - if key in ("CONTENT_TYPE", "CONTENT_LENGTH"): - return _unicodify_header_value(self.environ[key]) - return _unicodify_header_value(self.environ[f"HTTP_{key}"]) - - def __len__(self): - # the iter is necessary because otherwise list calls our - # len which would call list again and so forth. - return len(list(iter(self))) - - def __iter__(self): - for key, value in self.environ.items(): - if key.startswith("HTTP_") and key not in ( - "HTTP_CONTENT_TYPE", - "HTTP_CONTENT_LENGTH", - ): - yield ( - key[5:].replace("_", "-").title(), - _unicodify_header_value(value), - ) - elif key in ("CONTENT_TYPE", "CONTENT_LENGTH") and value: - yield (key.replace("_", "-").title(), _unicodify_header_value(value)) - - def copy(self): - raise TypeError(f"cannot create {type(self).__name__!r} copies") - - -class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): - """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict` - instances as sequence and it will combine the return values of all wrapped - dicts: - - >>> from werkzeug.datastructures import CombinedMultiDict, MultiDict - >>> post = MultiDict([('foo', 'bar')]) - >>> get = MultiDict([('blub', 'blah')]) - >>> combined = CombinedMultiDict([get, post]) - >>> combined['foo'] - 'bar' - >>> combined['blub'] - 'blah' - - This works for all read operations and will raise a `TypeError` for - methods that usually change data which isn't possible. - - From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a - subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will - render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP - exceptions. - """ - - def __reduce_ex__(self, protocol): - return type(self), (self.dicts,) - - def __init__(self, dicts=None): - self.dicts = list(dicts) or [] - - @classmethod - def fromkeys(cls, keys, value=None): - raise TypeError(f"cannot create {cls.__name__!r} instances by fromkeys") - - def __getitem__(self, key): - for d in self.dicts: - if key in d: - return d[key] - raise exceptions.BadRequestKeyError(key) - - def get(self, key, default=None, type=None): - for d in self.dicts: - if key in d: - if type is not None: - try: - return type(d[key]) - except ValueError: - continue - return d[key] - return default - - def getlist(self, key, type=None): - rv = [] - for d in self.dicts: - rv.extend(d.getlist(key, type)) - return rv - - def _keys_impl(self): - """This function exists so __len__ can be implemented more efficiently, - saving one list creation from an iterator. - """ - rv = set() - rv.update(*self.dicts) - return rv - - def keys(self): - return self._keys_impl() - - def __iter__(self): - return iter(self.keys()) - - def items(self, multi=False): - found = set() - for d in self.dicts: - for key, value in d.items(multi): - if multi: - yield key, value - elif key not in found: - found.add(key) - yield key, value - - def values(self): - for _key, value in self.items(): - yield value - - def lists(self): - rv = {} - for d in self.dicts: - for key, values in d.lists(): - rv.setdefault(key, []).extend(values) - return list(rv.items()) - - def listvalues(self): - return (x[1] for x in self.lists()) - - def copy(self): - """Return a shallow mutable copy of this object. - - This returns a :class:`MultiDict` representing the data at the - time of copying. The copy will no longer reflect changes to the - wrapped dicts. - - .. versionchanged:: 0.15 - Return a mutable :class:`MultiDict`. - """ - return MultiDict(self) - - def to_dict(self, flat=True): - """Return the contents as regular dict. If `flat` is `True` the - returned dict will only have the first item present, if `flat` is - `False` all values will be returned as lists. - - :param flat: If set to `False` the dict returned will have lists - with all the values in it. Otherwise it will only - contain the first item for each key. - :return: a :class:`dict` - """ - if flat: - return dict(self.items()) - - return dict(self.lists()) - - def __len__(self): - return len(self._keys_impl()) - - def __contains__(self, key): - for d in self.dicts: - if key in d: - return True - return False - - def __repr__(self): - return f"{type(self).__name__}({self.dicts!r})" - - -class FileMultiDict(MultiDict): - """A special :class:`MultiDict` that has convenience methods to add - files to it. This is used for :class:`EnvironBuilder` and generally - useful for unittesting. - - .. versionadded:: 0.5 - """ - - def add_file(self, name, file, filename=None, content_type=None): - """Adds a new file to the dict. `file` can be a file name or - a :class:`file`-like or a :class:`FileStorage` object. - - :param name: the name of the field. - :param file: a filename or :class:`file`-like object - :param filename: an optional filename - :param content_type: an optional content type - """ - if isinstance(file, FileStorage): - value = file - else: - if isinstance(file, str): - if filename is None: - filename = file - file = open(file, "rb") - if filename and content_type is None: - content_type = ( - mimetypes.guess_type(filename)[0] or "application/octet-stream" - ) - value = FileStorage(file, filename, name, content_type) - - self.add(name, value) - - -class ImmutableDict(ImmutableDictMixin, dict): - """An immutable :class:`dict`. - - .. versionadded:: 0.5 - """ - - def __repr__(self): - return f"{type(self).__name__}({dict.__repr__(self)})" - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return dict(self) - - def __copy__(self): - return self - - -class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): - """An immutable :class:`MultiDict`. - - .. versionadded:: 0.5 - """ - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return MultiDict(self) - - def __copy__(self): - return self - - -class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): - """An immutable :class:`OrderedMultiDict`. - - .. versionadded:: 0.6 - """ - - def _iter_hashitems(self): - return enumerate(self.items(multi=True)) - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return OrderedMultiDict(self) - - def __copy__(self): - return self - - -class Accept(ImmutableList): - """An :class:`Accept` object is just a list subclass for lists of - ``(value, quality)`` tuples. It is automatically sorted by specificity - and quality. - - All :class:`Accept` objects work similar to a list but provide extra - functionality for working with the data. Containment checks are - normalized to the rules of that header: - - >>> a = CharsetAccept([('ISO-8859-1', 1), ('utf-8', 0.7)]) - >>> a.best - 'ISO-8859-1' - >>> 'iso-8859-1' in a - True - >>> 'UTF8' in a - True - >>> 'utf7' in a - False - - To get the quality for an item you can use normal item lookup: - - >>> print a['utf-8'] - 0.7 - >>> a['utf7'] - 0 - - .. versionchanged:: 0.5 - :class:`Accept` objects are forced immutable now. - - .. versionchanged:: 1.0.0 - :class:`Accept` internal values are no longer ordered - alphabetically for equal quality tags. Instead the initial - order is preserved. - - """ - - def __init__(self, values=()): - if values is None: - list.__init__(self) - self.provided = False - elif isinstance(values, Accept): - self.provided = values.provided - list.__init__(self, values) - else: - self.provided = True - values = sorted( - values, key=lambda x: (self._specificity(x[0]), x[1]), reverse=True - ) - list.__init__(self, values) - - def _specificity(self, value): - """Returns a tuple describing the value's specificity.""" - return (value != "*",) - - def _value_matches(self, value, item): - """Check if a value matches a given accept item.""" - return item == "*" or item.lower() == value.lower() - - def __getitem__(self, key): - """Besides index lookup (getting item n) you can also pass it a string - to get the quality for the item. If the item is not in the list, the - returned quality is ``0``. - """ - if isinstance(key, str): - return self.quality(key) - return list.__getitem__(self, key) - - def quality(self, key): - """Returns the quality of the key. - - .. versionadded:: 0.6 - In previous versions you had to use the item-lookup syntax - (eg: ``obj[key]`` instead of ``obj.quality(key)``) - """ - for item, quality in self: - if self._value_matches(key, item): - return quality - return 0 - - def __contains__(self, value): - for item, _quality in self: - if self._value_matches(value, item): - return True - return False - - def __repr__(self): - pairs_str = ", ".join(f"({x!r}, {y})" for x, y in self) - return f"{type(self).__name__}([{pairs_str}])" - - def index(self, key): - """Get the position of an entry or raise :exc:`ValueError`. - - :param key: The key to be looked up. - - .. versionchanged:: 0.5 - This used to raise :exc:`IndexError`, which was inconsistent - with the list API. - """ - if isinstance(key, str): - for idx, (item, _quality) in enumerate(self): - if self._value_matches(key, item): - return idx - raise ValueError(key) - return list.index(self, key) - - def find(self, key): - """Get the position of an entry or return -1. - - :param key: The key to be looked up. - """ - try: - return self.index(key) - except ValueError: - return -1 - - def values(self): - """Iterate over all values.""" - for item in self: - yield item[0] - - def to_header(self): - """Convert the header set into an HTTP header string.""" - result = [] - for value, quality in self: - if quality != 1: - value = f"{value};q={quality}" - result.append(value) - return ",".join(result) - - def __str__(self): - return self.to_header() - - def _best_single_match(self, match): - for client_item, quality in self: - if self._value_matches(match, client_item): - # self is sorted by specificity descending, we can exit - return client_item, quality - return None - - def best_match(self, matches, default=None): - """Returns the best match from a list of possible matches based - on the specificity and quality of the client. If two items have the - same quality and specificity, the one is returned that comes first. - - :param matches: a list of matches to check for - :param default: the value that is returned if none match - """ - result = default - best_quality = -1 - best_specificity = (-1,) - for server_item in matches: - match = self._best_single_match(server_item) - if not match: - continue - client_item, quality = match - specificity = self._specificity(client_item) - if quality <= 0 or quality < best_quality: - continue - # better quality or same quality but more specific => better match - if quality > best_quality or specificity > best_specificity: - result = server_item - best_quality = quality - best_specificity = specificity - return result - - @property - def best(self): - """The best match as value.""" - if self: - return self[0][0] - - -_mime_split_re = re.compile(r"/|(?:\s*;\s*)") - - -def _normalize_mime(value): - return _mime_split_re.split(value.lower()) - - -class MIMEAccept(Accept): - """Like :class:`Accept` but with special methods and behavior for - mimetypes. - """ - - def _specificity(self, value): - return tuple(x != "*" for x in _mime_split_re.split(value)) - - def _value_matches(self, value, item): - # item comes from the client, can't match if it's invalid. - if "/" not in item: - return False - - # value comes from the application, tell the developer when it - # doesn't look valid. - if "/" not in value: - raise ValueError(f"invalid mimetype {value!r}") - - # Split the match value into type, subtype, and a sorted list of parameters. - normalized_value = _normalize_mime(value) - value_type, value_subtype = normalized_value[:2] - value_params = sorted(normalized_value[2:]) - - # "*/*" is the only valid value that can start with "*". - if value_type == "*" and value_subtype != "*": - raise ValueError(f"invalid mimetype {value!r}") - - # Split the accept item into type, subtype, and parameters. - normalized_item = _normalize_mime(item) - item_type, item_subtype = normalized_item[:2] - item_params = sorted(normalized_item[2:]) - - # "*/not-*" from the client is invalid, can't match. - if item_type == "*" and item_subtype != "*": - return False - - return ( - (item_type == "*" and item_subtype == "*") - or (value_type == "*" and value_subtype == "*") - ) or ( - item_type == value_type - and ( - item_subtype == "*" - or value_subtype == "*" - or (item_subtype == value_subtype and item_params == value_params) - ) - ) - - @property - def accept_html(self): - """True if this object accepts HTML.""" - return ( - "text/html" in self or "application/xhtml+xml" in self or self.accept_xhtml - ) - - @property - def accept_xhtml(self): - """True if this object accepts XHTML.""" - return "application/xhtml+xml" in self or "application/xml" in self - - @property - def accept_json(self): - """True if this object accepts JSON.""" - return "application/json" in self - - -_locale_delim_re = re.compile(r"[_-]") - - -def _normalize_lang(value): - """Process a language tag for matching.""" - return _locale_delim_re.split(value.lower()) - - -class LanguageAccept(Accept): - """Like :class:`Accept` but with normalization for language tags.""" - - def _value_matches(self, value, item): - return item == "*" or _normalize_lang(value) == _normalize_lang(item) - - def best_match(self, matches, default=None): - """Given a list of supported values, finds the best match from - the list of accepted values. - - Language tags are normalized for the purpose of matching, but - are returned unchanged. - - If no exact match is found, this will fall back to matching - the first subtag (primary language only), first with the - accepted values then with the match values. This partial is not - applied to any other language subtags. - - The default is returned if no exact or fallback match is found. - - :param matches: A list of supported languages to find a match. - :param default: The value that is returned if none match. - """ - # Look for an exact match first. If a client accepts "en-US", - # "en-US" is a valid match at this point. - result = super().best_match(matches) - - if result is not None: - return result - - # Fall back to accepting primary tags. If a client accepts - # "en-US", "en" is a valid match at this point. Need to use - # re.split to account for 2 or 3 letter codes. - fallback = Accept( - [(_locale_delim_re.split(item[0], 1)[0], item[1]) for item in self] - ) - result = fallback.best_match(matches) - - if result is not None: - return result - - # Fall back to matching primary tags. If the client accepts - # "en", "en-US" is a valid match at this point. - fallback_matches = [_locale_delim_re.split(item, 1)[0] for item in matches] - result = super().best_match(fallback_matches) - - # Return a value from the original match list. Find the first - # original value that starts with the matched primary tag. - if result is not None: - return next(item for item in matches if item.startswith(result)) - - return default - - -class CharsetAccept(Accept): - """Like :class:`Accept` but with normalization for charsets.""" - - def _value_matches(self, value, item): - def _normalize(name): - try: - return codecs.lookup(name).name - except LookupError: - return name.lower() - - return item == "*" or _normalize(value) == _normalize(item) - - -def cache_control_property(key, empty, type): - """Return a new property object for a cache header. Useful if you - want to add support for a cache extension in a subclass. - - .. versionchanged:: 2.0 - Renamed from ``cache_property``. - """ - return property( - lambda x: x._get_cache_value(key, empty, type), - lambda x, v: x._set_cache_value(key, v, type), - lambda x: x._del_cache_value(key), - f"accessor for {key!r}", - ) - - -class _CacheControl(UpdateDictMixin, dict): - """Subclass of a dict that stores values for a Cache-Control header. It - has accessors for all the cache-control directives specified in RFC 2616. - The class does not differentiate between request and response directives. - - Because the cache-control directives in the HTTP header use dashes the - python descriptors use underscores for that. - - To get a header of the :class:`CacheControl` object again you can convert - the object into a string or call the :meth:`to_header` method. If you plan - to subclass it and add your own items have a look at the sourcecode for - that class. - - .. versionchanged:: 2.1.0 - Setting int properties such as ``max_age`` will convert the - value to an int. - - .. versionchanged:: 0.4 - - Setting `no_cache` or `private` to boolean `True` will set the implicit - none-value which is ``*``: - - >>> cc = ResponseCacheControl() - >>> cc.no_cache = True - >>> cc -
This page displays all available information about the WSGI server and @@ -139,7 +92,7 @@ """ -def iter_sys_path() -> t.Iterator[t.Tuple[str, bool, bool]]: +def iter_sys_path() -> t.Iterator[tuple[str, bool, bool]]: if os.name == "posix": def strip(x: str) -> str: @@ -159,7 +112,21 @@ def strip(x: str) -> str: yield strip(os.path.normpath(path)), not os.path.isdir(path), path != item -def render_testapp(req: Request) -> bytes: +@Request.application +def test_app(req: Request) -> Response: + """Simple test application that dumps the environment. You can use + it to check if Werkzeug is working properly: + + .. sourcecode:: pycon + + >>> from werkzeug.serving import run_simple + >>> from werkzeug.testapp import test_app + >>> run_simple('localhost', 3000, test_app) + * Running on http://localhost:3000/ + + The application displays important information from the WSGI environment, + the Python interpreter and the installed libraries. + """ try: import pkg_resources except ImportError: @@ -167,7 +134,7 @@ def render_testapp(req: Request) -> bytes: else: eggs = sorted( pkg_resources.working_set, - key=lambda x: x.project_name.lower(), # type: ignore + key=lambda x: x.project_name.lower(), ) python_eggs = [] for egg in eggs: @@ -187,52 +154,38 @@ def render_testapp(req: Request) -> bytes: sys_path = [] for item, virtual, expanded in iter_sys_path(): - class_ = [] + css = [] if virtual: - class_.append("virtual") + css.append("virtual") if expanded: - class_.append("exp") - class_ = f' class="{" ".join(class_)}"' if class_ else "" - sys_path.append(f"
You should be redirected automatically to the target URL: "
- f'{display_location}. If'
- " not, click the link.\n",
+ f'{html_location}. If not, click the link.\n',
code,
mimetype="text/html",
)
@@ -289,7 +279,7 @@ def redirect(
return response
-def append_slash_redirect(environ: "WSGIEnvironment", code: int = 308) -> "Response":
+def append_slash_redirect(environ: WSGIEnvironment, code: int = 308) -> Response:
"""Redirect to the current URL with a slash appended.
If the current URL is ``/user/42``, the redirect URL will be
@@ -327,21 +317,19 @@ def append_slash_redirect(environ: "WSGIEnvironment", code: int = 308) -> "Respo
def send_file(
- path_or_file: t.Union[os.PathLike, str, t.IO[bytes]],
- environ: "WSGIEnvironment",
- mimetype: t.Optional[str] = None,
+ path_or_file: os.PathLike[str] | str | t.IO[bytes],
+ environ: WSGIEnvironment,
+ mimetype: str | None = None,
as_attachment: bool = False,
- download_name: t.Optional[str] = None,
+ download_name: str | None = None,
conditional: bool = True,
- etag: t.Union[bool, str] = True,
- last_modified: t.Optional[t.Union[datetime, int, float]] = None,
- max_age: t.Optional[
- t.Union[int, t.Callable[[t.Optional[str]], t.Optional[int]]]
- ] = None,
+ etag: bool | str = True,
+ last_modified: datetime | int | float | None = None,
+ max_age: None | (int | t.Callable[[str | None], int | None]) = None,
use_x_sendfile: bool = False,
- response_class: t.Optional[t.Type["Response"]] = None,
- _root_path: t.Optional[t.Union[os.PathLike, str]] = None,
-) -> "Response":
+ response_class: type[Response] | None = None,
+ _root_path: os.PathLike[str] | str | None = None,
+) -> Response:
"""Send the contents of a file to the client.
The first argument can be a file path or a file-like object. Paths
@@ -352,7 +340,7 @@ def send_file(
Never pass file paths provided by a user. The path is assumed to be
trusted, so a user could craft a path to access a file you didn't
- intend.
+ intend. Use :func:`send_from_directory` to safely serve user-provided paths.
If the WSGI server sets a ``file_wrapper`` in ``environ``, it is
used, otherwise Werkzeug's built-in wrapper is used. Alternatively,
@@ -419,16 +407,16 @@ def send_file(
response_class = Response
- path: t.Optional[str] = None
- file: t.Optional[t.IO[bytes]] = None
- size: t.Optional[int] = None
- mtime: t.Optional[float] = None
+ path: str | None = None
+ file: t.IO[bytes] | None = None
+ size: int | None = None
+ mtime: float | None = None
headers = Headers()
if isinstance(path_or_file, (os.PathLike, str)) or hasattr(
path_or_file, "__fspath__"
):
- path_or_file = t.cast(t.Union[os.PathLike, str], path_or_file)
+ path_or_file = t.cast("t.Union[os.PathLike[str], str]", path_or_file)
# Flask will pass app.root_path, allowing its send_file wrapper
# to not have to deal with paths.
@@ -470,7 +458,8 @@ def send_file(
except UnicodeEncodeError:
simple = unicodedata.normalize("NFKD", download_name)
simple = simple.encode("ascii", "ignore").decode("ascii")
- quoted = url_quote(download_name, safe="")
+ # safe = RFC 5987 attr-char
+ quoted = quote(download_name, safe="!#$&+-.^_`|~")
names = {"filename": simple, "filename*": f"UTF-8''{quoted}"}
else:
names = {"filename": download_name}
@@ -526,7 +515,7 @@ def send_file(
if isinstance(etag, str):
rv.set_etag(etag)
elif etag and path is not None:
- check = adler32(path.encode("utf-8")) & 0xFFFFFFFF
+ check = adler32(path.encode()) & 0xFFFFFFFF
rv.set_etag(f"{mtime}-{size}-{check}")
if conditional:
@@ -547,11 +536,11 @@ def send_file(
def send_from_directory(
- directory: t.Union[os.PathLike, str],
- path: t.Union[os.PathLike, str],
- environ: "WSGIEnvironment",
+ directory: os.PathLike[str] | str,
+ path: os.PathLike[str] | str,
+ environ: WSGIEnvironment,
**kwargs: t.Any,
-) -> "Response":
+) -> Response:
"""Send a file from within a directory using :func:`send_file`.
This is a secure way to serve files from a folder, such as static
@@ -562,33 +551,30 @@ def send_from_directory(
If the final path does not point to an existing regular file,
returns a 404 :exc:`~werkzeug.exceptions.NotFound` error.
- :param directory: The directory that ``path`` must be located under.
- :param path: The path to the file to send, relative to
- ``directory``.
+ :param directory: The directory that ``path`` must be located under. This *must not*
+ be a value provided by the client, otherwise it becomes insecure.
+ :param path: The path to the file to send, relative to ``directory``. This is the
+ part of the path provided by the client, which is checked for security.
:param environ: The WSGI environ for the current request.
:param kwargs: Arguments to pass to :func:`send_file`.
.. versionadded:: 2.0
Adapted from Flask's implementation.
"""
- path = safe_join(os.fspath(directory), os.fspath(path))
+ path_str = safe_join(os.fspath(directory), os.fspath(path))
- if path is None:
+ if path_str is None:
raise NotFound()
# Flask will pass app.root_path, allowing its send_from_directory
# wrapper to not have to deal with paths.
if "_root_path" in kwargs:
- path = os.path.join(kwargs["_root_path"], path)
+ path_str = os.path.join(kwargs["_root_path"], path_str)
- try:
- if not os.path.isfile(path):
- raise NotFound()
- except ValueError:
- # path contains null byte on Python < 3.8
- raise NotFound() from None
+ if not os.path.isfile(path_str):
+ raise NotFound()
- return send_file(path, environ, **kwargs)
+ return send_file(path_str, environ, **kwargs)
def import_string(import_name: str, silent: bool = False) -> t.Any:
diff --git a/src/werkzeug/wrappers/__init__.py b/src/werkzeug/wrappers/__init__.py
index b8c45d7..b36f228 100644
--- a/src/werkzeug/wrappers/__init__.py
+++ b/src/werkzeug/wrappers/__init__.py
@@ -1,3 +1,3 @@
from .request import Request as Request
from .response import Response as Response
-from .response import ResponseStream
+from .response import ResponseStream as ResponseStream
diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py
index 57b739c..38053c2 100644
--- a/src/werkzeug/wrappers/request.py
+++ b/src/werkzeug/wrappers/request.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
+import collections.abc as cabc
import functools
import json
-import typing
import typing as t
from io import BytesIO
@@ -11,6 +13,8 @@
from ..datastructures import ImmutableMultiDict
from ..datastructures import iter_multi_items
from ..datastructures import MultiDict
+from ..exceptions import BadRequest
+from ..exceptions import UnsupportedMediaType
from ..formparser import default_stream_factory
from ..formparser import FormDataParser
from ..sansio.request import Request as _SansIORequest
@@ -18,10 +22,8 @@
from ..utils import environ_property
from ..wsgi import _get_server
from ..wsgi import get_input_stream
-from werkzeug.exceptions import BadRequest
if t.TYPE_CHECKING:
- import typing_extensions as te
from _typeshed.wsgi import WSGIApplication
from _typeshed.wsgi import WSGIEnvironment
@@ -49,13 +51,19 @@ class Request(_SansIORequest):
prevent consuming the form data in middleware, which would make
it unavailable to the final application.
+ .. versionchanged:: 3.0
+ The ``charset``, ``url_charset``, and ``encoding_errors`` parameters
+ were removed.
+
+ .. versionchanged:: 2.1
+ Old ``BaseRequest`` and mixin classes were removed.
+
.. versionchanged:: 2.1
Remove the ``disable_data_descriptor`` attribute.
.. versionchanged:: 2.0
Combine ``BaseRequest`` and mixins into a single ``Request``
- class. Using the old classes is deprecated and will be removed
- in Werkzeug 2.1.
+ class.
.. versionchanged:: 0.5
Read-only mode is enforced with immutable classes for all data.
@@ -67,10 +75,8 @@ class Request(_SansIORequest):
#: parsing fails because more than the specified value is transmitted
#: a :exc:`~werkzeug.exceptions.RequestEntityTooLarge` exception is raised.
#:
- #: Have a look at :doc:`/request_data` for more details.
- #:
#: .. versionadded:: 0.5
- max_content_length: t.Optional[int] = None
+ max_content_length: int | None = None
#: the maximum form field size. This is forwarded to the form data
#: parsing function (:func:`parse_form_data`). When set and the
@@ -78,18 +84,23 @@ class Request(_SansIORequest):
#: data in memory for post data is longer than the specified value a
#: :exc:`~werkzeug.exceptions.RequestEntityTooLarge` exception is raised.
#:
- #: Have a look at :doc:`/request_data` for more details.
- #:
#: .. versionadded:: 0.5
- max_form_memory_size: t.Optional[int] = None
+ max_form_memory_size: int | None = None
+
+ #: The maximum number of multipart parts to parse, passed to
+ #: :attr:`form_data_parser_class`. Parsing form data with more than this
+ #: many parts will raise :exc:`~.RequestEntityTooLarge`.
+ #:
+ #: .. versionadded:: 2.2.3
+ max_form_parts = 1000
#: The form data parser that should be used. Can be replaced to customize
#: the form date parsing.
- form_data_parser_class: t.Type[FormDataParser] = FormDataParser
+ form_data_parser_class: type[FormDataParser] = FormDataParser
#: The WSGI environment containing HTTP headers and information from
#: the WSGI server.
- environ: "WSGIEnvironment"
+ environ: WSGIEnvironment
#: Set when creating the request object. If ``True``, reading from
#: the request body will cause a ``RuntimeException``. Useful to
@@ -98,7 +109,7 @@ class Request(_SansIORequest):
def __init__(
self,
- environ: "WSGIEnvironment",
+ environ: WSGIEnvironment,
populate_request: bool = True,
shallow: bool = False,
) -> None:
@@ -106,12 +117,8 @@ def __init__(
method=environ.get("REQUEST_METHOD", "GET"),
scheme=environ.get("wsgi.url_scheme", "http"),
server=_get_server(environ),
- root_path=_wsgi_decoding_dance(
- environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors
- ),
- path=_wsgi_decoding_dance(
- environ.get("PATH_INFO") or "", self.charset, self.encoding_errors
- ),
+ root_path=_wsgi_decoding_dance(environ.get("SCRIPT_NAME") or ""),
+ path=_wsgi_decoding_dance(environ.get("PATH_INFO") or ""),
query_string=environ.get("QUERY_STRING", "").encode("latin1"),
headers=EnvironHeaders(environ),
remote_addr=environ.get("REMOTE_ADDR"),
@@ -123,7 +130,7 @@ def __init__(
self.environ["werkzeug.request"] = self
@classmethod
- def from_values(cls, *args: t.Any, **kwargs: t.Any) -> "Request":
+ def from_values(cls, *args: t.Any, **kwargs: t.Any) -> Request:
"""Create a new request object based on the values provided. If
environ is given missing values are filled from there. This method is
useful for small scripts when you need to simulate a request from an URL.
@@ -143,8 +150,6 @@ def from_values(cls, *args: t.Any, **kwargs: t.Any) -> "Request":
"""
from ..test import EnvironBuilder
- charset = kwargs.pop("charset", cls.charset)
- kwargs["charset"] = charset
builder = EnvironBuilder(*args, **kwargs)
try:
return builder.get_request(cls)
@@ -152,9 +157,7 @@ def from_values(cls, *args: t.Any, **kwargs: t.Any) -> "Request":
builder.close()
@classmethod
- def application(
- cls, f: t.Callable[["Request"], "WSGIApplication"]
- ) -> "WSGIApplication":
+ def application(cls, f: t.Callable[[Request], WSGIApplication]) -> WSGIApplication:
"""Decorate a function as responder that accepts the request as
the last argument. This works like the :func:`responder`
decorator but the function is passed the request object as the
@@ -180,23 +183,23 @@ def my_wsgi_app(request):
from ..exceptions import HTTPException
@functools.wraps(f)
- def application(*args): # type: ignore
+ def application(*args: t.Any) -> cabc.Iterable[bytes]:
request = cls(args[-2])
with request:
try:
resp = f(*args[:-2] + (request,))
except HTTPException as e:
- resp = e.get_response(args[-2])
+ resp = t.cast("WSGIApplication", e.get_response(args[-2]))
return resp(*args[-2:])
return t.cast("WSGIApplication", application)
def _get_file_stream(
self,
- total_content_length: t.Optional[int],
- content_type: t.Optional[str],
- filename: t.Optional[str] = None,
- content_length: t.Optional[int] = None,
+ total_content_length: int | None,
+ content_type: str | None,
+ filename: str | None = None,
+ content_length: int | None = None,
) -> t.IO[bytes]:
"""Called to get a stream for the file upload.
@@ -240,12 +243,11 @@ def make_form_data_parser(self) -> FormDataParser:
.. versionadded:: 0.8
"""
return self.form_data_parser_class(
- self._get_file_stream,
- self.charset,
- self.encoding_errors,
- self.max_form_memory_size,
- self.max_content_length,
- self.parameter_storage_class,
+ stream_factory=self._get_file_stream,
+ max_form_memory_size=self.max_form_memory_size,
+ max_content_length=self.max_content_length,
+ max_form_parts=self.max_form_parts,
+ cls=self.parameter_storage_class,
)
def _load_form_data(self) -> None:
@@ -304,7 +306,7 @@ def close(self) -> None:
for _key, value in iter_multi_items(files or ()):
value.close()
- def __enter__(self) -> "Request":
+ def __enter__(self) -> Request:
return self
def __exit__(self, exc_type, exc_value, tb) -> None: # type: ignore
@@ -312,21 +314,30 @@ def __exit__(self, exc_type, exc_value, tb) -> None: # type: ignore
@cached_property
def stream(self) -> t.IO[bytes]:
- """
- If the incoming form data was not encoded with a known mimetype
- the data is stored unmodified in this stream for consumption. Most
- of the time it is a better idea to use :attr:`data` which will give
- you that data as a string. The stream only returns the data once.
+ """The WSGI input stream, with safety checks. This stream can only be consumed
+ once.
+
+ Use :meth:`get_data` to get the full data as bytes or text. The :attr:`data`
+ attribute will contain the full bytes only if they do not represent form data.
+ The :attr:`form` attribute will contain the parsed form data in that case.
+
+ Unlike :attr:`input_stream`, this stream guards against infinite streams or
+ reading past :attr:`content_length` or :attr:`max_content_length`.
+
+ If ``max_content_length`` is set, it can be enforced on streams if
+ ``wsgi.input_terminated`` is set. Otherwise, an empty stream is returned.
- Unlike :attr:`input_stream` this stream is properly guarded that you
- can't accidentally read past the length of the input. Werkzeug will
- internally always refer to this stream to read data which makes it
- possible to wrap this object with a stream that does filtering.
+ If the limit is reached before the underlying stream is exhausted (such as a
+ file that is too large, or an infinite stream), the remaining contents of the
+ stream cannot be read safely. Depending on how the server handles this, clients
+ may show a "connection reset" failure instead of seeing the 413 response.
+
+ .. versionchanged:: 2.3
+ Check ``max_content_length`` preemptively and while reading.
.. versionchanged:: 0.9
- This stream is now always available but might be consumed by the
- form parser later on. Previously the stream was only set if no
- parsing happened.
+ The stream is always set (but may be consumed) even if form parsing was
+ accessed first.
"""
if self.shallow:
raise RuntimeError(
@@ -334,46 +345,49 @@ def stream(self) -> t.IO[bytes]:
" from the input stream is disabled."
)
- return get_input_stream(self.environ)
+ return get_input_stream(
+ self.environ, max_content_length=self.max_content_length
+ )
input_stream = environ_property[t.IO[bytes]](
"wsgi.input",
- doc="""The WSGI input stream.
+ doc="""The raw WSGI input stream, without any safety checks.
+
+ This is dangerous to use. It does not guard against infinite streams or reading
+ past :attr:`content_length` or :attr:`max_content_length`.
- In general it's a bad idea to use this one because you can
- easily read past the boundary. Use the :attr:`stream`
- instead.""",
+ Use :attr:`stream` instead.
+ """,
)
@cached_property
def data(self) -> bytes:
- """
- Contains the incoming request data as string in case it came with
- a mimetype Werkzeug does not handle.
+ """The raw data read from :attr:`stream`. Will be empty if the request
+ represents form data.
+
+ To get the raw data even if it represents form data, use :meth:`get_data`.
"""
return self.get_data(parse_form_data=True)
- @typing.overload
+ @t.overload
def get_data( # type: ignore
self,
cache: bool = True,
- as_text: "te.Literal[False]" = False,
+ as_text: t.Literal[False] = False,
parse_form_data: bool = False,
- ) -> bytes:
- ...
+ ) -> bytes: ...
- @typing.overload
+ @t.overload
def get_data(
self,
cache: bool = True,
- as_text: "te.Literal[True]" = ...,
+ as_text: t.Literal[True] = ...,
parse_form_data: bool = False,
- ) -> str:
- ...
+ ) -> str: ...
def get_data(
self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False
- ) -> t.Union[bytes, str]:
+ ) -> bytes | str:
"""This reads the buffered incoming data from the client into one
bytes object. By default this is cached but that behavior can be
changed by setting `cache` to `False`.
@@ -406,11 +420,11 @@ def get_data(
if cache:
self._cached_data = rv
if as_text:
- rv = rv.decode(self.charset, self.encoding_errors)
+ rv = rv.decode(errors="replace")
return rv
@cached_property
- def form(self) -> "ImmutableMultiDict[str, str]":
+ def form(self) -> ImmutableMultiDict[str, str]:
"""The form parameters. By default an
:class:`~werkzeug.datastructures.ImmutableMultiDict`
is returned from this function. This can be changed by setting
@@ -429,7 +443,7 @@ def form(self) -> "ImmutableMultiDict[str, str]":
return self.form
@cached_property
- def values(self) -> "CombinedMultiDict[str, str]":
+ def values(self) -> CombinedMultiDict[str, str]:
"""A :class:`werkzeug.datastructures.CombinedMultiDict` that
combines :attr:`args` and :attr:`form`.
@@ -458,7 +472,7 @@ def values(self) -> "CombinedMultiDict[str, str]":
return CombinedMultiDict(args)
@cached_property
- def files(self) -> "ImmutableMultiDict[str, FileStorage]":
+ def files(self) -> ImmutableMultiDict[str, FileStorage]:
""":class:`~werkzeug.datastructures.MultiDict` object containing
all uploaded files. Each key in :attr:`files` is the name from the
````. Each value in :attr:`files` is a
@@ -525,14 +539,17 @@ def url_root(self) -> str:
json_module = json
@property
- def json(self) -> t.Optional[t.Any]:
+ def json(self) -> t.Any | None:
"""The parsed JSON data if :attr:`mimetype` indicates JSON
(:mimetype:`application/json`, see :attr:`is_json`).
Calls :meth:`get_json` with default arguments.
If the request content type is not ``application/json``, this
- will raise a 400 Bad Request error.
+ will raise a 415 Unsupported Media Type error.
+
+ .. versionchanged:: 2.3
+ Raise a 415 error instead of 400.
.. versionchanged:: 2.1
Raise a 400 error if the content type is incorrect.
@@ -541,18 +558,28 @@ def json(self) -> t.Optional[t.Any]:
# Cached values for ``(silent=False, silent=True)``. Initialized
# with sentinel values.
- _cached_json: t.Tuple[t.Any, t.Any] = (Ellipsis, Ellipsis)
+ _cached_json: tuple[t.Any, t.Any] = (Ellipsis, Ellipsis)
+
+ @t.overload
+ def get_json(
+ self, force: bool = ..., silent: t.Literal[False] = ..., cache: bool = ...
+ ) -> t.Any: ...
+
+ @t.overload
+ def get_json(
+ self, force: bool = ..., silent: bool = ..., cache: bool = ...
+ ) -> t.Any | None: ...
def get_json(
self, force: bool = False, silent: bool = False, cache: bool = True
- ) -> t.Optional[t.Any]:
+ ) -> t.Any | None:
"""Parse :attr:`data` as JSON.
If the mimetype does not indicate JSON
(:mimetype:`application/json`, see :attr:`is_json`), or parsing
fails, :meth:`on_json_loading_failed` is called and
its return value is used as the return value. By default this
- raises a 400 Bad Request error.
+ raises a 415 Unsupported Media Type resp.
:param force: Ignore the mimetype and always try to parse JSON.
:param silent: Silence mimetype and parsing errors, and
@@ -560,6 +587,9 @@ def get_json(
:param cache: Store the parsed JSON to return for subsequent
calls.
+ .. versionchanged:: 2.3
+ Raise a 415 error instead of 400.
+
.. versionchanged:: 2.1
Raise a 400 error if the content type is incorrect.
"""
@@ -595,7 +625,7 @@ def get_json(
return rv
- def on_json_loading_failed(self, e: t.Optional[ValueError]) -> t.Any:
+ def on_json_loading_failed(self, e: ValueError | None) -> t.Any:
"""Called if :meth:`get_json` fails and isn't silenced.
If this method returns a value, it is used as the return value
@@ -604,11 +634,14 @@ def on_json_loading_failed(self, e: t.Optional[ValueError]) -> t.Any:
:param e: If parsing failed, this is the exception. It will be
``None`` if the content type wasn't ``application/json``.
+
+ .. versionchanged:: 2.3
+ Raise a 415 error instead of 400.
"""
if e is not None:
raise BadRequest(f"Failed to decode JSON object: {e}")
- raise BadRequest(
+ raise UnsupportedMediaType(
"Did not attempt to load JSON data because the request"
" Content-Type was not 'application/json'."
)
diff --git a/src/werkzeug/wrappers/response.py b/src/werkzeug/wrappers/response.py
index 7e888cb..7f01287 100644
--- a/src/werkzeug/wrappers/response.py
+++ b/src/werkzeug/wrappers/response.py
@@ -1,69 +1,41 @@
+from __future__ import annotations
+
import json
-import typing
import typing as t
-import warnings
from http import HTTPStatus
+from urllib.parse import urljoin
-from .._internal import _to_bytes
+from .._internal import _get_environ
from ..datastructures import Headers
+from ..http import generate_etag
+from ..http import http_date
+from ..http import is_resource_modified
+from ..http import parse_etags
+from ..http import parse_range_header
from ..http import remove_entity_headers
from ..sansio.response import Response as _SansIOResponse
from ..urls import iri_to_uri
-from ..urls import url_join
from ..utils import cached_property
+from ..wsgi import _RangeWrapper
from ..wsgi import ClosingIterator
from ..wsgi import get_current_url
-from werkzeug._internal import _get_environ
-from werkzeug.http import generate_etag
-from werkzeug.http import http_date
-from werkzeug.http import is_resource_modified
-from werkzeug.http import parse_etags
-from werkzeug.http import parse_range_header
-from werkzeug.wsgi import _RangeWrapper
if t.TYPE_CHECKING:
- import typing_extensions as te
from _typeshed.wsgi import StartResponse
from _typeshed.wsgi import WSGIApplication
from _typeshed.wsgi import WSGIEnvironment
- from .request import Request
-
-def _warn_if_string(iterable: t.Iterable) -> None:
- """Helper for the response objects to check if the iterable returned
- to the WSGI server is not a string.
- """
- if isinstance(iterable, str):
- warnings.warn(
- "Response iterable was set to a string. This will appear to"
- " work but means that the server will send the data to the"
- " client one character at a time. This is almost never"
- " intended behavior, use 'response.data' to assign strings"
- " to the response object.",
- stacklevel=2,
- )
+ from .request import Request
-def _iter_encoded(
- iterable: t.Iterable[t.Union[str, bytes]], charset: str
-) -> t.Iterator[bytes]:
+def _iter_encoded(iterable: t.Iterable[str | bytes]) -> t.Iterator[bytes]:
for item in iterable:
if isinstance(item, str):
- yield item.encode(charset)
+ yield item.encode()
else:
yield item
-def _clean_accept_ranges(accept_ranges: t.Union[bool, str]) -> str:
- if accept_ranges is True:
- return "bytes"
- elif accept_ranges is False:
- return "none"
- elif isinstance(accept_ranges, str):
- return accept_ranges
- raise ValueError("Invalid accept_ranges value")
-
-
class Response(_SansIOResponse):
"""Represents an outgoing WSGI HTTP response with body, status, and
headers. Has properties and methods for using the functionality
@@ -123,10 +95,12 @@ def application(environ, start_response):
checks. Use :func:`~werkzeug.utils.send_file` instead of setting
this manually.
+ .. versionchanged:: 2.1
+ Old ``BaseResponse`` and mixin classes were removed.
+
.. versionchanged:: 2.0
Combine ``BaseResponse`` and mixins into a single ``Response``
- class. Using the old classes is deprecated and will be removed
- in Werkzeug 2.1.
+ class.
.. versionchanged:: 0.5
The ``direct_passthrough`` parameter was added.
@@ -165,22 +139,17 @@ def application(environ, start_response):
#: Do not set to a plain string or bytes, that will cause sending
#: the response to be very inefficient as it will iterate one byte
#: at a time.
- response: t.Union[t.Iterable[str], t.Iterable[bytes]]
+ response: t.Iterable[str] | t.Iterable[bytes]
def __init__(
self,
- response: t.Optional[
- t.Union[t.Iterable[bytes], bytes, t.Iterable[str], str]
- ] = None,
- status: t.Optional[t.Union[int, str, HTTPStatus]] = None,
- headers: t.Optional[
- t.Union[
- t.Mapping[str, t.Union[str, int, t.Iterable[t.Union[str, int]]]],
- t.Iterable[t.Tuple[str, t.Union[str, int]]],
- ]
- ] = None,
- mimetype: t.Optional[str] = None,
- content_type: t.Optional[str] = None,
+ response: t.Iterable[bytes] | bytes | t.Iterable[str] | str | None = None,
+ status: int | str | HTTPStatus | None = None,
+ headers: t.Mapping[str, str | t.Iterable[str]]
+ | t.Iterable[tuple[str, str]]
+ | None = None,
+ mimetype: str | None = None,
+ content_type: str | None = None,
direct_passthrough: bool = False,
) -> None:
super().__init__(
@@ -196,7 +165,7 @@ def __init__(
#: :func:`~werkzeug.utils.send_file` instead of setting this
#: manually.
self.direct_passthrough = direct_passthrough
- self._on_close: t.List[t.Callable[[], t.Any]] = []
+ self._on_close: list[t.Callable[[], t.Any]] = []
# we set the response after the headers so that if a class changes
# the charset attribute, the data is set in the correct charset.
@@ -227,8 +196,8 @@ def __repr__(self) -> str:
@classmethod
def force_type(
- cls, response: "Response", environ: t.Optional["WSGIEnvironment"] = None
- ) -> "Response":
+ cls, response: Response, environ: WSGIEnvironment | None = None
+ ) -> Response:
"""Enforce that the WSGI response is a response object of the current
type. Werkzeug will use the :class:`Response` internally in many
situations like the exceptions. If you call :meth:`get_response` on an
@@ -272,8 +241,8 @@ def force_type(
@classmethod
def from_app(
- cls, app: "WSGIApplication", environ: "WSGIEnvironment", buffered: bool = False
- ) -> "Response":
+ cls, app: WSGIApplication, environ: WSGIEnvironment, buffered: bool = False
+ ) -> Response:
"""Create a new response object from an application output. This
works best if you pass it an application that returns a generator all
the time. Sometimes applications may use the `write()` callable
@@ -290,15 +259,13 @@ def from_app(
return cls(*run_wsgi_app(app, environ, buffered))
- @typing.overload
- def get_data(self, as_text: "te.Literal[False]" = False) -> bytes:
- ...
+ @t.overload
+ def get_data(self, as_text: t.Literal[False] = False) -> bytes: ...
- @typing.overload
- def get_data(self, as_text: "te.Literal[True]") -> str:
- ...
+ @t.overload
+ def get_data(self, as_text: t.Literal[True]) -> str: ...
- def get_data(self, as_text: bool = False) -> t.Union[bytes, str]:
+ def get_data(self, as_text: bool = False) -> bytes | str:
"""The string representation of the response body. Whenever you call
this property the response iterable is encoded and flattened. This
can lead to unwanted behavior if you stream big data.
@@ -315,23 +282,19 @@ def get_data(self, as_text: bool = False) -> t.Union[bytes, str]:
rv = b"".join(self.iter_encoded())
if as_text:
- return rv.decode(self.charset)
+ return rv.decode()
return rv
- def set_data(self, value: t.Union[bytes, str]) -> None:
+ def set_data(self, value: bytes | str) -> None:
"""Sets a new string as response. The value must be a string or
bytes. If a string is set it's encoded to the charset of the
response (utf-8 by default).
.. versionadded:: 0.9
"""
- # if a string is set, it's encoded directly so that we
- # can set the content length
if isinstance(value, str):
- value = value.encode(self.charset)
- else:
- value = bytes(value)
+ value = value.encode()
self.response = [value]
if self.automatically_set_content_length:
self.headers["Content-Length"] = str(len(value))
@@ -342,7 +305,7 @@ def set_data(self, value: t.Union[bytes, str]) -> None:
doc="A descriptor that calls :meth:`get_data` and :meth:`set_data`.",
)
- def calculate_content_length(self) -> t.Optional[int]:
+ def calculate_content_length(self) -> int | None:
"""Returns the content length if available or `None` otherwise."""
try:
self._ensure_sequence()
@@ -398,12 +361,10 @@ def iter_encoded(self) -> t.Iterator[bytes]:
value of this method is used as application iterator unless
:attr:`direct_passthrough` was activated.
"""
- if __debug__:
- _warn_if_string(self.response)
# Encode in a separate function so that self.response is fetched
# early. This allows us to wrap the response with the return
# value from get_app_iter or iter_encoded.
- return _iter_encoded(self.response, self.charset)
+ return _iter_encoded(self.response)
@property
def is_streamed(self) -> bool:
@@ -439,11 +400,11 @@ def close(self) -> None:
Can now be used in a with statement.
"""
if hasattr(self.response, "close"):
- self.response.close() # type: ignore
+ self.response.close()
for func in self._on_close:
func()
- def __enter__(self) -> "Response":
+ def __enter__(self) -> Response:
return self
def __exit__(self, exc_type, exc_value, tb): # type: ignore
@@ -463,8 +424,7 @@ def freeze(self) -> None:
Removed the ``no_etag`` parameter.
.. versionchanged:: 2.0
- An ``ETag`` header is added, the ``no_etag`` parameter is
- deprecated and will be removed in Werkzeug 2.1.
+ An ``ETag`` header is always added.
.. versionchanged:: 0.6
The ``Content-Length`` header is set.
@@ -475,7 +435,7 @@ def freeze(self) -> None:
self.headers["Content-Length"] = str(sum(map(len, self.response)))
self.add_etag()
- def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers:
+ def get_wsgi_headers(self, environ: WSGIEnvironment) -> Headers:
"""This is automatically called right before the response is started
and returns headers modified for the given environment. It returns a
copy of the headers from the response with some modifications applied
@@ -500,9 +460,9 @@ def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers:
object.
"""
headers = Headers(self.headers)
- location: t.Optional[str] = None
- content_location: t.Optional[str] = None
- content_length: t.Optional[t.Union[str, int]] = None
+ location: str | None = None
+ content_location: str | None = None
+ content_length: str | int | None = None
status = self.status_code
# iterate over the headers to find all values in one go. Because
@@ -517,24 +477,19 @@ def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers:
elif ikey == "content-length":
content_length = value
- # make sure the location header is an absolute URL
if location is not None:
- old_location = location
- if isinstance(location, str):
- # Safe conversion is necessary here as we might redirect
- # to a broken URI scheme (for instance itms-services).
- location = iri_to_uri(location, safe_conversion=True)
+ location = iri_to_uri(location)
if self.autocorrect_location_header:
+ # Make the location header an absolute URL.
current_url = get_current_url(environ, strip_querystring=True)
- if isinstance(current_url, str):
- current_url = iri_to_uri(current_url)
- location = url_join(current_url, location)
- if location != old_location:
- headers["Location"] = location
+ current_url = iri_to_uri(current_url)
+ location = urljoin(current_url, location)
+
+ headers["Location"] = location
# make sure the content location is a URL
- if content_location is not None and isinstance(content_location, str):
+ if content_location is not None:
headers["Content-Location"] = iri_to_uri(content_location)
if 100 <= status < 200 or status == 204:
@@ -557,18 +512,12 @@ def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers:
and status not in (204, 304)
and not (100 <= status < 200)
):
- try:
- content_length = sum(len(_to_bytes(x, "ascii")) for x in self.response)
- except UnicodeError:
- # Something other than bytes, can't safely figure out
- # the length of the response.
- pass
- else:
- headers["Content-Length"] = str(content_length)
+ content_length = sum(len(x) for x in self.iter_encoded())
+ headers["Content-Length"] = str(content_length)
return headers
- def get_app_iter(self, environ: "WSGIEnvironment") -> t.Iterable[bytes]:
+ def get_app_iter(self, environ: WSGIEnvironment) -> t.Iterable[bytes]:
"""Returns the application iterator for the given environ. Depending
on the request method and the current status code the return value
might be an empty response rather than the one from the response.
@@ -590,16 +539,14 @@ def get_app_iter(self, environ: "WSGIEnvironment") -> t.Iterable[bytes]:
):
iterable: t.Iterable[bytes] = ()
elif self.direct_passthrough:
- if __debug__:
- _warn_if_string(self.response)
return self.response # type: ignore
else:
iterable = self.iter_encoded()
return ClosingIterator(iterable, self.close)
def get_wsgi_response(
- self, environ: "WSGIEnvironment"
- ) -> t.Tuple[t.Iterable[bytes], str, t.List[t.Tuple[str, str]]]:
+ self, environ: WSGIEnvironment
+ ) -> tuple[t.Iterable[bytes], str, list[tuple[str, str]]]:
"""Returns the final WSGI response as tuple. The first item in
the tuple is the application iterator, the second the status and
the third the list of headers. The response returned is created
@@ -617,7 +564,7 @@ def get_wsgi_response(
return app_iter, self.status, headers.to_wsgi_list()
def __call__(
- self, environ: "WSGIEnvironment", start_response: "StartResponse"
+ self, environ: WSGIEnvironment, start_response: StartResponse
) -> t.Iterable[bytes]:
"""Process this response as WSGI application.
@@ -637,7 +584,7 @@ def __call__(
json_module = json
@property
- def json(self) -> t.Optional[t.Any]:
+ def json(self) -> t.Any | None:
"""The parsed JSON data if :attr:`mimetype` indicates JSON
(:mimetype:`application/json`, see :attr:`is_json`).
@@ -645,7 +592,13 @@ def json(self) -> t.Optional[t.Any]:
"""
return self.get_json()
- def get_json(self, force: bool = False, silent: bool = False) -> t.Optional[t.Any]:
+ @t.overload
+ def get_json(self, force: bool = ..., silent: t.Literal[False] = ...) -> t.Any: ...
+
+ @t.overload
+ def get_json(self, force: bool = ..., silent: bool = ...) -> t.Any | None: ...
+
+ def get_json(self, force: bool = False, silent: bool = False) -> t.Any | None:
"""Parse :attr:`data` as JSON. Useful during testing.
If the mimetype does not indicate JSON
@@ -674,7 +627,7 @@ def get_json(self, force: bool = False, silent: bool = False) -> t.Optional[t.An
# Stream
@cached_property
- def stream(self) -> "ResponseStream":
+ def stream(self) -> ResponseStream:
"""The response iterable as write-only stream."""
return ResponseStream(self)
@@ -683,7 +636,7 @@ def _wrap_range_response(self, start: int, length: int) -> None:
if self.status_code == 206:
self.response = _RangeWrapper(self.response, start, length) # type: ignore
- def _is_range_request_processable(self, environ: "WSGIEnvironment") -> bool:
+ def _is_range_request_processable(self, environ: WSGIEnvironment) -> bool:
"""Return ``True`` if `Range` header is present and if underlying
resource is considered unchanged when compared with `If-Range` header.
"""
@@ -700,9 +653,9 @@ def _is_range_request_processable(self, environ: "WSGIEnvironment") -> bool:
def _process_range_request(
self,
- environ: "WSGIEnvironment",
- complete_length: t.Optional[int] = None,
- accept_ranges: t.Optional[t.Union[bool, str]] = None,
+ environ: WSGIEnvironment,
+ complete_length: int | None,
+ accept_ranges: bool | str,
) -> bool:
"""Handle Range Request related headers (RFC7233). If `Accept-Ranges`
header is valid, and Range Request is processable, we set the headers
@@ -720,13 +673,16 @@ def _process_range_request(
from ..exceptions import RequestedRangeNotSatisfiable
if (
- accept_ranges is None
+ not accept_ranges
or complete_length is None
or complete_length == 0
or not self._is_range_request_processable(environ)
):
return False
+ if accept_ranges is True:
+ accept_ranges = "bytes"
+
parsed_range = parse_range_header(environ.get("HTTP_RANGE"))
if parsed_range is None:
@@ -739,7 +695,7 @@ def _process_range_request(
raise RequestedRangeNotSatisfiable(complete_length)
content_length = range_tuple[1] - range_tuple[0]
- self.headers["Content-Length"] = content_length
+ self.headers["Content-Length"] = str(content_length)
self.headers["Accept-Ranges"] = accept_ranges
self.content_range = content_range_header # type: ignore
self.status_code = 206
@@ -748,10 +704,10 @@ def _process_range_request(
def make_conditional(
self,
- request_or_environ: t.Union["WSGIEnvironment", "Request"],
- accept_ranges: t.Union[bool, str] = False,
- complete_length: t.Optional[int] = None,
- ) -> "Response":
+ request_or_environ: WSGIEnvironment | Request,
+ accept_ranges: bool | str = False,
+ complete_length: int | None = None,
+ ) -> Response:
"""Make the response conditional to the request. This method works
best if an etag was defined for the response already. The `add_etag`
method can be used to do that. If called without etag just the date
@@ -777,8 +733,7 @@ def make_conditional(
:param accept_ranges: This parameter dictates the value of
`Accept-Ranges` header. If ``False`` (default),
the header is not set. If ``True``, it will be set
- to ``"bytes"``. If ``None``, it will be set to
- ``"none"``. If it's a string, it will use this
+ to ``"bytes"``. If it's a string, it will use this
value.
:param complete_length: Will be used only in valid Range Requests.
It will set `Content-Range` complete length
@@ -800,7 +755,6 @@ def make_conditional(
# wsgiref.
if "date" not in self.headers:
self.headers["Date"] = http_date()
- accept_ranges = _clean_accept_ranges(accept_ranges)
is206 = self._process_range_request(environ, complete_length, accept_ranges)
if not is206 and not is_resource_modified(
environ,
@@ -818,7 +772,7 @@ def make_conditional(
):
length = self.calculate_content_length()
if length is not None:
- self.headers["Content-Length"] = length
+ self.headers["Content-Length"] = str(length)
return self
def add_etag(self, overwrite: bool = False, weak: bool = False) -> None:
@@ -874,4 +828,4 @@ def tell(self) -> int:
@property
def encoding(self) -> str:
- return self.response.charset
+ return "utf-8"
diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py
index 24ece0b..01d40af 100644
--- a/src/werkzeug/wsgi.py
+++ b/src/werkzeug/wsgi.py
@@ -1,28 +1,21 @@
+from __future__ import annotations
+
import io
-import re
import typing as t
-import warnings
from functools import partial
from functools import update_wrapper
-from itertools import chain
-from ._internal import _make_encode_wrapper
-from ._internal import _to_bytes
-from ._internal import _to_str
+from .exceptions import ClientDisconnected
+from .exceptions import RequestEntityTooLarge
from .sansio import utils as _sansio_utils
from .sansio.utils import host_is_trusted # noqa: F401 # Imported as part of API
-from .urls import _URLTuple
-from .urls import uri_to_iri
-from .urls import url_join
-from .urls import url_parse
-from .urls import url_quote
if t.TYPE_CHECKING:
from _typeshed.wsgi import WSGIApplication
from _typeshed.wsgi import WSGIEnvironment
-def responder(f: t.Callable[..., "WSGIApplication"]) -> "WSGIApplication":
+def responder(f: t.Callable[..., WSGIApplication]) -> WSGIApplication:
"""Marks a function as responder. Decorate a function with it and it
will automatically call the return value as WSGI application.
@@ -36,11 +29,11 @@ def application(environ, start_response):
def get_current_url(
- environ: "WSGIEnvironment",
+ environ: WSGIEnvironment,
root_only: bool = False,
strip_querystring: bool = False,
host_only: bool = False,
- trusted_hosts: t.Optional[t.Iterable[str]] = None,
+ trusted_hosts: t.Iterable[str] | None = None,
) -> str:
"""Recreate the URL for a request from the parts in a WSGI
environment.
@@ -74,15 +67,15 @@ def get_current_url(
def _get_server(
- environ: "WSGIEnvironment",
-) -> t.Optional[t.Tuple[str, t.Optional[int]]]:
+ environ: WSGIEnvironment,
+) -> tuple[str, int | None] | None:
name = environ.get("SERVER_NAME")
if name is None:
return None
try:
- port: t.Optional[int] = int(environ.get("SERVER_PORT", None))
+ port: int | None = int(environ.get("SERVER_PORT", None))
except (TypeError, ValueError):
# unix socket
port = None
@@ -91,7 +84,7 @@ def _get_server(
def get_host(
- environ: "WSGIEnvironment", trusted_hosts: t.Optional[t.Iterable[str]] = None
+ environ: WSGIEnvironment, trusted_hosts: t.Iterable[str] | None = None
) -> str:
"""Return the host for the given WSGI environment.
@@ -118,337 +111,101 @@ def get_host(
)
-def get_content_length(environ: "WSGIEnvironment") -> t.Optional[int]:
- """Returns the content length from the WSGI environment as
- integer. If it's not available or chunked transfer encoding is used,
- ``None`` is returned.
+def get_content_length(environ: WSGIEnvironment) -> int | None:
+ """Return the ``Content-Length`` header value as an int. If the header is not given
+ or the ``Transfer-Encoding`` header is ``chunked``, ``None`` is returned to indicate
+ a streaming request. If the value is not an integer, or negative, 0 is returned.
- .. versionadded:: 0.9
+ :param environ: The WSGI environ to get the content length from.
- :param environ: the WSGI environ to fetch the content length from.
+ .. versionadded:: 0.9
"""
return _sansio_utils.get_content_length(
http_content_length=environ.get("CONTENT_LENGTH"),
- http_transfer_encoding=environ.get("HTTP_TRANSFER_ENCODING", ""),
+ http_transfer_encoding=environ.get("HTTP_TRANSFER_ENCODING"),
)
def get_input_stream(
- environ: "WSGIEnvironment", safe_fallback: bool = True
+ environ: WSGIEnvironment,
+ safe_fallback: bool = True,
+ max_content_length: int | None = None,
) -> t.IO[bytes]:
- """Returns the input stream from the WSGI environment and wraps it
- in the most sensible way possible. The stream returned is not the
- raw WSGI stream in most cases but one that is safe to read from
- without taking into account the content length.
-
- If content length is not set, the stream will be empty for safety reasons.
- If the WSGI server supports chunked or infinite streams, it should set
- the ``wsgi.input_terminated`` value in the WSGI environ to indicate that.
-
- .. versionadded:: 0.9
-
- :param environ: the WSGI environ to fetch the stream from.
- :param safe_fallback: use an empty stream as a safe fallback when the
- content length is not set. Disabling this allows infinite streams,
- which can be a denial-of-service risk.
- """
- stream = t.cast(t.IO[bytes], environ["wsgi.input"])
- content_length = get_content_length(environ)
+ """Return the WSGI input stream, wrapped so that it may be read safely without going
+ past the ``Content-Length`` header value or ``max_content_length``.
- # A wsgi extension that tells us if the input is terminated. In
- # that case we return the stream unchanged as we know we can safely
- # read it until the end.
- if environ.get("wsgi.input_terminated"):
- return stream
+ If ``Content-Length`` exceeds ``max_content_length``, a
+ :exc:`RequestEntityTooLarge`` ``413 Content Too Large`` error is raised.
- # If the request doesn't specify a content length, returning the stream is
- # potentially dangerous because it could be infinite, malicious or not. If
- # safe_fallback is true, return an empty stream instead for safety.
- if content_length is None:
- return io.BytesIO() if safe_fallback else stream
+ If the WSGI server sets ``environ["wsgi.input_terminated"]``, it indicates that the
+ server handles terminating the stream, so it is safe to read directly. For example,
+ a server that knows how to handle chunked requests safely would set this.
- # Otherwise limit the stream to the content length
- return t.cast(t.IO[bytes], LimitedStream(stream, content_length))
+ If ``max_content_length`` is set, it can be enforced on streams if
+ ``wsgi.input_terminated`` is set. Otherwise, an empty stream is returned unless the
+ user explicitly disables this safe fallback.
+ If the limit is reached before the underlying stream is exhausted (such as a file
+ that is too large, or an infinite stream), the remaining contents of the stream
+ cannot be read safely. Depending on how the server handles this, clients may show a
+ "connection reset" failure instead of seeing the 413 response.
-def get_query_string(environ: "WSGIEnvironment") -> str:
- """Returns the ``QUERY_STRING`` from the WSGI environment. This also
- takes care of the WSGI decoding dance. The string returned will be
- restricted to ASCII characters.
+ :param environ: The WSGI environ containing the stream.
+ :param safe_fallback: Return an empty stream when ``Content-Length`` is not set.
+ Disabling this allows infinite streams, which can be a denial-of-service risk.
+ :param max_content_length: The maximum length that content-length or streaming
+ requests may not exceed.
- :param environ: WSGI environment to get the query string from.
+ .. versionchanged:: 2.3.2
+ ``max_content_length`` is only applied to streaming requests if the server sets
+ ``wsgi.input_terminated``.
- .. deprecated:: 2.2
- Will be removed in Werkzeug 2.3.
+ .. versionchanged:: 2.3
+ Check ``max_content_length`` and raise an error if it is exceeded.
.. versionadded:: 0.9
"""
- warnings.warn(
- "'get_query_string' is deprecated and will be removed in Werkzeug 2.3.",
- DeprecationWarning,
- stacklevel=2,
- )
- qs = environ.get("QUERY_STRING", "").encode("latin1")
- # QUERY_STRING really should be ascii safe but some browsers
- # will send us some unicode stuff (I am looking at you IE).
- # In that case we want to urllib quote it badly.
- return url_quote(qs, safe=":&%=+$!*'(),")
+ stream = t.cast(t.IO[bytes], environ["wsgi.input"])
+ content_length = get_content_length(environ)
+ if content_length is not None and max_content_length is not None:
+ if content_length > max_content_length:
+ raise RequestEntityTooLarge()
+
+ # A WSGI server can set this to indicate that it terminates the input stream. In
+ # that case the stream is safe without wrapping, or can enforce a max length.
+ if "wsgi.input_terminated" in environ:
+ if max_content_length is not None:
+ # If this is moved above, it can cause the stream to hang if a read attempt
+ # is made when the client sends no data. For example, the development server
+ # does not handle buffering except for chunked encoding.
+ return t.cast(
+ t.IO[bytes], LimitedStream(stream, max_content_length, is_max=True)
+ )
-def get_path_info(
- environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace"
-) -> str:
- """Return the ``PATH_INFO`` from the WSGI environment and decode it
- unless ``charset`` is ``None``.
+ return stream
- :param environ: WSGI environment to get the path from.
- :param charset: The charset for the path info, or ``None`` if no
- decoding should be performed.
- :param errors: The decoding error handling.
+ # No limit given, return an empty stream unless the user explicitly allows the
+ # potentially infinite stream. An infinite stream is dangerous if it's not expected,
+ # as it can tie up a worker indefinitely.
+ if content_length is None:
+ return io.BytesIO() if safe_fallback else stream
- .. versionadded:: 0.9
- """
- path = environ.get("PATH_INFO", "").encode("latin1")
- return _to_str(path, charset, errors, allow_none_charset=True) # type: ignore
+ return t.cast(t.IO[bytes], LimitedStream(stream, content_length))
-def get_script_name(
- environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace"
-) -> str:
- """Return the ``SCRIPT_NAME`` from the WSGI environment and decode
- it unless `charset` is set to ``None``.
+def get_path_info(environ: WSGIEnvironment) -> str:
+ """Return ``PATH_INFO`` from the WSGI environment.
:param environ: WSGI environment to get the path from.
- :param charset: The charset for the path, or ``None`` if no decoding
- should be performed.
- :param errors: The decoding error handling.
- .. deprecated:: 2.2
- Will be removed in Werkzeug 2.3.
+ .. versionchanged:: 3.0
+ The ``charset`` and ``errors`` parameters were removed.
.. versionadded:: 0.9
"""
- warnings.warn(
- "'get_script_name' is deprecated and will be removed in Werkzeug 2.3.",
- DeprecationWarning,
- stacklevel=2,
- )
- path = environ.get("SCRIPT_NAME", "").encode("latin1")
- return _to_str(path, charset, errors, allow_none_charset=True) # type: ignore
-
-
-def pop_path_info(
- environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace"
-) -> t.Optional[str]:
- """Removes and returns the next segment of `PATH_INFO`, pushing it onto
- `SCRIPT_NAME`. Returns `None` if there is nothing left on `PATH_INFO`.
-
- If the `charset` is set to `None` bytes are returned.
-
- If there are empty segments (``'/foo//bar``) these are ignored but
- properly pushed to the `SCRIPT_NAME`:
-
- >>> env = {'SCRIPT_NAME': '/foo', 'PATH_INFO': '/a/b'}
- >>> pop_path_info(env)
- 'a'
- >>> env['SCRIPT_NAME']
- '/foo/a'
- >>> pop_path_info(env)
- 'b'
- >>> env['SCRIPT_NAME']
- '/foo/a/b'
-
- .. deprecated:: 2.2
- Will be removed in Werkzeug 2.3.
-
- .. versionadded:: 0.5
-
- .. versionchanged:: 0.9
- The path is now decoded and a charset and encoding
- parameter can be provided.
-
- :param environ: the WSGI environment that is modified.
- :param charset: The ``encoding`` parameter passed to
- :func:`bytes.decode`.
- :param errors: The ``errors`` paramater passed to
- :func:`bytes.decode`.
- """
- warnings.warn(
- "'pop_path_info' is deprecated and will be removed in Werkzeug 2.3.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- path = environ.get("PATH_INFO")
- if not path:
- return None
-
- script_name = environ.get("SCRIPT_NAME", "")
-
- # shift multiple leading slashes over
- old_path = path
- path = path.lstrip("/")
- if path != old_path:
- script_name += "/" * (len(old_path) - len(path))
-
- if "/" not in path:
- environ["PATH_INFO"] = ""
- environ["SCRIPT_NAME"] = script_name + path
- rv = path.encode("latin1")
- else:
- segment, path = path.split("/", 1)
- environ["PATH_INFO"] = f"/{path}"
- environ["SCRIPT_NAME"] = script_name + segment
- rv = segment.encode("latin1")
-
- return _to_str(rv, charset, errors, allow_none_charset=True) # type: ignore
-
-
-def peek_path_info(
- environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace"
-) -> t.Optional[str]:
- """Returns the next segment on the `PATH_INFO` or `None` if there
- is none. Works like :func:`pop_path_info` without modifying the
- environment:
-
- >>> env = {'SCRIPT_NAME': '/foo', 'PATH_INFO': '/a/b'}
- >>> peek_path_info(env)
- 'a'
- >>> peek_path_info(env)
- 'a'
-
- If the `charset` is set to `None` bytes are returned.
-
- .. deprecated:: 2.2
- Will be removed in Werkzeug 2.3.
-
- .. versionadded:: 0.5
-
- .. versionchanged:: 0.9
- The path is now decoded and a charset and encoding
- parameter can be provided.
-
- :param environ: the WSGI environment that is checked.
- """
- warnings.warn(
- "'peek_path_info' is deprecated and will be removed in Werkzeug 2.3.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1)
- if segments:
- return _to_str( # type: ignore
- segments[0].encode("latin1"), charset, errors, allow_none_charset=True
- )
- return None
-
-
-def extract_path_info(
- environ_or_baseurl: t.Union[str, "WSGIEnvironment"],
- path_or_url: t.Union[str, _URLTuple],
- charset: str = "utf-8",
- errors: str = "werkzeug.url_quote",
- collapse_http_schemes: bool = True,
-) -> t.Optional[str]:
- """Extracts the path info from the given URL (or WSGI environment) and
- path. The path info returned is a string. The URLs might also be IRIs.
-
- If the path info could not be determined, `None` is returned.
-
- Some examples:
-
- >>> extract_path_info('http://example.com/app', '/app/hello')
- '/hello'
- >>> extract_path_info('http://example.com/app',
- ... 'https://example.com/app/hello')
- '/hello'
- >>> extract_path_info('http://example.com/app',
- ... 'https://example.com/app/hello',
- ... collapse_http_schemes=False) is None
- True
-
- Instead of providing a base URL you can also pass a WSGI environment.
-
- :param environ_or_baseurl: a WSGI environment dict, a base URL or
- base IRI. This is the root of the
- application.
- :param path_or_url: an absolute path from the server root, a
- relative path (in which case it's the path info)
- or a full URL.
- :param charset: the charset for byte data in URLs
- :param errors: the error handling on decode
- :param collapse_http_schemes: if set to `False` the algorithm does
- not assume that http and https on the
- same server point to the same
- resource.
-
- .. deprecated:: 2.2
- Will be removed in Werkzeug 2.3.
-
- .. versionchanged:: 0.15
- The ``errors`` parameter defaults to leaving invalid bytes
- quoted instead of replacing them.
-
- .. versionadded:: 0.6
-
- """
- warnings.warn(
- "'extract_path_info' is deprecated and will be removed in Werkzeug 2.3.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- def _normalize_netloc(scheme: str, netloc: str) -> str:
- parts = netloc.split("@", 1)[-1].split(":", 1)
- port: t.Optional[str]
-
- if len(parts) == 2:
- netloc, port = parts
- if (scheme == "http" and port == "80") or (
- scheme == "https" and port == "443"
- ):
- port = None
- else:
- netloc = parts[0]
- port = None
-
- if port is not None:
- netloc += f":{port}"
-
- return netloc
-
- # make sure whatever we are working on is a IRI and parse it
- path = uri_to_iri(path_or_url, charset, errors)
- if isinstance(environ_or_baseurl, dict):
- environ_or_baseurl = get_current_url(environ_or_baseurl, root_only=True)
- base_iri = uri_to_iri(environ_or_baseurl, charset, errors)
- base_scheme, base_netloc, base_path = url_parse(base_iri)[:3]
- cur_scheme, cur_netloc, cur_path = url_parse(url_join(base_iri, path))[:3]
-
- # normalize the network location
- base_netloc = _normalize_netloc(base_scheme, base_netloc)
- cur_netloc = _normalize_netloc(cur_scheme, cur_netloc)
-
- # is that IRI even on a known HTTP scheme?
- if collapse_http_schemes:
- for scheme in base_scheme, cur_scheme:
- if scheme not in ("http", "https"):
- return None
- else:
- if not (base_scheme in ("http", "https") and base_scheme == cur_scheme):
- return None
-
- # are the netlocs compatible?
- if base_netloc != cur_netloc:
- return None
-
- # are we below the application path?
- base_path = base_path.rstrip("/")
- if not cur_path.startswith(base_path):
- return None
-
- return f"/{cur_path[len(base_path) :].lstrip('/')}"
+ path: bytes = environ.get("PATH_INFO", "").encode("latin1")
+ return path.decode(errors="replace")
class ClosingIterator:
@@ -476,9 +233,8 @@ class ClosingIterator:
def __init__(
self,
iterable: t.Iterable[bytes],
- callbacks: t.Optional[
- t.Union[t.Callable[[], None], t.Iterable[t.Callable[[], None]]]
- ] = None,
+ callbacks: None
+ | (t.Callable[[], None] | t.Iterable[t.Callable[[], None]]) = None,
) -> None:
iterator = iter(iterable)
self._next = t.cast(t.Callable[[], bytes], partial(next, iterator))
@@ -493,7 +249,7 @@ def __init__(
callbacks.insert(0, iterable_close)
self._callbacks = callbacks
- def __iter__(self) -> "ClosingIterator":
+ def __iter__(self) -> ClosingIterator:
return self
def __next__(self) -> bytes:
@@ -505,7 +261,7 @@ def close(self) -> None:
def wrap_file(
- environ: "WSGIEnvironment", file: t.IO[bytes], buffer_size: int = 8192
+ environ: WSGIEnvironment, file: t.IO[bytes], buffer_size: int = 8192
) -> t.Iterable[bytes]:
"""Wraps a file. This uses the WSGI server's file wrapper if available
or otherwise the generic :class:`FileWrapper`.
@@ -564,12 +320,12 @@ def seek(self, *args: t.Any) -> None:
if hasattr(self.file, "seek"):
self.file.seek(*args)
- def tell(self) -> t.Optional[int]:
+ def tell(self) -> int | None:
if hasattr(self.file, "tell"):
return self.file.tell()
return None
- def __iter__(self) -> "FileWrapper":
+ def __iter__(self) -> FileWrapper:
return self
def __next__(self) -> bytes:
@@ -598,9 +354,9 @@ class _RangeWrapper:
def __init__(
self,
- iterable: t.Union[t.Iterable[bytes], t.IO[bytes]],
+ iterable: t.Iterable[bytes] | t.IO[bytes],
start_byte: int = 0,
- byte_range: t.Optional[int] = None,
+ byte_range: int | None = None,
):
self.iterable = iter(iterable)
self.byte_range = byte_range
@@ -611,12 +367,10 @@ def __init__(
self.end_byte = start_byte + byte_range
self.read_length = 0
- self.seekable = (
- hasattr(iterable, "seekable") and iterable.seekable() # type: ignore
- )
+ self.seekable = hasattr(iterable, "seekable") and iterable.seekable()
self.end_reached = False
- def __iter__(self) -> "_RangeWrapper":
+ def __iter__(self) -> _RangeWrapper:
return self
def _next_chunk(self) -> bytes:
@@ -628,7 +382,7 @@ def _next_chunk(self) -> bytes:
self.end_reached = True
raise
- def _first_iteration(self) -> t.Tuple[t.Optional[bytes], int]:
+ def _first_iteration(self) -> tuple[bytes | None, int]:
chunk = None
if self.seekable:
self.iterable.seek(self.start_byte) # type: ignore
@@ -665,356 +419,177 @@ def __next__(self) -> bytes:
def close(self) -> None:
if hasattr(self.iterable, "close"):
- self.iterable.close() # type: ignore
-
-
-def _make_chunk_iter(
- stream: t.Union[t.Iterable[bytes], t.IO[bytes]],
- limit: t.Optional[int],
- buffer_size: int,
-) -> t.Iterator[bytes]:
- """Helper for the line and chunk iter functions."""
- if isinstance(stream, (bytes, bytearray, str)):
- raise TypeError(
- "Passed a string or byte object instead of true iterator or stream."
- )
- if not hasattr(stream, "read"):
- for item in stream:
- if item:
- yield item
- return
- stream = t.cast(t.IO[bytes], stream)
- if not isinstance(stream, LimitedStream) and limit is not None:
- stream = t.cast(t.IO[bytes], LimitedStream(stream, limit))
- _read = stream.read
- while True:
- item = _read(buffer_size)
- if not item:
- break
- yield item
-
-
-def make_line_iter(
- stream: t.Union[t.Iterable[bytes], t.IO[bytes]],
- limit: t.Optional[int] = None,
- buffer_size: int = 10 * 1024,
- cap_at_buffer: bool = False,
-) -> t.Iterator[bytes]:
- """Safely iterates line-based over an input stream. If the input stream
- is not a :class:`LimitedStream` the `limit` parameter is mandatory.
-
- This uses the stream's :meth:`~file.read` method internally as opposite
- to the :meth:`~file.readline` method that is unsafe and can only be used
- in violation of the WSGI specification. The same problem applies to the
- `__iter__` function of the input stream which calls :meth:`~file.readline`
- without arguments.
-
- If you need line-by-line processing it's strongly recommended to iterate
- over the input stream using this helper function.
-
- .. versionchanged:: 0.8
- This function now ensures that the limit was reached.
+ self.iterable.close()
- .. versionadded:: 0.9
- added support for iterators as input stream.
-
- .. versionadded:: 0.11.10
- added support for the `cap_at_buffer` parameter.
-
- :param stream: the stream or iterate to iterate over.
- :param limit: the limit in bytes for the stream. (Usually
- content length. Not necessary if the `stream`
- is a :class:`LimitedStream`.
- :param buffer_size: The optional buffer size.
- :param cap_at_buffer: if this is set chunks are split if they are longer
- than the buffer size. Internally this is implemented
- that the buffer size might be exhausted by a factor
- of two however.
- """
- _iter = _make_chunk_iter(stream, limit, buffer_size)
-
- first_item = next(_iter, "")
- if not first_item:
- return
-
- s = _make_encode_wrapper(first_item)
- empty = t.cast(bytes, s(""))
- cr = t.cast(bytes, s("\r"))
- lf = t.cast(bytes, s("\n"))
- crlf = t.cast(bytes, s("\r\n"))
-
- _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter))
-
- def _iter_basic_lines() -> t.Iterator[bytes]:
- _join = empty.join
- buffer: t.List[bytes] = []
- while True:
- new_data = next(_iter, "")
- if not new_data:
- break
- new_buf: t.List[bytes] = []
- buf_size = 0
- for item in t.cast(
- t.Iterator[bytes], chain(buffer, new_data.splitlines(True))
- ):
- new_buf.append(item)
- buf_size += len(item)
- if item and item[-1:] in crlf:
- yield _join(new_buf)
- new_buf = []
- elif cap_at_buffer and buf_size >= buffer_size:
- rv = _join(new_buf)
- while len(rv) >= buffer_size:
- yield rv[:buffer_size]
- rv = rv[buffer_size:]
- new_buf = [rv]
- buffer = new_buf
- if buffer:
- yield _join(buffer)
-
- # This hackery is necessary to merge 'foo\r' and '\n' into one item
- # of 'foo\r\n' if we were unlucky and we hit a chunk boundary.
- previous = empty
- for item in _iter_basic_lines():
- if item == lf and previous[-1:] == cr:
- previous += item
- item = empty
- if previous:
- yield previous
- previous = item
- if previous:
- yield previous
-
-
-def make_chunk_iter(
- stream: t.Union[t.Iterable[bytes], t.IO[bytes]],
- separator: bytes,
- limit: t.Optional[int] = None,
- buffer_size: int = 10 * 1024,
- cap_at_buffer: bool = False,
-) -> t.Iterator[bytes]:
- """Works like :func:`make_line_iter` but accepts a separator
- which divides chunks. If you want newline based processing
- you should use :func:`make_line_iter` instead as it
- supports arbitrary newline markers.
-
- .. versionadded:: 0.8
- .. versionadded:: 0.9
- added support for iterators as input stream.
-
- .. versionadded:: 0.11.10
- added support for the `cap_at_buffer` parameter.
-
- :param stream: the stream or iterate to iterate over.
- :param separator: the separator that divides chunks.
- :param limit: the limit in bytes for the stream. (Usually
- content length. Not necessary if the `stream`
- is otherwise already limited).
- :param buffer_size: The optional buffer size.
- :param cap_at_buffer: if this is set chunks are split if they are longer
- than the buffer size. Internally this is implemented
- that the buffer size might be exhausted by a factor
- of two however.
- """
- _iter = _make_chunk_iter(stream, limit, buffer_size)
-
- first_item = next(_iter, b"")
- if not first_item:
- return
-
- _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter))
- if isinstance(first_item, str):
- separator = _to_str(separator)
- _split = re.compile(f"({re.escape(separator)})").split
- _join = "".join
- else:
- separator = _to_bytes(separator)
- _split = re.compile(b"(" + re.escape(separator) + b")").split
- _join = b"".join
-
- buffer: t.List[bytes] = []
- while True:
- new_data = next(_iter, b"")
- if not new_data:
- break
- chunks = _split(new_data)
- new_buf: t.List[bytes] = []
- buf_size = 0
- for item in chain(buffer, chunks):
- if item == separator:
- yield _join(new_buf)
- new_buf = []
- buf_size = 0
- else:
- buf_size += len(item)
- new_buf.append(item)
-
- if cap_at_buffer and buf_size >= buffer_size:
- rv = _join(new_buf)
- while len(rv) >= buffer_size:
- yield rv[:buffer_size]
- rv = rv[buffer_size:]
- new_buf = [rv]
- buf_size = len(rv)
-
- buffer = new_buf
- if buffer:
- yield _join(buffer)
-
-
-class LimitedStream(io.IOBase):
- """Wraps a stream so that it doesn't read more than n bytes. If the
- stream is exhausted and the caller tries to get more bytes from it
- :func:`on_exhausted` is called which by default returns an empty
- string. The return value of that function is forwarded
- to the reader function. So if it returns an empty string
- :meth:`read` will return an empty string as well.
-
- The limit however must never be higher than what the stream can
- output. Otherwise :meth:`readlines` will try to read past the
- limit.
-
- .. admonition:: Note on WSGI compliance
-
- calls to :meth:`readline` and :meth:`readlines` are not
- WSGI compliant because it passes a size argument to the
- readline methods. Unfortunately the WSGI PEP is not safely
- implementable without a size argument to :meth:`readline`
- because there is no EOF marker in the stream. As a result
- of that the use of :meth:`readline` is discouraged.
-
- For the same reason iterating over the :class:`LimitedStream`
- is not portable. It internally calls :meth:`readline`.
-
- We strongly suggest using :meth:`read` only or using the
- :func:`make_line_iter` which safely iterates line-based
- over a WSGI input stream.
-
- :param stream: the stream to wrap.
- :param limit: the limit for the stream, must not be longer than
- what the string can provide if the stream does not
- end with `EOF` (like `wsgi.input`)
+class LimitedStream(io.RawIOBase):
+ """Wrap a stream so that it doesn't read more than a given limit. This is used to
+ limit ``wsgi.input`` to the ``Content-Length`` header value or
+ :attr:`.Request.max_content_length`.
+
+ When attempting to read after the limit has been reached, :meth:`on_exhausted` is
+ called. When the limit is a maximum, this raises :exc:`.RequestEntityTooLarge`.
+
+ If reading from the stream returns zero bytes or raises an error,
+ :meth:`on_disconnect` is called, which raises :exc:`.ClientDisconnected`. When the
+ limit is a maximum and zero bytes were read, no error is raised, since it may be the
+ end of the stream.
+
+ If the limit is reached before the underlying stream is exhausted (such as a file
+ that is too large, or an infinite stream), the remaining contents of the stream
+ cannot be read safely. Depending on how the server handles this, clients may show a
+ "connection reset" failure instead of seeing the 413 response.
+
+ :param stream: The stream to read from. Must be a readable binary IO object.
+ :param limit: The limit in bytes to not read past. Should be either the
+ ``Content-Length`` header value or ``request.max_content_length``.
+ :param is_max: Whether the given ``limit`` is ``request.max_content_length`` instead
+ of the ``Content-Length`` header value. This changes how exhausted and
+ disconnect events are handled.
+
+ .. versionchanged:: 2.3
+ Handle ``max_content_length`` differently than ``Content-Length``.
+
+ .. versionchanged:: 2.3
+ Implements ``io.RawIOBase`` rather than ``io.IOBase``.
"""
- def __init__(self, stream: t.IO[bytes], limit: int) -> None:
- self._read = stream.read
- self._readline = stream.readline
+ def __init__(self, stream: t.IO[bytes], limit: int, is_max: bool = False) -> None:
+ self._stream = stream
self._pos = 0
self.limit = limit
-
- def __iter__(self) -> "LimitedStream":
- return self
+ self._limit_is_max = is_max
@property
def is_exhausted(self) -> bool:
- """If the stream is exhausted this attribute is `True`."""
+ """Whether the current stream position has reached the limit."""
return self._pos >= self.limit
- def on_exhausted(self) -> bytes:
- """This is called when the stream tries to read past the limit.
- The return value of this function is returned from the reading
- function.
- """
- # Read null bytes from the stream so that we get the
- # correct end of stream marker.
- return self._read(0)
-
- def on_disconnect(self) -> bytes:
- """What should happen if a disconnect is detected? The return
- value of this function is returned from read functions in case
- the client went away. By default a
- :exc:`~werkzeug.exceptions.ClientDisconnected` exception is raised.
- """
- from .exceptions import ClientDisconnected
+ def on_exhausted(self) -> None:
+ """Called when attempting to read after the limit has been reached.
- raise ClientDisconnected()
+ The default behavior is to do nothing, unless the limit is a maximum, in which
+ case it raises :exc:`.RequestEntityTooLarge`.
- def exhaust(self, chunk_size: int = 1024 * 64) -> None:
- """Exhaust the stream. This consumes all the data left until the
- limit is reached.
+ .. versionchanged:: 2.3
+ Raises ``RequestEntityTooLarge`` if the limit is a maximum.
- :param chunk_size: the size for a chunk. It will read the chunk
- until the stream is exhausted and throw away
- the results.
+ .. versionchanged:: 2.3
+ Any return value is ignored.
"""
- to_read = self.limit - self._pos
- chunk = chunk_size
- while to_read > 0:
- chunk = min(to_read, chunk)
- self.read(chunk)
- to_read -= chunk
+ if self._limit_is_max:
+ raise RequestEntityTooLarge()
+
+ def on_disconnect(self, error: Exception | None = None) -> None:
+ """Called when an attempted read receives zero bytes before the limit was
+ reached. This indicates that the client disconnected before sending the full
+ request body.
+
+ The default behavior is to raise :exc:`.ClientDisconnected`, unless the limit is
+ a maximum and no error was raised.
- def read(self, size: t.Optional[int] = None) -> bytes:
- """Read `size` bytes or if size is not provided everything is read.
+ .. versionchanged:: 2.3
+ Added the ``error`` parameter. Do nothing if the limit is a maximum and no
+ error was raised.
- :param size: the number of bytes read.
+ .. versionchanged:: 2.3
+ Any return value is ignored.
"""
- if self._pos >= self.limit:
- return self.on_exhausted()
- if size is None or size == -1: # -1 is for consistence with file
- size = self.limit
- to_read = min(self.limit - self._pos, size)
- try:
- read = self._read(to_read)
- except (OSError, ValueError):
- return self.on_disconnect()
- if to_read and len(read) != to_read:
- return self.on_disconnect()
- self._pos += len(read)
- return read
-
- def readline(self, size: t.Optional[int] = None) -> bytes:
- """Reads one line from the stream."""
- if self._pos >= self.limit:
- return self.on_exhausted()
- if size is None:
- size = self.limit - self._pos
- else:
- size = min(size, self.limit - self._pos)
- try:
- line = self._readline(size)
- except (ValueError, OSError):
- return self.on_disconnect()
- if size and not line:
- return self.on_disconnect()
- self._pos += len(line)
- return line
-
- def readlines(self, size: t.Optional[int] = None) -> t.List[bytes]:
- """Reads a file into a list of strings. It calls :meth:`readline`
- until the file is read to the end. It does support the optional
- `size` argument if the underlying stream supports it for
- `readline`.
+ if not self._limit_is_max or error is not None:
+ raise ClientDisconnected()
+
+ # If the limit is a maximum, then we may have read zero bytes because the
+ # streaming body is complete. There's no way to distinguish that from the
+ # client disconnecting early.
+
+ def exhaust(self) -> bytes:
+ """Exhaust the stream by reading until the limit is reached or the client
+ disconnects, returning the remaining data.
+
+ .. versionchanged:: 2.3
+ Return the remaining data.
+
+ .. versionchanged:: 2.2.3
+ Handle case where wrapped stream returns fewer bytes than requested.
"""
- last_pos = self._pos
- result = []
- if size is not None:
- end = min(self.limit, last_pos + size)
+ if not self.is_exhausted:
+ return self.readall()
+
+ return b""
+
+ def readinto(self, b: bytearray) -> int | None: # type: ignore[override]
+ size = len(b)
+ remaining = self.limit - self._pos
+
+ if remaining <= 0:
+ self.on_exhausted()
+ return 0
+
+ if hasattr(self._stream, "readinto"):
+ # Use stream.readinto if it's available.
+ if size <= remaining:
+ # The size fits in the remaining limit, use the buffer directly.
+ try:
+ out_size: int | None = self._stream.readinto(b)
+ except (OSError, ValueError) as e:
+ self.on_disconnect(error=e)
+ return 0
+ else:
+ # Use a temp buffer with the remaining limit as the size.
+ temp_b = bytearray(remaining)
+
+ try:
+ out_size = self._stream.readinto(temp_b)
+ except (OSError, ValueError) as e:
+ self.on_disconnect(error=e)
+ return 0
+
+ if out_size:
+ b[:out_size] = temp_b
else:
- end = self.limit
- while True:
- if size is not None:
- size -= last_pos - self._pos
- if self._pos >= end:
+ # WSGI requires that stream.read is available.
+ try:
+ data = self._stream.read(min(size, remaining))
+ except (OSError, ValueError) as e:
+ self.on_disconnect(error=e)
+ return 0
+
+ out_size = len(data)
+ b[:out_size] = data
+
+ if not out_size:
+ # Read zero bytes from the stream.
+ self.on_disconnect()
+ return 0
+
+ self._pos += out_size
+ return out_size
+
+ def readall(self) -> bytes:
+ if self.is_exhausted:
+ self.on_exhausted()
+ return b""
+
+ out = bytearray()
+
+ # The parent implementation uses "while True", which results in an extra read.
+ while not self.is_exhausted:
+ data = self.read(1024 * 64)
+
+ # Stream may return empty before a max limit is reached.
+ if not data:
break
- result.append(self.readline(size))
- if size is not None:
- last_pos = self._pos
- return result
+
+ out.extend(data)
+
+ return bytes(out)
def tell(self) -> int:
- """Returns the position of the stream.
+ """Return the current stream position.
.. versionadded:: 0.9
"""
return self._pos
- def __next__(self) -> bytes:
- line = self.readline()
- if not line:
- raise StopIteration()
- return line
-
def readable(self) -> bool:
return True
diff --git a/tests/conftest.py b/tests/conftest.py
index 7ce0896..b73202c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -41,7 +41,9 @@ def __init__(self, kwargs):
self.log = None
def tail_log(self, path):
- self.log = open(path)
+ # surrogateescape allows for handling of file streams
+ # containing junk binary values as normal text streams
+ self.log = open(path, errors="surrogateescape")
self.log.read()
def connect(self, **kwargs):
diff --git a/tests/live_apps/data_app.py b/tests/live_apps/data_app.py
index a7158c7..9b2e78b 100644
--- a/tests/live_apps/data_app.py
+++ b/tests/live_apps/data_app.py
@@ -5,13 +5,13 @@
@Request.application
-def app(request):
+def app(request: Request) -> Response:
return Response(
json.dumps(
{
"environ": request.environ,
- "form": request.form,
- "files": {k: v.read().decode("utf8") for k, v in request.files.items()},
+ "form": request.form.to_dict(),
+ "files": {k: v.read().decode() for k, v in request.files.items()},
},
default=lambda x: str(x),
),
diff --git a/tests/middleware/test_dispatcher.py b/tests/middleware/test_dispatcher.py
index 5a25a6c..b7b9a77 100644
--- a/tests/middleware/test_dispatcher.py
+++ b/tests/middleware/test_dispatcher.py
@@ -1,4 +1,3 @@
-from werkzeug._internal import _to_bytes
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from werkzeug.test import create_environ
from werkzeug.test import run_wsgi_app
@@ -11,7 +10,7 @@ def null_application(environ, start_response):
def dummy_application(environ, start_response):
start_response("200 OK", [("Content-Type", "text/plain")])
- yield _to_bytes(environ["SCRIPT_NAME"])
+ yield environ["SCRIPT_NAME"].encode()
app = DispatcherMiddleware(
null_application,
@@ -27,7 +26,7 @@ def dummy_application(environ, start_response):
environ = create_environ(p)
app_iter, status, headers = run_wsgi_app(app, environ)
assert status == "200 OK"
- assert b"".join(app_iter).strip() == _to_bytes(name)
+ assert b"".join(app_iter).strip() == name.encode()
app_iter, status, headers = run_wsgi_app(app, create_environ("/missing"))
assert status == "404 NOT FOUND"
diff --git a/tests/middleware/test_profiler.py b/tests/middleware/test_profiler.py
new file mode 100644
index 0000000..585aeb5
--- /dev/null
+++ b/tests/middleware/test_profiler.py
@@ -0,0 +1,50 @@
+import datetime
+import os
+from unittest.mock import ANY
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+from werkzeug.middleware.profiler import Profile
+from werkzeug.middleware.profiler import ProfilerMiddleware
+from werkzeug.test import Client
+
+
+def dummy_application(environ, start_response):
+ start_response("200 OK", [("Content-Type", "text/plain")])
+ return [b"Foo"]
+
+
+def test_filename_format_function():
+ # This should be called once with the generated file name
+ mock_capture_name = MagicMock()
+
+ def filename_format(env):
+ now = datetime.datetime.fromtimestamp(env["werkzeug.profiler"]["time"])
+ timestamp = now.strftime("%Y-%m-%d:%H:%M:%S")
+ path = (
+ "_".join(token for token in env["PATH_INFO"].split("/") if token) or "ROOT"
+ )
+ elapsed = env["werkzeug.profiler"]["elapsed"]
+ name = f"{timestamp}.{env['REQUEST_METHOD']}.{path}.{elapsed:.0f}ms.prof"
+ mock_capture_name(name=name)
+ return name
+
+ client = Client(
+ ProfilerMiddleware(
+ dummy_application,
+ stream=None,
+ profile_dir="profiles",
+ filename_format=filename_format,
+ )
+ )
+
+ # Replace the Profile class with a function that simulates an __init__()
+ # call and returns our mock instance.
+ mock_profile = MagicMock(wraps=Profile())
+ mock_profile.dump_stats = MagicMock()
+ with patch("werkzeug.middleware.profiler.Profile", lambda: mock_profile):
+ client.get("/foo/bar")
+
+ mock_capture_name.assert_called_once_with(name=ANY)
+ name = mock_capture_name.mock_calls[0].kwargs["name"]
+ mock_profile.dump_stats.assert_called_once_with(os.path.join("profiles", name))
diff --git a/tests/sansio/test_multipart.py b/tests/sansio/test_multipart.py
index f9c48b4..cf36fef 100644
--- a/tests/sansio/test_multipart.py
+++ b/tests/sansio/test_multipart.py
@@ -1,3 +1,5 @@
+import pytest
+
from werkzeug.datastructures import Headers
from werkzeug.sansio.multipart import Data
from werkzeug.sansio.multipart import Epilogue
@@ -22,15 +24,11 @@ def test_decoder_simple() -> None:
asdasd
-----------------------------9704338192090380615194531385$--
- """.replace(
- "\n", "\r\n"
- ).encode(
- "utf-8"
- )
+ """.replace("\n", "\r\n").encode()
decoder.receive_data(data)
decoder.receive_data(None)
events = [decoder.next_event()]
- while not isinstance(events[-1], Epilogue) and len(events) < 6:
+ while not isinstance(events[-1], Epilogue):
events.append(decoder.next_event())
assert events == [
Preamble(data=b""),
@@ -56,6 +54,57 @@ def test_decoder_simple() -> None:
assert data == result
+@pytest.mark.parametrize(
+ "data_start",
+ [
+ b"A",
+ b"\n",
+ b"\r",
+ b"\r\n",
+ b"\n\r",
+ b"A\n",
+ b"A\r",
+ b"A\r\n",
+ b"A\n\r",
+ ],
+)
+@pytest.mark.parametrize("data_end", [b"", b"\r\n--foo"])
+def test_decoder_data_start_with_different_newline_positions(
+ data_start: bytes, data_end: bytes
+) -> None:
+ boundary = b"foo"
+ data = (
+ b"\r\n--foo\r\n"
+ b'Content-Disposition: form-data; name="test"; filename="testfile"\r\n'
+ b"Content-Type: application/octet-stream\r\n\r\n"
+ b"" + data_start + b"\r\nBCDE" + data_end
+ )
+ decoder = MultipartDecoder(boundary)
+ decoder.receive_data(data)
+ events = [decoder.next_event()]
+ # We want to check up to data start event
+ while not isinstance(events[-1], Data):
+ events.append(decoder.next_event())
+ expected = data_start if data_end == b"" else data_start + b"\r\nBCDE"
+ assert events == [
+ Preamble(data=b""),
+ File(
+ name="test",
+ filename="testfile",
+ headers=Headers(
+ [
+ (
+ "Content-Disposition",
+ 'form-data; name="test"; filename="testfile"',
+ ),
+ ("Content-Type", "application/octet-stream"),
+ ]
+ ),
+ ),
+ Data(data=expected, more_data=True),
+ ]
+
+
def test_chunked_boundaries() -> None:
boundary = b"--boundary"
decoder = MultipartDecoder(boundary)
@@ -78,3 +127,54 @@ def test_chunked_boundaries() -> None:
assert not event.more_data
decoder.receive_data(None)
assert isinstance(decoder.next_event(), Epilogue)
+
+
+def test_empty_field() -> None:
+ boundary = b"foo"
+ decoder = MultipartDecoder(boundary)
+ data = """
+--foo
+Content-Disposition: form-data; name="text"
+Content-Type: text/plain; charset="UTF-8"
+
+Some Text
+--foo
+Content-Disposition: form-data; name="empty"
+Content-Type: text/plain; charset="UTF-8"
+
+--foo--
+ """.replace("\n", "\r\n").encode()
+ decoder.receive_data(data)
+ decoder.receive_data(None)
+ events = [decoder.next_event()]
+ while not isinstance(events[-1], Epilogue):
+ events.append(decoder.next_event())
+ assert events == [
+ Preamble(data=b""),
+ Field(
+ name="text",
+ headers=Headers(
+ [
+ ("Content-Disposition", 'form-data; name="text"'),
+ ("Content-Type", 'text/plain; charset="UTF-8"'),
+ ]
+ ),
+ ),
+ Data(data=b"Some Text", more_data=False),
+ Field(
+ name="empty",
+ headers=Headers(
+ [
+ ("Content-Disposition", 'form-data; name="empty"'),
+ ("Content-Type", 'text/plain; charset="UTF-8"'),
+ ]
+ ),
+ ),
+ Data(data=b"", more_data=False),
+ Epilogue(data=b" "),
+ ]
+ encoder = MultipartEncoder(boundary)
+ result = b""
+ for event in events:
+ result += encoder.send_event(event)
+ assert data == result
diff --git a/tests/sansio/test_request.py b/tests/sansio/test_request.py
index 310b244..4f4bbd6 100644
--- a/tests/sansio/test_request.py
+++ b/tests/sansio/test_request.py
@@ -12,6 +12,10 @@
(Headers({"Transfer-Encoding": "chunked", "Content-Length": "6"}), None),
(Headers({"Transfer-Encoding": "something", "Content-Length": "6"}), 6),
(Headers({"Content-Length": "6"}), 6),
+ (Headers({"Content-Length": "-6"}), 0),
+ (Headers({"Content-Length": "+123"}), 0),
+ (Headers({"Content-Length": "1_23"}), 0),
+ (Headers({"Content-Length": "🯱🯲🯳"}), 0),
(Headers(), None),
],
)
diff --git a/tests/sansio/test_utils.py b/tests/sansio/test_utils.py
index 8c8faa6..d43de66 100644
--- a/tests/sansio/test_utils.py
+++ b/tests/sansio/test_utils.py
@@ -1,7 +1,8 @@
-import typing as t
+from __future__ import annotations
import pytest
+from werkzeug.sansio.utils import get_content_length
from werkzeug.sansio.utils import get_host
@@ -25,8 +26,28 @@
)
def test_get_host(
scheme: str,
- host_header: t.Optional[str],
- server: t.Optional[t.Tuple[str, t.Optional[int]]],
+ host_header: str | None,
+ server: tuple[str, int | None] | None,
expected: str,
) -> None:
assert get_host(scheme, host_header, server) == expected
+
+
+@pytest.mark.parametrize(
+ ("http_content_length", "http_transfer_encoding", "expected"),
+ [
+ ("2", None, 2),
+ (" 2", None, 2),
+ ("2 ", None, 2),
+ (None, None, None),
+ (None, "chunked", None),
+ ("a", None, 0),
+ ("-2", None, 0),
+ ],
+)
+def test_get_content_length(
+ http_content_length: str | None,
+ http_transfer_encoding: str | None,
+ expected: int | None,
+) -> None:
+ assert get_content_length(http_content_length, http_transfer_encoding) == expected
diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py
index 7f63b64..64330e1 100644
--- a/tests/test_datastructures.py
+++ b/tests/test_datastructures.py
@@ -63,7 +63,7 @@ def create_instance(module=None):
d = create_instance()
s = pickle.dumps(d, protocol)
ud = pickle.loads(s)
- assert type(ud) == type(d)
+ assert type(ud) == type(d) # noqa: E721
assert ud == d
alternative = pickle.dumps(create_instance("werkzeug"), protocol)
assert pickle.loads(alternative) == d
@@ -550,8 +550,9 @@ def test_value_conversion(self):
assert d.get("foo", type=int) == 1
def test_return_default_when_conversion_is_not_possible(self):
- d = self.storage_class(foo="bar")
+ d = self.storage_class(foo="bar", baz=None)
assert d.get("foo", default=-1, type=int) == -1
+ assert d.get("baz", default=-1, type=int) == -1
def test_propagate_exceptions_in_conversion(self):
d = self.storage_class(foo="bar")
@@ -731,16 +732,6 @@ def test_slicing(self):
h[:] = [(k, v) for k, v in h if k.startswith("X-")]
assert list(h) == [("X-Foo-Poo", "bleh"), ("X-Forwarded-For", "192.168.0.123")]
- def test_bytes_operations(self):
- h = self.storage_class()
- h.set("X-Foo-Poo", "bleh")
- h.set("X-Whoops", b"\xff")
- h.set(b"X-Bytes", b"something")
-
- assert h.get("x-foo-poo", as_bytes=True) == b"bleh"
- assert h.get("x-whoops", as_bytes=True) == b"\xff"
- assert h.get("x-bytes") == "something"
-
def test_extend(self):
h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")])
h.extend(ds.Headers([("a", "3"), ("a", "4")]))
@@ -791,13 +782,6 @@ def test_to_wsgi_list(self):
assert key == "Key"
assert value == "Value"
- def test_to_wsgi_list_bytes(self):
- h = self.storage_class()
- h.set(b"Key", b"Value")
- for key, value in h.to_wsgi_list():
- assert key == "Key"
- assert value == "Value"
-
def test_equality(self):
# test equality, given keys are case insensitive
h1 = self.storage_class()
@@ -853,13 +837,6 @@ def test_return_type_is_str(self):
assert headers["Foo"] == "\xe2\x9c\x93"
assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93")
- def test_bytes_operations(self):
- foo_val = "\xff"
- h = self.storage_class({"HTTP_X_FOO": foo_val})
-
- assert h.get("x-foo", as_bytes=True) == b"\xff"
- assert h.get("x-foo") == "\xff"
-
class TestHeaderSet:
storage_class = ds.HeaderSet
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
index d8fed96..ad20b3f 100644
--- a/tests/test_exceptions.py
+++ b/tests/test_exceptions.py
@@ -7,6 +7,7 @@
from werkzeug import exceptions
from werkzeug.datastructures import Headers
from werkzeug.datastructures import WWWAuthenticate
+from werkzeug.exceptions import default_exceptions
from werkzeug.exceptions import HTTPException
from werkzeug.wrappers import Response
@@ -96,10 +97,8 @@ def test_method_not_allowed_methods():
def test_unauthorized_www_authenticate():
- basic = WWWAuthenticate()
- basic.set_basic("test")
- digest = WWWAuthenticate()
- digest.set_digest("test", "test")
+ basic = WWWAuthenticate("basic", {"realm": "test"})
+ digest = WWWAuthenticate("digest", {"realm": "test", "nonce": "test"})
exc = exceptions.Unauthorized(www_authenticate=basic)
h = Headers(exc.get_headers({}))
@@ -140,7 +139,7 @@ def test_retry_after_mixin(cls, value, expect):
@pytest.mark.parametrize(
"cls",
sorted(
- (e for e in HTTPException.__subclasses__() if e.code and e.code >= 400),
+ (e for e in default_exceptions.values() if e.code and e.code >= 400),
key=lambda e: e.code, # type: ignore
),
)
@@ -160,7 +159,7 @@ def test_description_none():
@pytest.mark.parametrize(
"cls",
sorted(
- (e for e in HTTPException.__subclasses__() if e.code),
+ (e for e in default_exceptions.values() if e.code),
key=lambda e: e.code, # type: ignore
),
)
diff --git a/tests/test_formparser.py b/tests/test_formparser.py
index 49010b4..1ecb012 100644
--- a/tests/test_formparser.py
+++ b/tests/test_formparser.py
@@ -69,6 +69,17 @@ def test_limiting(self):
req.max_form_memory_size = 400
assert req.form["foo"] == "Hello World"
+ input_stream = io.BytesIO(b"foo=123456")
+ req = Request.from_values(
+ input_stream=input_stream,
+ content_type="application/x-www-form-urlencoded",
+ method="POST",
+ )
+ req.max_content_length = 4
+ pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"])
+ # content-length was set, so request could exit early without reading anything
+ assert input_stream.read() == b"foo=123456"
+
data = (
b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n"
b"Hello World\r\n"
@@ -81,24 +92,17 @@ def test_limiting(self):
content_type="multipart/form-data; boundary=foo",
method="POST",
)
- req.max_content_length = 4
- pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"])
+ req.max_content_length = 400
+ assert req.form["foo"] == "Hello World"
- # when the request entity is too large, the input stream should be
- # drained so that firefox (and others) do not report connection reset
- # when run through gunicorn
- # a sufficiently large stream is necessary for block-based reads
- input_stream = io.BytesIO(b"foo=" + b"x" * 128 * 1024)
req = Request.from_values(
- input_stream=input_stream,
+ input_stream=io.BytesIO(data),
content_length=len(data),
content_type="multipart/form-data; boundary=foo",
method="POST",
)
- req.max_content_length = 4
+ req.max_form_memory_size = 7
pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"])
- # ensure that the stream is exhausted
- assert input_stream.read() == b""
req = Request.from_values(
input_stream=io.BytesIO(data),
@@ -106,7 +110,7 @@ def test_limiting(self):
content_type="multipart/form-data; boundary=foo",
method="POST",
)
- req.max_content_length = 400
+ req.max_form_memory_size = 400
assert req.form["foo"] == "Hello World"
req = Request.from_values(
@@ -115,17 +119,15 @@ def test_limiting(self):
content_type="multipart/form-data; boundary=foo",
method="POST",
)
- req.max_form_memory_size = 7
+ req.max_form_parts = 1
pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"])
- req = Request.from_values(
- input_stream=io.BytesIO(data),
- content_length=len(data),
- content_type="multipart/form-data; boundary=foo",
- method="POST",
- )
- req.max_form_memory_size = 400
- assert req.form["foo"] == "Hello World"
+ def test_x_www_urlencoded_max_form_parts(self):
+ r = Request.from_values(method="POST", data={"a": 1, "b": 2})
+ r.max_form_parts = 1
+
+ assert r.form["a"] == "1"
+ assert r.form["b"] == "2"
def test_missing_multipart_boundary(self):
data = (
@@ -271,7 +273,7 @@ def test_basic(self):
content_type=f'multipart/form-data; boundary="{boundary}"',
content_length=len(data),
) as response:
- assert response.get_data() == repr(text).encode("utf-8")
+ assert response.get_data() == repr(text).encode()
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
def test_ie7_unc_path(self):
diff --git a/tests/test_http.py b/tests/test_http.py
index 3760dc1..bbd51ba 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -1,4 +1,5 @@
import base64
+import urllib.parse
from datetime import date
from datetime import datetime
from datetime import timedelta
@@ -9,6 +10,8 @@
from werkzeug import datastructures
from werkzeug import http
from werkzeug._internal import _wsgi_encoding_dance
+from werkzeug.datastructures import Authorization
+from werkzeug.datastructures import WWWAuthenticate
from werkzeug.test import create_environ
@@ -21,6 +24,10 @@ def test_accept(self):
pytest.raises(ValueError, a.index, "de")
assert a.to_header() == "en-us,ru;q=0.5"
+ def test_accept_parameter_with_space(self):
+ a = http.parse_accept_header('application/x-special; z="a b";q=0.5')
+ assert a['application/x-special; z="a b"'] == 0.5
+
def test_mime_accept(self):
a = http.parse_accept_header(
"text/xml,application/xml,"
@@ -88,9 +95,17 @@ def test_set_header(self):
hs.add("Foo")
assert hs.to_header() == 'foo, Bar, "Blah baz", Hehe'
- def test_list_header(self):
- hl = http.parse_list_header("foo baz, blah")
- assert hl == ["foo baz", "blah"]
+ @pytest.mark.parametrize(
+ ("value", "expect"),
+ [
+ ("a b", ["a b"]),
+ ("a b, c", ["a b", "c"]),
+ ('a b, "c, d"', ["a b", "c, d"]),
+ ('"a\\"b", c', ['a"b', "c"]),
+ ],
+ )
+ def test_list_header(self, value, expect):
+ assert http.parse_list_header(value) == expect
def test_dict_header(self):
d = http.parse_dict_header('foo="bar baz", blah=42')
@@ -133,33 +148,30 @@ def test_csp_header(self):
assert csp.img_src is None
def test_authorization_header(self):
- a = http.parse_authorization_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
+ a = Authorization.from_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
assert a.type == "basic"
assert a.username == "Aladdin"
assert a.password == "open sesame"
- a = http.parse_authorization_header(
- "Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw=="
- )
+ a = Authorization.from_header("Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw==")
assert a.type == "basic"
assert a.username == "русскиЁ"
assert a.password == "Буквы"
- a = http.parse_authorization_header("Basic 5pmu6YCa6K+dOuS4reaWhw==")
+ a = Authorization.from_header("Basic 5pmu6YCa6K+dOuS4reaWhw==")
assert a.type == "basic"
assert a.username == "普通话"
assert a.password == "中文"
- a = http.parse_authorization_header(
- '''Digest username="Mufasa",
- realm="testrealm@host.invalid",
- nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",
- uri="/dir/index.html",
- qop=auth,
- nc=00000001,
- cnonce="0a4f113b",
- response="6629fae49393a05397450978507c4ef1",
- opaque="5ccc069c403ebaf9f0171e9517f40e41"'''
+ a = Authorization.from_header(
+ 'Digest username="Mufasa",'
+ ' realm="testrealm@host.invalid",'
+ ' nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",'
+ ' uri="/dir/index.html",'
+ " qop=auth, nc=00000001,"
+ ' cnonce="0a4f113b",'
+ ' response="6629fae49393a05397450978507c4ef1",'
+ ' opaque="5ccc069c403ebaf9f0171e9517f40e41"'
)
assert a.type == "digest"
assert a.username == "Mufasa"
@@ -172,13 +184,13 @@ def test_authorization_header(self):
assert a.response == "6629fae49393a05397450978507c4ef1"
assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41"
- a = http.parse_authorization_header(
- '''Digest username="Mufasa",
- realm="testrealm@host.invalid",
- nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",
- uri="/dir/index.html",
- response="e257afa1414a3340d93d30955171dd0e",
- opaque="5ccc069c403ebaf9f0171e9517f40e41"'''
+ a = Authorization.from_header(
+ 'Digest username="Mufasa",'
+ ' realm="testrealm@host.invalid",'
+ ' nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",'
+ ' uri="/dir/index.html",'
+ ' response="e257afa1414a3340d93d30955171dd0e",'
+ ' opaque="5ccc069c403ebaf9f0171e9517f40e41"'
)
assert a.type == "digest"
assert a.username == "Mufasa"
@@ -188,41 +200,87 @@ def test_authorization_header(self):
assert a.response == "e257afa1414a3340d93d30955171dd0e"
assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41"
- assert http.parse_authorization_header("") is None
- assert http.parse_authorization_header(None) is None
- assert http.parse_authorization_header("foo") is None
+ assert Authorization.from_header("") is None
+ assert Authorization.from_header(None) is None
+ assert Authorization.from_header("foo").type == "foo"
+
+ def test_authorization_token_padding(self):
+ # padded with =
+ token = base64.b64encode(b"This has base64 padding").decode()
+ a = Authorization.from_header(f"Token {token}")
+ assert a.type == "token"
+ assert a.token == token
+
+ # padded with ==
+ token = base64.b64encode(b"This has base64 padding..").decode()
+ a = Authorization.from_header(f"Token {token}")
+ assert a.type == "token"
+ assert a.token == token
+
+ def test_authorization_basic_incorrect_padding(self):
+ assert Authorization.from_header("Basic foo") is None
def test_bad_authorization_header_encoding(self):
"""If the base64 encoded bytes can't be decoded as UTF-8"""
content = base64.b64encode(b"\xffser:pass").decode()
- assert http.parse_authorization_header(f"Basic {content}") is None
+ assert Authorization.from_header(f"Basic {content}") is None
+
+ def test_authorization_eq(self):
+ basic1 = Authorization.from_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
+ basic2 = Authorization(
+ "basic", {"username": "Aladdin", "password": "open sesame"}
+ )
+ assert basic1 == basic2
+ bearer1 = Authorization.from_header("Bearer abc")
+ bearer2 = Authorization("bearer", token="abc")
+ assert bearer1 == bearer2
+ assert basic1 != bearer1
+ assert basic1 != object()
def test_www_authenticate_header(self):
- wa = http.parse_www_authenticate_header('Basic realm="WallyWorld"')
+ wa = WWWAuthenticate.from_header('Basic realm="WallyWorld"')
assert wa.type == "basic"
assert wa.realm == "WallyWorld"
wa.realm = "Foo Bar"
assert wa.to_header() == 'Basic realm="Foo Bar"'
- wa = http.parse_www_authenticate_header(
- '''Digest
- realm="testrealm@host.com",
- qop="auth,auth-int",
- nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",
- opaque="5ccc069c403ebaf9f0171e9517f40e41"'''
+ wa = WWWAuthenticate.from_header(
+ 'Digest realm="testrealm@host.com",'
+ ' qop="auth,auth-int",'
+ ' nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",'
+ ' opaque="5ccc069c403ebaf9f0171e9517f40e41"'
)
assert wa.type == "digest"
assert wa.realm == "testrealm@host.com"
- assert "auth" in wa.qop
- assert "auth-int" in wa.qop
+ assert wa.parameters["qop"] == "auth,auth-int"
assert wa.nonce == "dcd98b7102dd2f0e8b11d0f600bfb0c093"
assert wa.opaque == "5ccc069c403ebaf9f0171e9517f40e41"
- wa = http.parse_www_authenticate_header("broken")
- assert wa.type == "broken"
-
- assert not http.parse_www_authenticate_header("").type
- assert not http.parse_www_authenticate_header("")
+ assert WWWAuthenticate.from_header("broken").type == "broken"
+ assert WWWAuthenticate.from_header("") is None
+
+ def test_www_authenticate_token_padding(self):
+ # padded with =
+ token = base64.b64encode(b"This has base64 padding").decode()
+ a = WWWAuthenticate.from_header(f"Token {token}")
+ assert a.type == "token"
+ assert a.token == token
+
+ # padded with ==
+ token = base64.b64encode(b"This has base64 padding..").decode()
+ a = WWWAuthenticate.from_header(f"Token {token}")
+ assert a.type == "token"
+ assert a.token == token
+
+ def test_www_authenticate_eq(self):
+ basic1 = WWWAuthenticate.from_header("Basic realm=abc")
+ basic2 = WWWAuthenticate("basic", {"realm": "abc"})
+ assert basic1 == basic2
+ token1 = WWWAuthenticate.from_header("Token abc")
+ token2 = WWWAuthenticate("token", token="abc")
+ assert token1 == token2
+ assert basic1 != token1
+ assert basic1 != object()
def test_etags(self):
assert http.quote_etag("foo") == '"foo"'
@@ -274,68 +332,63 @@ def test_remove_hop_by_hop_headers(self):
http.remove_hop_by_hop_headers(headers2)
assert headers2 == datastructures.Headers([("Foo", "bar")])
- def test_parse_options_header(self):
- assert http.parse_options_header(None) == ("", {})
- assert http.parse_options_header("") == ("", {})
- assert http.parse_options_header(r'something; foo="other\"thing"') == (
- "something",
- {"foo": 'other"thing'},
- )
- assert http.parse_options_header(r'something; foo="other\"thing"; meh=42') == (
- "something",
- {"foo": 'other"thing', "meh": "42"},
- )
- assert http.parse_options_header(
- r'something; foo="other\"thing"; meh=42; bleh'
- ) == ("something", {"foo": 'other"thing', "meh": "42", "bleh": None})
- assert http.parse_options_header(
- 'something; foo="other;thing"; meh=42; bleh'
- ) == ("something", {"foo": "other;thing", "meh": "42", "bleh": None})
- assert http.parse_options_header('something; foo="otherthing"; meh=; bleh') == (
- "something",
- {"foo": "otherthing", "meh": None, "bleh": None},
- )
- # Issue #404
- assert http.parse_options_header(
- 'multipart/form-data; name="foo bar"; filename="bar foo"'
- ) == ("multipart/form-data", {"name": "foo bar", "filename": "bar foo"})
- # Examples from RFC
- assert http.parse_options_header("audio/*; q=0.2, audio/basic") == (
- "audio/*",
- {"q": "0.2"},
- )
-
- assert http.parse_options_header(
- "text/plain; q=0.5, text/html\n text/x-dvi; q=0.8, text/x-c"
- ) == ("text/plain", {"q": "0.5"})
- # Issue #932
- assert http.parse_options_header(
- "form-data; name=\"a_file\"; filename*=UTF-8''"
- '"%c2%a3%20and%20%e2%82%ac%20rates"'
- ) == ("form-data", {"name": "a_file", "filename": "\xa3 and \u20ac rates"})
- assert http.parse_options_header(
- "form-data; name*=UTF-8''\"%C5%AAn%C4%ADc%C5%8Dde%CC%BD\"; "
- 'filename="some_file.txt"'
- ) == (
- "form-data",
- {"name": "\u016an\u012dc\u014dde\u033d", "filename": "some_file.txt"},
- )
+ @pytest.mark.parametrize(
+ ("value", "expect"),
+ [
+ (None, ""),
+ ("", ""),
+ (";a=b", ""),
+ ("v", "v"),
+ ("v;", "v"),
+ ],
+ )
+ def test_parse_options_header_empty(self, value, expect):
+ assert http.parse_options_header(value) == (expect, {})
- def test_parse_options_header_value_with_quotes(self):
- assert http.parse_options_header(
- 'form-data; name="file"; filename="t\'es\'t.txt"'
- ) == ("form-data", {"name": "file", "filename": "t'es't.txt"})
- assert http.parse_options_header(
- "form-data; name=\"file\"; filename*=UTF-8''\"'🐍'.txt\""
- ) == ("form-data", {"name": "file", "filename": "'🐍'.txt"})
+ @pytest.mark.parametrize(
+ ("value", "expect"),
+ [
+ ("v;a=b;c=d;", {"a": "b", "c": "d"}),
+ ("v; ; a=b ; ", {"a": "b"}),
+ ("v;a", {}),
+ ("v;a=", {}),
+ ("v;=b", {}),
+ ('v;a="b"', {"a": "b"}),
+ ("v;a=µ", {}),
+ ('v;a="\';\'";b="µ";', {"a": "';'", "b": "µ"}),
+ ('v;a="b c"', {"a": "b c"}),
+ # HTTP headers use \" for internal "
+ ('v;a="b\\"c";d=e', {"a": 'b"c', "d": "e"}),
+ # HTTP headers use \\ for internal \
+ ('v;a="c:\\\\"', {"a": "c:\\"}),
+ # Invalid trailing slash in quoted part is left as-is.
+ ('v;a="c:\\"', {"a": "c:\\"}),
+ ('v;a="b\\\\\\"c"', {"a": 'b\\"c'}),
+ # multipart form data uses %22 for internal "
+ ('v;a="b%22c"', {"a": 'b"c'}),
+ ("v;a*=b", {"a": "b"}),
+ ("v;a*=ASCII'en'b", {"a": "b"}),
+ ("v;a*=US-ASCII''%62", {"a": "b"}),
+ ("v;a*=UTF-8''%C2%B5", {"a": "µ"}),
+ ("v;a*=US-ASCII''%C2%B5", {"a": "��"}),
+ ("v;a*=BAD''%62", {"a": "%62"}),
+ ("v;a*=UTF-8'''%F0%9F%90%8D'.txt", {"a": "'🐍'.txt"}),
+ ('v;a="🐍.txt"', {"a": "🐍.txt"}),
+ ("v;a*0=b;a*1=c;d=e", {"a": "bc", "d": "e"}),
+ ("v;a*0*=b", {"a": "b"}),
+ ("v;a*0*=UTF-8''b;a*1=c;a*2*=%C2%B5", {"a": "bcµ"}),
+ ],
+ )
+ def test_parse_options_header(self, value, expect) -> None:
+ assert http.parse_options_header(value) == ("v", expect)
def test_parse_options_header_broken_values(self):
# Issue #995
assert http.parse_options_header(" ") == ("", {})
- assert http.parse_options_header(" , ") == ("", {})
+ assert http.parse_options_header(" , ") == (",", {})
assert http.parse_options_header(" ; ") == ("", {})
- assert http.parse_options_header(" ,; ") == ("", {})
- assert http.parse_options_header(" , a ") == ("", {})
+ assert http.parse_options_header(" ,; ") == (",", {})
+ assert http.parse_options_header(" , a ") == (", a", {})
assert http.parse_options_header(" ; a ") == ("", {})
def test_parse_options_header_case_insensitive(self):
@@ -344,16 +397,12 @@ def test_parse_options_header_case_insensitive(self):
def test_dump_options_header(self):
assert http.dump_options_header("foo", {"bar": 42}) == "foo; bar=42"
- assert http.dump_options_header("foo", {"bar": 42, "fizz": None}) in (
- "foo; bar=42; fizz",
- "foo; fizz; bar=42",
- )
+ assert "fizz" not in http.dump_options_header("foo", {"bar": 42, "fizz": None})
def test_dump_header(self):
assert http.dump_header([1, 2, 3]) == "1, 2, 3"
- assert http.dump_header([1, 2, 3], allow_token=False) == '"1", "2", "3"'
- assert http.dump_header({"foo": "bar"}, allow_token=False) == 'foo="bar"'
assert http.dump_header({"foo": "bar"}) == "foo=bar"
+ assert http.dump_header({"foo*": "UTF-8''bar"}) == "foo*=UTF-8''bar"
def test_is_resource_modified(self):
env = create_environ()
@@ -411,7 +460,8 @@ def test_is_resource_modified_for_range_requests(self):
def test_parse_cookie(self):
cookies = http.parse_cookie(
"dismiss-top=6; CP=null*; PHPSESSID=0a539d42abc001cdc762809248d4beed;"
- 'a=42; b="\\";"; ; fo234{=bar;blub=Blah; "__Secure-c"=d'
+ 'a=42; b="\\";"; ; fo234{=bar;blub=Blah; "__Secure-c"=d;'
+ "==__Host-eq=bad;__Host-eq=good;"
)
assert cookies.to_dict() == {
"CP": "null*",
@@ -422,6 +472,7 @@ def test_parse_cookie(self):
"fo234{": "bar",
"blub": "Blah",
'"__Secure-c"': "d",
+ "__Host-eq": "good",
}
def test_dump_cookie(self):
@@ -435,7 +486,7 @@ def test_dump_cookie(self):
'foo="bar baz blub"',
}
assert http.dump_cookie("key", "xxx/") == "key=xxx/; Path=/"
- assert http.dump_cookie("key", "xxx=") == "key=xxx=; Path=/"
+ assert http.dump_cookie("key", "xxx=", path=None) == "key=xxx="
def test_bad_cookies(self):
cookies = http.parse_cookie(
@@ -458,9 +509,9 @@ def test_empty_keys_are_ignored(self):
def test_cookie_quoting(self):
val = http.dump_cookie("foo", "?foo")
- assert val == 'foo="?foo"; Path=/'
- assert http.parse_cookie(val).to_dict() == {"foo": "?foo", "Path": "/"}
- assert http.parse_cookie(r'foo="foo\054bar"').to_dict(), {"foo": "foo,bar"}
+ assert val == "foo=?foo; Path=/"
+ assert http.parse_cookie(val)["foo"] == "?foo"
+ assert http.parse_cookie(r'foo="foo\054bar"')["foo"] == "foo,bar"
def test_parse_set_cookie_directive(self):
val = 'foo="?foo"; version="0.1";'
@@ -482,7 +533,7 @@ def test_cookie_unicode_dumping(self):
def test_cookie_unicode_keys(self):
# Yes, this is technically against the spec but happens
val = http.dump_cookie("fö", "fö")
- assert val == _wsgi_encoding_dance('fö="f\\303\\266"; Path=/', "utf-8")
+ assert val == _wsgi_encoding_dance('fö="f\\303\\266"; Path=/')
cookies = http.parse_cookie(val)
assert cookies["fö"] == "fö"
@@ -495,38 +546,30 @@ def test_cookie_domain_encoding(self):
val = http.dump_cookie("foo", "bar", domain="\N{SNOWMAN}.com")
assert val == "foo=bar; Domain=xn--n3h.com; Path=/"
- val = http.dump_cookie("foo", "bar", domain=".\N{SNOWMAN}.com")
- assert val == "foo=bar; Domain=.xn--n3h.com; Path=/"
-
- val = http.dump_cookie("foo", "bar", domain=".foo.com")
- assert val == "foo=bar; Domain=.foo.com; Path=/"
+ val = http.dump_cookie("foo", "bar", domain="foo.com")
+ assert val == "foo=bar; Domain=foo.com; Path=/"
- def test_cookie_maxsize(self, recwarn):
+ def test_cookie_maxsize(self):
val = http.dump_cookie("foo", "bar" * 1360 + "b")
- assert len(recwarn) == 0
assert len(val) == 4093
- http.dump_cookie("foo", "bar" * 1360 + "ba")
- assert len(recwarn) == 1
- w = recwarn.pop()
- assert "cookie is too large" in str(w.message)
+ with pytest.warns(UserWarning, match="cookie is too large"):
+ http.dump_cookie("foo", "bar" * 1360 + "ba")
- http.dump_cookie("foo", b"w" * 502, max_size=512)
- assert len(recwarn) == 1
- w = recwarn.pop()
- assert "the limit is 512 bytes" in str(w.message)
+ with pytest.warns(UserWarning, match="the limit is 512 bytes"):
+ http.dump_cookie("foo", "w" * 501, max_size=512)
@pytest.mark.parametrize(
("samesite", "expected"),
(
- ("strict", "foo=bar; Path=/; SameSite=Strict"),
- ("lax", "foo=bar; Path=/; SameSite=Lax"),
- ("none", "foo=bar; Path=/; SameSite=None"),
- (None, "foo=bar; Path=/"),
+ ("strict", "foo=bar; SameSite=Strict"),
+ ("lax", "foo=bar; SameSite=Lax"),
+ ("none", "foo=bar; SameSite=None"),
+ (None, "foo=bar"),
),
)
def test_cookie_samesite_attribute(self, samesite, expected):
- value = http.dump_cookie("foo", "bar", samesite=samesite)
+ value = http.dump_cookie("foo", "bar", samesite=samesite, path=None)
assert value == expected
def test_cookie_samesite_invalid(self):
@@ -619,6 +662,9 @@ def test_content_range_parsing(self):
rv = http.parse_content_range_header("bytes 0-98/*asdfsa")
assert rv is None
+ rv = http.parse_content_range_header("bytes */-1")
+ assert rv is None
+
rv = http.parse_content_range_header("bytes 0-99/100")
assert rv.to_header() == "bytes 0-99/100"
rv.start = None
@@ -656,7 +702,7 @@ def test_best_match_works(self):
],
)
def test_authorization_to_header(value: str) -> None:
- parsed = http.parse_authorization_header(value)
+ parsed = Authorization.from_header(value)
assert parsed is not None
assert parsed.to_header() == value
@@ -715,3 +761,32 @@ def test_parse_date(value, expect):
)
def test_http_date(value, expect):
assert http.http_date(value) == expect
+
+
+@pytest.mark.parametrize("value", [".5", "+0.5", "0.5_1", "🯰.🯵"])
+def test_accept_invalid_float(value):
+ quoted = urllib.parse.quote(value)
+
+ if quoted == value:
+ q = f"q={value}"
+ else:
+ q = f"q*=UTF-8''{value}"
+
+ a = http.parse_accept_header(f"en,jp;{q}")
+ assert list(a.values()) == ["en"]
+
+
+def test_accept_valid_int_one_zero():
+ assert http.parse_accept_header("en;q=1") == http.parse_accept_header("en;q=1.0")
+ assert http.parse_accept_header("en;q=0") == http.parse_accept_header("en;q=0.0")
+ assert http.parse_accept_header("en;q=5") == http.parse_accept_header("en;q=5.0")
+
+
+@pytest.mark.parametrize("value", ["🯱🯲🯳", "+1-", "1-1_23"])
+def test_range_invalid_int(value):
+ assert http.parse_range_header(value) is None
+
+
+@pytest.mark.parametrize("value", ["*/🯱🯲🯳", "1-+2/3", "1_23-125/*"])
+def test_content_range_invalid_int(value):
+ assert http.parse_content_range_header(f"bytes {value}") is None
diff --git a/tests/test_internal.py b/tests/test_internal.py
index 6e673fd..edae35b 100644
--- a/tests/test_internal.py
+++ b/tests/test_internal.py
@@ -1,21 +1,8 @@
-from warnings import filterwarnings
-from warnings import resetwarnings
-
-import pytest
-
-from werkzeug import _internal as internal
from werkzeug.test import create_environ
from werkzeug.wrappers import Request
from werkzeug.wrappers import Response
-def test_easteregg():
- req = Request.from_values("/?macgybarchakku")
- resp = Response.force_type(internal._easteregg(None), req)
- assert b"About Werkzeug" in resp.get_data()
- assert b"the Swiss Army knife of Python web development" in resp.get_data()
-
-
def test_wrapper_internals():
req = Request.from_values(data={"foo": "bar"}, method="POST")
req._load_form_data()
@@ -34,23 +21,10 @@ def test_wrapper_internals():
resp.response = iter(["Test"])
assert repr(resp) == " You should be redirected automatically to the target URL: "
- b'/f\xc3\xbc\xc3\xbcb\xc3\xa4r. '
- b"If not, click the link.\n"
- )
+@pytest.mark.parametrize(
+ ("url", "code", "expect"),
+ [
+ ("http://example.com", None, "http://example.com"),
+ ("/füübär", 305, "/f%C3%BC%C3%BCb%C3%A4r"),
+ ("http://☃.example.com/", 307, "http://xn--n3h.example.com/"),
+ ("itms-services://?url=abc", None, "itms-services://?url=abc"),
+ ],
+)
+def test_redirect(url: str, code: int | None, expect: str) -> None:
+ environ = EnvironBuilder().get_environ()
- resp = utils.redirect("http://☃.net/", 307)
- assert resp.headers["Location"] == "http://xn--n3h.net/"
- assert resp.status_code == 307
- assert resp.get_data() == (
- b"\n"
- b"\n"
- b" You should be redirected automatically to the target URL: "
- b'http://\xe2\x98\x83.net/. '
- b"If not, click the link.\n"
- )
+ if code is None:
+ resp = utils.redirect(url)
+ assert resp.status_code == 302
+ else:
+ resp = utils.redirect(url, code)
+ assert resp.status_code == code
- resp = utils.redirect("http://example.com/", 305)
- assert resp.headers["Location"] == "http://example.com/"
- assert resp.status_code == 305
- assert resp.get_data() == (
- b"\n"
- b"\n"
- b" You should be redirected automatically to the target URL: "
- b'http://example.com/. '
- b"If not, click the link.\n"
- )
+ assert resp.headers["Location"] == url
+ assert resp.get_wsgi_headers(environ)["Location"] == expect
+ assert resp.get_data(as_text=True).count(url) == 2
def test_redirect_xss():
@@ -190,6 +176,7 @@ def test_assign():
def test_import_string():
from datetime import date
+
from werkzeug.debug import DebuggedApplication
assert utils.import_string("datetime.date") is date
diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py
index b769a38..f756944 100644
--- a/tests/test_wrappers.py
+++ b/tests/test_wrappers.py
@@ -20,9 +20,11 @@
from werkzeug.datastructures import LanguageAccept
from werkzeug.datastructures import MIMEAccept
from werkzeug.datastructures import MultiDict
+from werkzeug.datastructures import WWWAuthenticate
from werkzeug.exceptions import BadRequest
from werkzeug.exceptions import RequestedRangeNotSatisfiable
from werkzeug.exceptions import SecurityError
+from werkzeug.exceptions import UnsupportedMediaType
from werkzeug.http import COEP
from werkzeug.http import COOP
from werkzeug.http import generate_etag
@@ -136,11 +138,12 @@ def test_url_request_descriptors():
def test_url_request_descriptors_query_quoting():
- next = "http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash"
- req = wrappers.Request.from_values(f"/bar?next={next}", "http://example.com/")
+ quoted = "http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash"
+ unquoted = "http://www.example.com/?next%3D/baz%23my%3Dhash"
+ req = wrappers.Request.from_values(f"/bar?next={quoted}", "http://example.com/")
assert req.path == "/bar"
- assert req.full_path == f"/bar?next={next}"
- assert req.url == f"http://example.com/bar?next={next}"
+ assert req.full_path == f"/bar?next={quoted}"
+ assert req.url == f"http://example.com/bar?next={unquoted}"
def test_url_request_descriptors_hosts():
@@ -349,13 +352,6 @@ def test_response_init_status_empty_string():
assert "Empty status argument" in str(info.value)
-def test_response_init_status_tuple():
- with pytest.raises(TypeError) as info:
- wrappers.Response(None, tuple())
-
- assert "Invalid status argument" in str(info.value)
-
-
def test_type_forcing():
def wsgi_application(environ, start_response):
start_response("200 OK", [("Content-Type", "text/html")])
@@ -686,27 +682,26 @@ def test_etag_response_freezing():
def test_authenticate():
resp = wrappers.Response()
- resp.www_authenticate.type = "basic"
resp.www_authenticate.realm = "Testing"
- assert resp.headers["WWW-Authenticate"] == 'Basic realm="Testing"'
- resp.www_authenticate.realm = None
- resp.www_authenticate.type = None
+ assert resp.headers["WWW-Authenticate"] == "Basic realm=Testing"
+ del resp.www_authenticate
assert "WWW-Authenticate" not in resp.headers
def test_authenticate_quoted_qop():
# Example taken from https://github.com/pallets/werkzeug/issues/633
resp = wrappers.Response()
- resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth", "auth-int"))
+ resp.www_authenticate = WWWAuthenticate(
+ "digest", {"realm": "REALM", "nonce": "NONCE", "qop": "auth, auth-int"}
+ )
- actual = set(f"{resp.headers['WWW-Authenticate']},".split())
- expected = set('Digest nonce="NONCE", realm="REALM", qop="auth, auth-int",'.split())
+ actual = resp.headers["WWW-Authenticate"]
+ expected = 'Digest realm="REALM", nonce="NONCE", qop="auth, auth-int"'
assert actual == expected
- resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth",))
-
- actual = set(f"{resp.headers['WWW-Authenticate']},".split())
- expected = set('Digest nonce="NONCE", realm="REALM", qop="auth",'.split())
+ resp.www_authenticate.parameters["qop"] = "auth"
+ actual = resp.headers["WWW-Authenticate"]
+ expected = 'Digest realm="REALM", nonce="NONCE", qop="auth"'
assert actual == expected
@@ -875,12 +870,6 @@ def test_file_closing_with():
assert foo.closed is True
-def test_url_charset_reflection():
- req = wrappers.Request.from_values()
- req.charset = "utf-7"
- assert req.url_charset == "utf-7"
-
-
def test_response_streamed():
r = wrappers.Response()
assert not r.is_streamed
@@ -1048,25 +1037,25 @@ class MyRequest(wrappers.Request):
parameter_storage_class = dict
req = MyRequest.from_values("/?foo=baz", headers={"Cookie": "foo=bar"})
- assert type(req.cookies) is dict
+ assert type(req.cookies) is dict # noqa: E721
assert req.cookies == {"foo": "bar"}
- assert type(req.access_route) is list
+ assert type(req.access_route) is list # noqa: E721
- assert type(req.args) is dict
- assert type(req.values) is CombinedMultiDict
+ assert type(req.args) is dict # noqa: E721
+ assert type(req.values) is CombinedMultiDict # noqa: E721
assert req.values["foo"] == "baz"
req = wrappers.Request.from_values(headers={"Cookie": "foo=bar;foo=baz"})
- assert type(req.cookies) is ImmutableMultiDict
+ assert type(req.cookies) is ImmutableMultiDict # noqa: E721
assert req.cookies.to_dict() == {"foo": "bar"}
# it is possible to have multiple cookies with the same name
assert req.cookies.getlist("foo") == ["bar", "baz"]
- assert type(req.access_route) is ImmutableList
+ assert type(req.access_route) is ImmutableList # noqa: E721
MyRequest.list_storage_class = tuple
req = MyRequest.from_values()
- assert type(req.access_route) is tuple
+ assert type(req.access_route) is tuple # noqa: E721
def test_response_headers_passthrough():
@@ -1165,6 +1154,7 @@ class MyResponse(wrappers.Response):
("auto", "location", "expect"),
(
(False, "/test", "/test"),
+ (False, "/\\\\test.example?q", "/%5C%5Ctest.example?q"),
(True, "/test", "http://localhost/test"),
(True, "test", "http://localhost/a/b/test"),
(True, "./test", "http://localhost/a/b/test"),
@@ -1206,14 +1196,6 @@ def test_malformed_204_response_has_no_content_length():
assert b"".join(app_iter) == b"" # ensure data will not be sent
-def test_modified_url_encoding():
- class ModifiedRequest(wrappers.Request):
- url_charset = "euc-kr"
-
- req = ModifiedRequest.from_values(query_string={"foo": "정상처리"}, charset="euc-kr")
- assert req.args["foo"] == "정상처리"
-
-
def test_request_method_case_sensitivity():
req = wrappers.Request(
{"REQUEST_METHOD": "get", "SERVER_NAME": "eggs", "SERVER_PORT": "80"}
@@ -1350,7 +1332,7 @@ def test_bad_content_type(self):
value = [1, 2, 3]
request = wrappers.Request.from_values(json=value, content_type="text/plain")
- with pytest.raises(BadRequest):
+ with pytest.raises(UnsupportedMediaType):
request.get_json()
assert request.get_json(silent=True) is None
diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py
index b0f71bc..7f4d2e9 100644
--- a/tests/test_wsgi.py
+++ b/tests/test_wsgi.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
import io
import json
import os
+import typing as t
import pytest
@@ -84,7 +87,6 @@ def foo(environ, start_response):
def test_path_info_and_script_name_fetching():
env = create_environ("/\N{SNOWMAN}", "http://example.com/\N{COMET}/")
assert wsgi.get_path_info(env) == "/\N{SNOWMAN}"
- assert wsgi.get_path_info(env, charset=None) == "/\N{SNOWMAN}".encode()
def test_limited_stream():
@@ -117,11 +119,10 @@ def on_exhausted(self):
stream = wsgi.LimitedStream(io_, 9)
assert stream.readlines() == [b"123456\n", b"ab"]
- io_ = io.BytesIO(b"123456\nabcdefg")
+ io_ = io.BytesIO(b"123\n456\nabcdefg")
stream = wsgi.LimitedStream(io_, 9)
- assert stream.readlines(2) == [b"12"]
- assert stream.readlines(2) == [b"34"]
- assert stream.readlines() == [b"56\n", b"ab"]
+ assert stream.readlines(2) == [b"123\n"]
+ assert stream.readlines() == [b"456\n", b"a"]
io_ = io.BytesIO(b"123456\nabcdefg")
stream = wsgi.LimitedStream(io_, 9)
@@ -146,13 +147,8 @@ def on_exhausted(self):
stream = wsgi.LimitedStream(io_, 0)
assert stream.read(-1) == b""
- io_ = io.StringIO("123456")
- stream = wsgi.LimitedStream(io_, 0)
- assert stream.read(-1) == ""
-
- io_ = io.StringIO("123\n456\n")
- stream = wsgi.LimitedStream(io_, 8)
- assert list(stream) == ["123\n", "456\n"]
+ stream = wsgi.LimitedStream(io.BytesIO(b"123\n456\n"), 8)
+ assert list(stream) == [b"123\n", b"456\n"]
def test_limited_stream_json_load():
@@ -165,21 +161,63 @@ def test_limited_stream_json_load():
def test_limited_stream_disconnection():
- io_ = io.BytesIO(b"A bit of content")
-
- # disconnect detection on out of bytes
- stream = wsgi.LimitedStream(io_, 255)
+ # disconnect because stream returns zero bytes
+ stream = wsgi.LimitedStream(io.BytesIO(), 255)
with pytest.raises(ClientDisconnected):
stream.read()
- # disconnect detection because file close
- io_ = io.BytesIO(b"x" * 255)
- io_.close()
- stream = wsgi.LimitedStream(io_, 255)
+ # disconnect because stream is closed
+ data = io.BytesIO(b"x" * 255)
+ data.close()
+ stream = wsgi.LimitedStream(data, 255)
+
with pytest.raises(ClientDisconnected):
stream.read()
+def test_limited_stream_read_with_raw_io():
+ class OneByteStream(t.BinaryIO):
+ def __init__(self, buf: bytes) -> None:
+ self.buf = buf
+ self.pos = 0
+
+ def read(self, size: int | None = None) -> bytes:
+ """Return one byte at a time regardless of requested size."""
+
+ if size is None or size == -1:
+ raise ValueError("expected read to be called with specific limit")
+
+ if size == 0 or len(self.buf) < self.pos:
+ return b""
+
+ b = self.buf[self.pos : self.pos + 1]
+ self.pos += 1
+ return b
+
+ stream = wsgi.LimitedStream(OneByteStream(b"foo"), 4)
+ assert stream.read(5) == b"f"
+ assert stream.read(5) == b"o"
+ assert stream.read(5) == b"o"
+
+ # The stream has fewer bytes (3) than the limit (4), therefore the read returns 0
+ # bytes before the limit is reached.
+ with pytest.raises(ClientDisconnected):
+ stream.read(5)
+
+ stream = wsgi.LimitedStream(OneByteStream(b"foo123"), 3)
+ assert stream.read(5) == b"f"
+ assert stream.read(5) == b"o"
+ assert stream.read(5) == b"o"
+ # The limit was reached, therefore the wrapper is exhausted, not disconnected.
+ assert stream.read(5) == b""
+
+ stream = wsgi.LimitedStream(OneByteStream(b"foo"), 3)
+ assert stream.read() == b"foo"
+
+ stream = wsgi.LimitedStream(OneByteStream(b"foo"), 2)
+ assert stream.read() == b"fo"
+
+
def test_get_host_fallback():
assert (
wsgi.get_host(
@@ -218,123 +256,6 @@ def test_get_current_url_invalid_utf8():
assert rv == "http://localhost/?foo=bar&baz=blah&meh=%CF"
-def test_multi_part_line_breaks():
- data = "abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK"
- test_stream = io.StringIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16))
- assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"]
-
- data = "abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz"
- test_stream = io.StringIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24))
- assert lines == [
- "abc\r\n",
- "This line is broken by the buffer length.\r\n",
- "Foo bar baz",
- ]
-
-
-def test_multi_part_line_breaks_bytes():
- data = b"abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK"
- test_stream = io.BytesIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16))
- assert lines == [
- b"abcdef\r\n",
- b"ghijkl\r\n",
- b"mnopqrstuvwxyz\r\n",
- b"ABCDEFGHIJK",
- ]
-
- data = b"abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz"
- test_stream = io.BytesIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24))
- assert lines == [
- b"abc\r\n",
- b"This line is broken by the buffer length.\r\n",
- b"Foo bar baz",
- ]
-
-
-def test_multi_part_line_breaks_problematic():
- data = "abc\rdef\r\nghi"
- for _ in range(1, 10):
- test_stream = io.StringIO(data)
- lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=4))
- assert lines == ["abc\r", "def\r\n", "ghi"]
-
-
-def test_iter_functions_support_iterators():
- data = ["abcdef\r\nghi", "jkl\r\nmnopqrstuvwxyz\r", "\nABCDEFGHIJK"]
- lines = list(wsgi.make_line_iter(data))
- assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"]
-
-
-def test_make_chunk_iter():
- data = ["abcdefXghi", "jklXmnopqrstuvwxyzX", "ABCDEFGHIJK"]
- rv = list(wsgi.make_chunk_iter(data, "X"))
- assert rv == ["abcdef", "ghijkl", "mnopqrstuvwxyz", "ABCDEFGHIJK"]
-
- data = "abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK"
- test_stream = io.StringIO(data)
- rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4))
- assert rv == ["abcdef", "ghijkl", "mnopqrstuvwxyz", "ABCDEFGHIJK"]
-
-
-def test_make_chunk_iter_bytes():
- data = [b"abcdefXghi", b"jklXmnopqrstuvwxyzX", b"ABCDEFGHIJK"]
- rv = list(wsgi.make_chunk_iter(data, "X"))
- assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"]
-
- data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK"
- test_stream = io.BytesIO(data)
- rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4))
- assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"]
-
- data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK"
- test_stream = io.BytesIO(data)
- rv = list(
- wsgi.make_chunk_iter(
- test_stream, "X", limit=len(data), buffer_size=4, cap_at_buffer=True
- )
- )
- assert rv == [
- b"abcd",
- b"ef",
- b"ghij",
- b"kl",
- b"mnop",
- b"qrst",
- b"uvwx",
- b"yz",
- b"ABCD",
- b"EFGH",
- b"IJK",
- ]
-
-
-def test_lines_longer_buffer_size():
- data = "1234567890\n1234567890\n"
- for bufsize in range(1, 15):
- lines = list(
- wsgi.make_line_iter(io.StringIO(data), limit=len(data), buffer_size=bufsize)
- )
- assert lines == ["1234567890\n", "1234567890\n"]
-
-
-def test_lines_longer_buffer_size_cap():
- data = "1234567890\n1234567890\n"
- for bufsize in range(1, 15):
- lines = list(
- wsgi.make_line_iter(
- io.StringIO(data),
- limit=len(data),
- buffer_size=bufsize,
- cap_at_buffer=True,
- )
- )
- assert len(lines[0]) == bufsize or lines[0].endswith("\n")
-
-
def test_range_wrapper():
response = Response(b"Hello World")
range_wrapper = _RangeWrapper(response.response, 6, 4)
diff --git a/tox.ini b/tox.ini
index 056ca0d..f7bc0b3 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,19 +1,24 @@
[tox]
envlist =
- py3{11,10,9,8,7},pypy3{8,7}
+ py3{12,11,10,9,8}
+ pypy310
style
typing
docs
skip_missing_interpreters = true
[testenv]
+package = wheel
+wheel_build_env = .pkg
+constrain_package_deps = true
+use_frozen_constraints = true
deps = -r requirements/tests.txt
commands = pytest -v --tb=short --basetemp={envtmpdir} {posargs}
[testenv:style]
deps = pre-commit
skip_install = true
-commands = pre-commit run --all-files --show-diff-on-failure
+commands = pre-commit run --all-files
[testenv:typing]
deps = -r requirements/typing.txt
@@ -21,4 +26,18 @@ commands = mypy
[testenv:docs]
deps = -r requirements/docs.txt
-commands = sphinx-build -W -b html -d {envtmpdir}/doctrees docs {envtmpdir}/html
+commands = sphinx-build -E -W -b dirhtml docs docs/_build/dirhtml
+
+[testenv:update-requirements]
+deps =
+ pip-tools
+ pre-commit
+skip_install = true
+change_dir = requirements
+commands =
+ pre-commit autoupdate -j4
+ pip-compile -U build.in
+ pip-compile -U docs.in
+ pip-compile -U tests.in
+ pip-compile -U typing.in
+ pip-compile -U dev.in
Redirecting...
\n"
- b"Redirecting...
\n"
- b"Redirecting...
\n"
- b"