diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..b5c2efd --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,17 @@ +{ + "name": "pallets/werkzeug", + "image": "mcr.microsoft.com/devcontainers/python:3", + "customizations": { + "vscode": { + "settings": { + "python.defaultInterpreterPath": "${workspaceFolder}/.venv", + "python.terminal.activateEnvInCurrentTerminal": true, + "python.terminal.launchArgs": [ + "-X", + "dev" + ] + } + } + }, + "onCreateCommand": ".devcontainer/on-create-command.sh" +} diff --git a/.devcontainer/on-create-command.sh b/.devcontainer/on-create-command.sh new file mode 100755 index 0000000..eaebea6 --- /dev/null +++ b/.devcontainer/on-create-command.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -e +python3 -m venv --upgrade-deps .venv +. .venv/bin/activate +pip install -r requirements/dev.txt +pip install -e . +pre-commit install --install-hooks diff --git a/.editorconfig b/.editorconfig index e32c802..2ff985a 100644 --- a/.editorconfig +++ b/.editorconfig @@ -9,5 +9,5 @@ end_of_line = lf charset = utf-8 max_line_length = 88 -[*.{yml,yaml,json,js,css,html}] +[*.{css,html,js,json,jsx,scss,ts,tsx,yaml,yml}] indent_size = 2 diff --git a/.gitignore b/.gitignore index 36f3670..bbeb14f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,26 +1,11 @@ -MANIFEST -build -dist -/src/Werkzeug.egg-info -*.pyc -*.pyo -env -.DS_Store -docs/_build -bench/a -bench/b -.tox -.coverage -.coverage.* -coverage_out -htmlcov -.cache -.xprocess -.hypothesis -test_uwsgi_failed -.idea +.idea/ +.vscode/ +.venv*/ +venv*/ +__pycache__/ +dist/ +.coverage* +htmlcov/ .pytest_cache/ -venv/ -.vscode -.mypy_cache/ -.dmypy.json +.tox/ +docs/_build/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 55f8c13..8289161 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,44 +1,16 @@ ci: - autoupdate_branch: "2.2.x" autoupdate_schedule: monthly repos: - - repo: https://github.com/asottile/pyupgrade - rev: v2.37.3 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.5 hooks: - - id: pyupgrade - args: ["--py37-plus"] - - repo: https://github.com/asottile/reorder_python_imports - rev: v3.8.2 - hooks: - - id: reorder-python-imports - name: Reorder Python imports (src, tests) - files: "^(?!examples/)" - args: ["--application-directories", ".:src"] - additional_dependencies: ["setuptools>60.9"] - - id: reorder-python-imports - name: Reorder Python imports (examples) - files: "^examples/" - args: ["--application-directories", "examples"] - additional_dependencies: ["setuptools>60.9"] - - repo: https://github.com/psf/black - rev: 22.6.0 - hooks: - - id: black - - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 - hooks: - - id: flake8 - additional_dependencies: - - flake8-bugbear - - flake8-implicit-str-concat - - repo: https://github.com/peterdemin/pip-compile-multi - rev: v2.4.6 - hooks: - - id: pip-compile-multi-verify + - id: ruff + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.6.0 hooks: + - id: check-merge-conflict + - id: debug-statements - id: fix-byte-order-marker - id: trailing-whitespace - id: end-of-file-fixer - exclude: "^tests/.*.http$" diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 346900b..865c685 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,8 +1,8 @@ version: 2 build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: - python: "3.10" + python: '3.12' python: install: - requirements: requirements/docs.txt diff --git a/CHANGES.rst b/CHANGES.rst index 18e68af..f6158e7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,5 +1,313 @@ .. currentmodule:: werkzeug +Version 3.0.3 +------------- + +Released 2024-05-05 + +- Only allow ``localhost``, ``.localhost``, ``127.0.0.1``, or the specified + hostname when running the dev server, to make debugger requests. Additional + hosts can be added by using the debugger middleware directly. The debugger + UI makes requests using the full URL rather than only the path. + :ghsa:`2g68-c3qc-8985` +- Make reloader more robust when ``""`` is in ``sys.path``. :pr:`2823` +- Better TLS cert format with ``adhoc`` dev certs. :pr:`2891` +- Inform Python < 3.12 how to handle ``itms-services`` URIs correctly, rather + than using an overly-broad workaround in Werkzeug that caused some redirect + URIs to be passed on without encoding. :issue:`2828` +- Type annotation for ``Rule.endpoint`` and other uses of ``endpoint`` is + ``Any``. :issue:`2836` + + +Version 3.0.2 +------------- + +Released 2024-04-01 + +- Ensure setting ``merge_slashes`` to ``False`` results in ``NotFound`` for + repeated-slash requests against single slash routes. :issue:`2834` +- Fix handling of ``TypeError`` in ``TypeConversionDict.get()`` to match + ``ValueError``. :issue:`2843` +- Fix ``response_wrapper`` type check in test client. :issue:`2831` +- Make the return type of ``MultiPartParser.parse`` more precise. + :issue:`2840` +- Raise an error if converter arguments cannot be parsed. :issue:`2822` + + +Version 3.0.1 +------------- + +Released 2023-10-24 + +- Fix slow multipart parsing for large parts potentially enabling DoS attacks. + + +Version 3.0.0 +------------- + +Released 2023-09-30 + +- Remove previously deprecated code. :pr:`2768` +- Deprecate the ``__version__`` attribute. Use feature detection, or + ``importlib.metadata.version("werkzeug")``, instead. :issue:`2770` +- ``generate_password_hash`` uses scrypt by default. :issue:`2769` +- Add the ``"werkzeug.profiler"`` item to the WSGI ``environ`` dictionary + passed to `ProfilerMiddleware`'s `filename_format` function. It contains + the ``elapsed`` and ``time`` values for the profiled request. :issue:`2775` +- Explicitly marked the PathConverter as non path isolating. :pr:`2784` + + +Version 2.3.8 +------------- + +Released 2023-11-08 + +- Fix slow multipart parsing for large parts potentially enabling DoS + attacks. + + +Version 2.3.7 +------------- + +Released 2023-08-14 + +- Use ``flit_core`` instead of ``setuptools`` as build backend. +- Fix parsing of multipart bodies. :issue:`2734` +- Adjust index of last newline in data start. :issue:`2761` +- Parsing ints from header values strips spacing first. :issue:`2734` +- Fix empty file streaming when testing. :issue:`2740` +- Clearer error message when URL rule does not start with slash. :pr:`2750` +- ``Accept`` ``q`` value can be a float without a decimal part. :issue:`2751` + + +Version 2.3.6 +------------- + +Released 2023-06-08 + +- ``FileStorage.content_length`` does not fail if the form data did not provide a + value. :issue:`2726` + + +Version 2.3.5 +------------- + +Released 2023-06-07 + +- Python 3.12 compatibility. :issue:`2704` +- Fix handling of invalid base64 values in ``Authorization.from_header``. :issue:`2717` +- The debugger escapes the exception message in the page title. :pr:`2719` +- When binding ``routing.Map``, a long IDNA ``server_name`` with a port does not fail + encoding. :issue:`2700` +- ``iri_to_uri`` shows a deprecation warning instead of an error when passing bytes. + :issue:`2708` +- When parsing numbers in HTTP request headers such as ``Content-Length``, only ASCII + digits are accepted rather than any format that Python's ``int`` and ``float`` + accept. :issue:`2716` + + +Version 2.3.4 +------------- + +Released 2023-05-08 + +- ``Authorization.from_header`` and ``WWWAuthenticate.from_header`` detects tokens + that end with base64 padding (``=``). :issue:`2685` +- Remove usage of ``warnings.catch_warnings``. :issue:`2690` +- Remove ``max_form_parts`` restriction from standard form data parsing and only use + if for multipart content. :pr:`2694` +- ``Response`` will avoid converting the ``Location`` header in some cases to preserve + invalid URL schemes like ``itms-services``. :issue:`2691` + + +Version 2.3.3 +------------- + +Released 2023-05-01 + +- Fix parsing of large multipart bodies. Remove invalid leading newline, and restore + parsing speed. :issue:`2658, 2675` +- The cookie ``Path`` attribute is set to ``/`` by default again, to prevent clients + from falling back to RFC 6265's ``default-path`` behavior. :issue:`2672, 2679` + + +Version 2.3.2 +------------- + +Released 2023-04-28 + +- Parse the cookie ``Expires`` attribute correctly in the test client. :issue:`2669` +- ``max_content_length`` can only be enforced on streaming requests if the server + sets ``wsgi.input_terminated``. :issue:`2668` + + +Version 2.3.1 +------------- + +Released 2023-04-27 + +- Percent-encode plus (+) when building URLs and in test requests. :issue:`2657` +- Cookie values don't quote characters defined in RFC 6265. :issue:`2659` +- Include ``pyi`` files for ``datastructures`` type annotations. :issue:`2660` +- ``Authorization`` and ``WWWAuthenticate`` objects can be compared for equality. + :issue:`2665` + + +Version 2.3.0 +------------- + +Released 2023-04-25 + +- Drop support for Python 3.7. :pr:`2648` +- Remove previously deprecated code. :pr:`2592` +- Passing bytes where strings are expected is deprecated, as well as the ``charset`` + and ``errors`` parameters in many places. Anywhere that was annotated, documented, + or tested to accept bytes shows a warning. Removing this artifact of the transition + from Python 2 to 3 removes a significant amount of overhead in instance checks and + encoding cycles. In general, always work with UTF-8, the modern HTML, URL, and HTTP + standards all strongly recommend this. :issue:`2602` +- Deprecate the ``werkzeug.urls`` module, except for the ``uri_to_iri`` and + ``iri_to_uri`` functions. Use the ``urllib.parse`` library instead. :issue:`2600` +- Update which characters are considered safe when using percent encoding in URLs, + based on the WhatWG URL Standard. :issue:`2601` +- Update which characters are considered safe when using percent encoding for Unicode + filenames in downloads. :issue:`2598` +- Deprecate the ``safe_conversion`` parameter of ``iri_to_uri``. The ``Location`` + header is converted to IRI using the same process as everywhere else. :issue:`2609` +- Deprecate ``werkzeug.wsgi.make_line_iter`` and ``make_chunk_iter``. :pr:`2613` +- Use modern packaging metadata with ``pyproject.toml`` instead of ``setup.cfg``. + :pr:`2574` +- ``Request.get_json()`` will raise a ``415 Unsupported Media Type`` error if the + ``Content-Type`` header is not ``application/json``, instead of a generic 400. + :issue:`2550` +- A URL converter's ``part_isolating`` defaults to ``False`` if its ``regex`` contains + a ``/``. :issue:`2582` +- A custom converter's regex can have capturing groups without breaking the router. + :pr:`2596` +- The reloader can pick up arguments to ``python`` like ``-X dev``, and does not + require heuristics to determine how to reload the command. Only available + on Python >= 3.10. :issue:`2589` +- The Watchdog reloader ignores file opened events. Bump the minimum version of + Watchdog to 2.3.0. :issue:`2603` +- When using a Unix socket for the development server, the path can start with a dot. + :issue:`2595` +- Increase default work factor for PBKDF2 to 600,000 iterations. :issue:`2611` +- ``parse_options_header`` is 2-3 times faster. It conforms to :rfc:`9110`, some + invalid parts that were previously accepted are now ignored. :issue:`1628` +- The ``is_filename`` parameter to ``unquote_header_value`` is deprecated. :pr:`2614` +- Deprecate the ``extra_chars`` parameter and passing bytes to ``quote_header_value``, + the ``allow_token`` parameter to ``dump_header``, and the ``cls`` parameter and + passing bytes to ``parse_dict_header``. :pr:`2618` +- Improve ``parse_accept_header`` implementation. Parse according to :rfc:`9110`. + Discard items with invalid ``q`` values. :issue:`1623` +- ``quote_header_value`` quotes the empty string. :pr:`2618` +- ``dump_options_header`` skips ``None`` values rather than using a bare key. + :pr:`2618` +- ``dump_header`` and ``dump_options_header`` will not quote a value if the key ends + with an asterisk ``*``. +- ``parse_dict_header`` will decode values with charsets. :pr:`2618` +- Refactor the ``Authorization`` and ``WWWAuthenticate`` header data structures. + :issue:`1769`, :pr:`2619` + + - Both classes have ``type``, ``parameters``, and ``token`` attributes. The + ``token`` attribute supports auth schemes that use a single opaque token rather + than ``key=value`` parameters, such as ``Bearer``. + - Neither class is a ``dict`` anymore, although they still implement getting, + setting, and deleting ``auth[key]`` and ``auth.key`` syntax, as well as + ``auth.get(key)`` and ``key in auth``. + - Both classes have a ``from_header`` class method. ``parse_authorization_header`` + and ``parse_www_authenticate_header`` are deprecated. + - The methods ``WWWAuthenticate.set_basic`` and ``set_digest`` are deprecated. + Instead, an instance should be created and assigned to + ``response.www_authenticate``. + - A list of instances can be assigned to ``response.www_authenticate`` to set + multiple header values. However, accessing the property only returns the first + instance. + +- Refactor ``parse_cookie`` and ``dump_cookie``. :pr:`2637` + + - ``parse_cookie`` is up to 40% faster, ``dump_cookie`` is up to 60% faster. + - Passing bytes to ``parse_cookie`` and ``dump_cookie`` is deprecated. The + ``dump_cookie`` ``charset`` parameter is deprecated. + - ``dump_cookie`` allows ``domain`` values that do not include a dot ``.``, and + strips off a leading dot. + - ``dump_cookie`` does not set ``path="/"`` unnecessarily by default. + +- Refactor the test client cookie implementation. :issue:`1060, 1680` + + - The ``cookie_jar`` attribute is deprecated. ``http.cookiejar`` is no longer used + for storage. + - Domain and path matching is used when sending cookies in requests. The + ``domain`` and ``path`` parameters default to ``localhost`` and ``/``. + - Added a ``get_cookie`` method to inspect cookies. + - Cookies have ``decoded_key`` and ``decoded_value`` attributes to match what the + app sees rather than the encoded values a client would see. + - The first positional ``server_name`` parameter to ``set_cookie`` and + ``delete_cookie`` is deprecated. Use the ``domain`` parameter instead. + - Other parameters to ``delete_cookie`` besides ``domain``, ``path``, and + ``value`` are deprecated. + +- If ``request.max_content_length`` is set, it is checked immediately when accessing + the stream, and while reading from the stream in general, rather than only during + form parsing. :issue:`1513` +- The development server, which must not be used in production, will exhaust the + request stream up to 10GB or 1000 reads. This allows clients to see a 413 error if + ``max_content_length`` is exceeded, instead of a "connection reset" failure. + :pr:`2620` +- The development server discards header keys that contain underscores ``_``, as they + are ambiguous with dashes ``-`` in WSGI. :pr:`2622` +- ``secure_filename`` looks for more Windows reserved file names. :pr:`2623` +- Update type annotation for ``best_match`` to make ``default`` parameter clearer. + :issue:`2625` +- Multipart parser handles empty fields correctly. :issue:`2632` +- The ``Map`` ``charset`` parameter and ``Request.url_charset`` property are + deprecated. Percent encoding in URLs must always represent UTF-8 bytes. Invalid + bytes are left percent encoded rather than replaced. :issue:`2602` +- The ``Request.charset``, ``Request.encoding_errors``, ``Response.charset``, and + ``Client.charset`` attributes are deprecated. Request and response data must always + use UTF-8. :issue:`2602` +- Header values that have charset information only allow ASCII, UTF-8, and ISO-8859-1. + :pr:`2614, 2641` +- Update type annotation for ``ProfilerMiddleware`` ``stream`` parameter. + :issue:`2642` +- Use postponed evaluation of annotations. :pr:`2644` +- The development server escapes ASCII control characters in decoded URLs before + logging the request to the terminal. :pr:`2652` +- The ``FormDataParser`` ``parse_functions`` attribute and ``get_parse_func`` method, + and the invalid ``application/x-url-encoded`` content type, are deprecated. + :pr:`2653` +- ``generate_password_hash`` supports scrypt. Plain hash methods are deprecated, only + scrypt and pbkdf2 are supported. :issue:`2654` + + +Version 2.2.3 +------------- + +Released 2023-02-14 + +- Ensure that URL rules using path converters will redirect with strict slashes when + the trailing slash is missing. :issue:`2533` +- Type signature for ``get_json`` specifies that return type is not optional when + ``silent=False``. :issue:`2508` +- ``parse_content_range_header`` returns ``None`` for a value like ``bytes */-1`` + where the length is invalid, instead of raising an ``AssertionError``. :issue:`2531` +- Address remaining ``ResourceWarning`` related to the socket used by ``run_simple``. + Remove ``prepare_socket``, which now happens when creating the server. :issue:`2421` +- Update pre-existing headers for ``multipart/form-data`` requests with the test + client. :issue:`2549` +- Fix handling of header extended parameters such that they are no longer quoted. + :issue:`2529` +- ``LimitedStream.read`` works correctly when wrapping a stream that may not return + the requested size in one ``read`` call. :issue:`2558` +- A cookie header that starts with ``=`` is treated as an empty key and discarded, + rather than stripping the leading ``==``. +- Specify a maximum number of multipart parts, default 1000, after which a + ``RequestEntityTooLarge`` exception is raised on parsing. This mitigates a DoS + attack where a larger number of form/file parts would result in disproportionate + resource use. + + + Version 2.2.2 ------------- @@ -23,6 +331,7 @@ Released 2022-08-08 ``run_simple``. :issue:`2421` + Version 2.2.1 ------------- @@ -54,8 +363,9 @@ Released 2022-07-23 debug console. :pr:`2439` - Fix compatibility with Python 3.11 by ensuring that ``end_lineno`` and ``end_col_offset`` are present on AST nodes. :issue:`2425` -- Add a new faster matching router based on a state - machine. :pr:`2433` +- Add a new faster URL matching router based on a state machine. If a custom converter + needs to match a ``/`` it must set the class variable ``part_isolating = False``. + :pr:`2433` - Fix branch leaf path masking branch paths when strict-slashes is disabled. :issue:`1074` - Names within options headers are always converted to lowercase. This @@ -775,7 +1085,7 @@ Released 2019-03-19 (:pr:`1358`) - :func:`http.parse_cookie` ignores empty segments rather than producing a cookie with no key or value. (:issue:`1245`, :pr:`1301`) -- :func:`~http.parse_authorization_header` (and +- ``http.parse_authorization_header`` (and :class:`~datastructures.Authorization`, :attr:`~wrappers.Request.authorization`) treats the authorization header as UTF-8. On Python 2, basic auth username and password are @@ -1540,8 +1850,8 @@ Version 0.9.2 (bugfix release, released on July 18th 2013) -- Added `unsafe` parameter to :func:`~werkzeug.urls.url_quote`. -- Fixed an issue with :func:`~werkzeug.urls.url_quote_plus` not quoting +- Added ``unsafe`` parameter to ``urls.url_quote``. +- Fixed an issue with ``urls.url_quote_plus`` not quoting `'+'` correctly. - Ported remaining parts of :class:`~werkzeug.contrib.RedisCache` to Python 3.3. @@ -1590,9 +1900,8 @@ Released on June 13nd 2013, codename Planierraupe. certificates easily and load them from files. - Refactored test client to invoke the open method on the class for redirects. This makes subclassing more powerful. -- :func:`werkzeug.wsgi.make_chunk_iter` and - :func:`werkzeug.wsgi.make_line_iter` now support processing of - iterators and streams. +- ``wsgi.make_chunk_iter`` and ``make_line_iter`` now support processing + of iterators and streams. - URL generation by the routing system now no longer quotes ``+``. - URL fixing now no longer quotes certain reserved characters. @@ -1690,7 +1999,7 @@ Version 0.8.3 (bugfix release, released on February 5th 2012) -- Fixed another issue with :func:`werkzeug.wsgi.make_line_iter` +- Fixed another issue with ``wsgi.make_line_iter`` where lines longer than the buffer size were not handled properly. - Restore stdout after debug console finished executing so @@ -1758,7 +2067,7 @@ Released on September 29th 2011, codename Lötkolben - Werkzeug now uses a new method to check that the length of incoming data is complete and will raise IO errors by itself if the server fails to do so. -- :func:`~werkzeug.wsgi.make_line_iter` now requires a limit that is +- ``wsgi.make_line_iter`` now requires a limit that is not higher than the length the stream can provide. - Refactored form parsing into a form parser class that makes it possible to hook into individual parts of the parsing process for debugging and @@ -1958,7 +2267,7 @@ Released on Feb 19th 2010, codename Hammer. - the form data parser will now look at the filename instead the content type to figure out if it should treat the upload as regular form data or file upload. This fixes a bug with Google Chrome. -- improved performance of `make_line_iter` and the multipart parser +- improved performance of ``make_line_iter`` and the multipart parser for binary uploads. - fixed :attr:`~werkzeug.BaseResponse.is_streamed` - fixed a path quoting bug in `EnvironBuilder` that caused PATH_INFO and @@ -2087,7 +2396,7 @@ Released on April 24th, codename Schlagbohrer. - added :mod:`werkzeug.contrib.lint` - added `passthrough_errors` to `run_simple`. - added `secure_filename` -- added :func:`make_line_iter` +- added ``make_line_iter`` - :class:`MultiDict` copies now instead of revealing internal lists to the caller for `getlist` and iteration functions that return lists. diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 9f40800..97486de 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -7,19 +7,17 @@ Thank you for considering contributing to Werkzeug! Support questions ----------------- -Please don't use the issue tracker for this. The issue tracker is a -tool to address bugs and feature requests in Werkzeug itself. Use one of -the following resources for questions about using Werkzeug or issues -with your own code: - -- The ``#get-help`` channel on our Discord chat: - https://discord.gg/pallets -- The mailing list flask@python.org for long term discussion or larger - issues. +Please don't use the issue tracker for this. The issue tracker is a tool to address bugs +and feature requests in Werkzeug itself. Use one of the following resources for +questions about using Werkzeug or issues with your own code: + +- The ``#questions`` channel on our Discord chat: https://discord.gg/pallets - Ask on `Stack Overflow`_. Search with Google first using: ``site:stackoverflow.com werkzeug {search term, exception message, etc.}`` +- Ask on our `GitHub Discussions`_ for long term discussion or larger questions. .. _Stack Overflow: https://stackoverflow.com/questions/tagged/werkzeug?tab=Frequent +.. _GitHub Discussions: https://github.com/pallets/werkzeug/discussions Reporting issues @@ -66,9 +64,30 @@ Include the following in your patch: .. _pre-commit: https://pre-commit.com -First time setup -~~~~~~~~~~~~~~~~ +First time setup using GitHub Codespaces +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +`GitHub Codespaces`_ creates a development environment that is already set up for the +project. By default it opens in Visual Studio Code for the Web, but this can +be changed in your GitHub profile settings to use Visual Studio Code or JetBrains +PyCharm on your local computer. + +- Make sure you have a `GitHub account`_. +- From the project's repository page, click the green "Code" button and then "Create + codespace on main". +- The codespace will be set up, then Visual Studio Code will open. However, you'll + need to wait a bit longer for the Python extension to be installed. You'll know it's + ready when the terminal at the bottom shows that the virtualenv was activated. +- Check out a branch and `start coding`_. + +.. _GitHub Codespaces: https://docs.github.com/en/codespaces +.. _devcontainer: https://docs.github.com/en/codespaces/setting-up-your-project-for-codespaces/adding-a-dev-container-configuration/introduction-to-dev-containers + +First time setup in your local environment +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Make sure you have a `GitHub account`_. - Download and install the `latest version of git`_. - Configure git with your `username`_ and `email`_. @@ -77,99 +96,93 @@ First time setup $ git config --global user.name 'your name' $ git config --global user.email 'your email' -- Make sure you have a `GitHub account`_. - Fork Werkzeug to your GitHub account by clicking the `Fork`_ button. -- `Clone`_ the main repository locally. +- `Clone`_ your fork locally, replacing ``your-username`` in the command below with + your actual username. .. code-block:: text - $ git clone https://github.com/pallets/werkzeug + $ git clone https://github.com/your-username/werkzeug $ cd werkzeug -- Add your fork as a remote to push your work to. Replace - ``{username}`` with your username. This names the remote "fork", the - default Pallets remote is "origin". - - .. code-block:: text - - $ git remote add fork https://github.com/{username}/werkzeug +- Create a virtualenv. Use the latest version of Python. -- Create a virtualenv. + - Linux/macOS - .. code-block:: text - - $ python3 -m venv env - $ . env/bin/activate - - On Windows, activating is different. + .. code-block:: text - .. code-block:: text - - > env\Scripts\activate + $ python3 -m venv .venv + $ . .venv/bin/activate -- Upgrade pip and setuptools. + - Windows - .. code-block:: text + .. code-block:: text - $ python -m pip install --upgrade pip setuptools + > py -3 -m venv .venv + > .venv\Scripts\activate -- Install the development dependencies, then install Werkzeug in - editable mode. +- Install the development dependencies, then install Werkzeug in editable mode. .. code-block:: text + $ python -m pip install -U pip $ pip install -r requirements/dev.txt && pip install -e . - Install the pre-commit hooks. .. code-block:: text - $ pre-commit install + $ pre-commit install --install-hooks +.. _GitHub account: https://github.com/join .. _latest version of git: https://git-scm.com/downloads .. _username: https://docs.github.com/en/github/using-git/setting-your-username-in-git .. _email: https://docs.github.com/en/github/setting-up-and-managing-your-github-user-account/setting-your-commit-email-address -.. _GitHub account: https://github.com/join .. _Fork: https://github.com/pallets/werkzeug/fork .. _Clone: https://docs.github.com/en/github/getting-started-with-github/fork-a-repo#step-2-create-a-local-clone-of-your-fork +.. _start coding: + Start coding ~~~~~~~~~~~~ -- Create a branch to identify the issue you would like to work on. If - you're submitting a bug or documentation fix, branch off of the - latest ".x" branch. +- Create a branch to identify the issue you would like to work on. If you're + submitting a bug or documentation fix, branch off of the latest ".x" branch. .. code-block:: text $ git fetch origin - $ git checkout -b your-branch-name origin/2.0.x + $ git checkout -b your-branch-name origin/2.2.x - If you're submitting a feature addition or change, branch off of the - "main" branch. + If you're submitting a feature addition or change, branch off of the "main" branch. .. code-block:: text $ git fetch origin $ git checkout -b your-branch-name origin/main -- Using your favorite editor, make your changes, - `committing as you go`_. -- Include tests that cover any code changes you make. Make sure the - test fails without your patch. Run the tests as described below. -- Push your commits to your fork on GitHub and - `create a pull request`_. Link to the issue being addressed with - ``fixes #123`` in the pull request. +- Using your favorite editor, make your changes, `committing as you go`_. + + - If you are in a codespace, you will be prompted to `create a fork`_ the first + time you make a commit. Enter ``Y`` to continue. + +- Include tests that cover any code changes you make. Make sure the test fails without + your patch. Run the tests as described below. +- Push your commits to your fork on GitHub and `create a pull request`_. Link to the + issue being addressed with ``fixes #123`` in the pull request description. .. code-block:: text - $ git push --set-upstream fork your-branch-name + $ git push --set-upstream origin your-branch-name -.. _committing as you go: https://dont-be-afraid-to-commit.readthedocs.io/en/latest/git/commandlinegit.html#commit-your-changes +.. _committing as you go: https://afraid-to-commit.readthedocs.io/en/latest/git/commandlinegit.html#commit-your-changes +.. _create a fork: https://docs.github.com/en/codespaces/developing-in-codespaces/using-source-control-in-your-codespace#about-automatic-forking .. _create a pull request: https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request +.. _Running the tests: + Running the tests ~~~~~~~~~~~~~~~~~ diff --git a/LICENSE.rst b/LICENSE.txt similarity index 100% rename from LICENSE.rst rename to LICENSE.txt diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 8942481..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,12 +0,0 @@ -include CHANGES.rst -include tox.ini -include requirements/*.txt -graft artwork -graft docs -prune docs/_build -graft examples -graft tests -include src/werkzeug/py.typed -include src/werkzeug/*.pyi -graft src/werkzeug/debug/shared -global-exclude *.pyc diff --git a/README.rst b/README.md similarity index 54% rename from README.rst rename to README.md index f1592a5..011c0c4 100644 --- a/README.rst +++ b/README.md @@ -1,9 +1,8 @@ -Werkzeug -======== +# Werkzeug *werkzeug* German noun: "tool". Etymology: *werk* ("work"), *zeug* ("stuff") -Werkzeug is a comprehensive `WSGI`_ web application library. It began as +Werkzeug is a comprehensive [WSGI][] web application library. It began as a simple collection of various utilities for WSGI applications and has become one of the most advanced WSGI utility libraries. @@ -31,61 +30,40 @@ choose a template engine, database adapter, and even how to handle requests. It can be used to build all sorts of end user applications such as blogs, wikis, or bulletin boards. -`Flask`_ wraps Werkzeug, using it to handle the details of WSGI while +[Flask][] wraps Werkzeug, using it to handle the details of WSGI while providing more structure and patterns for defining powerful applications. -.. _WSGI: https://wsgi.readthedocs.io/en/latest/ -.. _Flask: https://www.palletsprojects.com/p/flask/ +[WSGI]: https://wsgi.readthedocs.io/en/latest/ +[Flask]: https://www.palletsprojects.com/p/flask/ -Installing ----------- +## A Simple Example -Install and update using `pip`_: +```python +# save this as app.py +from werkzeug.wrappers import Request, Response -.. code-block:: text +@Request.application +def application(request: Request) -> Response: + return Response("Hello, World!") - pip install -U Werkzeug +if __name__ == "__main__": + from werkzeug.serving import run_simple + run_simple("127.0.0.1", 5000, application) +``` -.. _pip: https://pip.pypa.io/en/stable/getting-started/ +``` +$ python -m app + * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit) +``` -A Simple Example ----------------- - -.. code-block:: python - - from werkzeug.wrappers import Request, Response - - @Request.application - def application(request): - return Response('Hello, World!') - - if __name__ == '__main__': - from werkzeug.serving import run_simple - run_simple('localhost', 4000, application) - - -Donate ------- +## Donate The Pallets organization develops and supports Werkzeug and other popular packages. In order to grow the community of contributors and users, and allow the maintainers to devote more time to the projects, -`please donate today`_. - -.. _please donate today: https://palletsprojects.com/donate - - -Links ------ +[please donate today][]. -- Documentation: https://werkzeug.palletsprojects.com/ -- Changes: https://werkzeug.palletsprojects.com/changes/ -- PyPI Releases: https://pypi.org/project/Werkzeug/ -- Source Code: https://github.com/pallets/werkzeug/ -- Issue Tracker: https://github.com/pallets/werkzeug/issues/ -- Website: https://palletsprojects.com/p/werkzeug/ -- Twitter: https://twitter.com/PalletsTeam -- Chat: https://discord.gg/pallets +[please donate today]: https://palletsprojects.com/donate diff --git a/artwork/logo.png b/artwork/logo.png deleted file mode 100644 index 61666ab..0000000 Binary files a/artwork/logo.png and /dev/null differ diff --git a/artwork/logo.svg b/artwork/logo.svg deleted file mode 100644 index bd65219..0000000 --- a/artwork/logo.svg +++ /dev/null @@ -1,88 +0,0 @@ - - - - - - - - - image/svg+xml - - - - - - - - - - - diff --git a/debian/.salsa-ci.yml b/debian/.salsa-ci.yml new file mode 100644 index 0000000..a5957e7 --- /dev/null +++ b/debian/.salsa-ci.yml @@ -0,0 +1,8 @@ +--- +include: + - https://salsa.debian.org/salsa-ci-team/pipeline/raw/master/salsa-ci.yml + - https://salsa.debian.org/salsa-ci-team/pipeline/raw/master/pipeline-jobs.yml + +variables: + SALSA_CI_DISABLE_BLHC: 1 + SALSA_CI_DISABLE_BUILD_PACKAGE_ANY: 1 diff --git a/debian/README.source b/debian/README.source new file mode 100644 index 0000000..c9ddf90 --- /dev/null +++ b/debian/README.source @@ -0,0 +1,29 @@ +This package is maintained with git-buildpackage(1). It follows DEP-14 for +branch naming (e.g. using debian/master for the current version in Debian +unstable due Debian Python team policy). + +It uses pristine-tar(1) to store enough information in git to generate bit +identical tarballs when building the package without having downloaded an +upstream tarball first. + +When working with patches it is recommended to use "gbp pq import" to import +the patches, modify the source and then use "gbp pq export --commit" to commit +the modifications. + +The changelog is generated using "gbp dch" so if you submit any changes don't +bother to add changelog entries but rather provide a nice git commit message +that can then end up in the changelog. + +It is recommended to build the package with pbuilder using: + + gbp buildpackage --git-pbuilder + +For information on how to set up a pbuilder environment see the git-pbuilder(1) +manpage. In short: + + DIST=sid git-pbuilder create + gbp clone https://salsa.debian.org/python-team/packages/python-werkzeug.git + cd python-werkzeug + gbp buildpackage --git-pbuilder + + -- Carsten Schoenert Sun, 17 Dec 2023 14:49:00 +0200 diff --git a/debian/changelog b/debian/changelog index cdb335e..d20a26c 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,99 @@ +python-werkzeug (3.0.3-1) unstable; urgency=medium + + * Team upload + * [32a577e] New upstream version 3.0.3 + Fixes CVE-2024-34069 + (Closes: #1070711) + * [9f49688] Rebuild patch queue drom patch-queue branch + Adjusted patch: + docs-Use-intersphix-with-Debian-packages.patch + * [98b2263] d/copyright: Update year data + * [730be75] d/control: Drop python3-sphinx-issues from B-D + * [a4bd5f5] d/control: Bump Standards-Version to 4.7.0 + No further changes needed. + + -- Carsten Schoenert Thu, 09 May 2024 12:58:49 +0200 + +python-werkzeug (3.0.2-1) unstable; urgency=medium + + * Team upload + + [ Ondřej Nový ] + * [44b776b] Remove myself from Uploaders. + + [ Carsten Schoenert ] + * [ae61a23] New upstream version 3.0.2 + * [f5bd2de] Rebuild patch queue drom patch-queue branch + Dropped patch: + Fix-test-failure-with-Pytest-8.0.patch + + -- Carsten Schoenert Fri, 22 Mar 2024 19:43:39 +0100 + +python-werkzeug (3.0.1-3) unstable; urgency=medium + + * Team upload + * Rebuild patch queue drom patch-queue branch + Added patch: + Fix-test-failure-with-Pytest-8.0.patch + (Closes: 1063983) + + -- Carsten Schoenert Sat, 24 Feb 2024 08:29:01 +0100 + +python-werkzeug (3.0.1-2) unstable; urgency=medium + + * Team upload + + [ Julian Gilbey ] + * Add python3-markupsafe Build-Depends + + [ Carsten Schoenert ] + * Upload to unstable + Fixes CVE-2023-46136 + (Closes: #1054553, #1058244) + + + -- Carsten Schoenert Fri, 09 Feb 2024 19:32:22 +0100 + +python-werkzeug (3.0.1-1) experimental; urgency=medium + + * Team upload + * [b97bb12] d/gbp.conf: Add some basic defaults for git-buildpackage + * [9ba65b6] d/watch: Update and move to version 4 + * [8e91a92] d/gbp.conf: Adjust to debian/experimental + * [1d55463] d/copyright: Drop Files-Excluded, not needed any more + * [cffc6d3] New upstream version 3.0.1 + Fixes CVE-2023-46136 + (Closes: #1054553, #1058244) + * [80a2855] Rebuild patch queue drom patch-queue branch + Added patches: + docs-Use-intersphix-with-Debian-packages.patch + Renamed patch: + preserve-any-existing-PYTHONPATH-in-tests.patch + -> tests-Preserve-any-existing-PYTHONPATH-in-tests.patch + Dropped patches (included upstream): + 0003-don-t-strip-leading-when-parsing-cookie.patch + 0004-limit-the-maximum-number-of-multipart-form-parts.patch + Dropped patch (not needed any more): + remove-test_exclude_patterns-test.patch + * [385989d] d/control: Bump Standards-Version to 4.6.2 + No further changes needed. + * [3e6cbd1] d/control: Update B-D as upstream has moved to flit + * [80afaf5] d/control: Set nodoc build profile for python3-doc + * [7468a1a] d/rules: Drop --with option in default target + * [ca345a4] d/rules: Improve build of Sphinx based documentation + * [ed1f0ae] d/rules: Remove file ICON_LICENSE.md in bin package + * [98e1776] d/rules: Add removal of .mypy_cache to dh_clean + * [9edc6e8] autopkgtest: Use more specific test dependencies + * [2f48efc] d/copyright: Update content and copyright holders + * [d3cbf83] d/u/metadata: Small updates + * [daeb503] d/control: Reflect GitHub project website as Homepage + * [e6c630c] d/python3-werkzeug.links: Remove sequencer, it's obsolet + * [1db0b73] d/p-w-d.links: Remove linking to obsolet packages + * [64bae1b] d/README.source: Adding a README file about source specifics + * [774528e] d/salsa-ci.yml: Adding trigger file for Salsa CI + + -- Carsten Schoenert Wed, 20 Dec 2023 19:22:33 +0100 + python-werkzeug (2.2.2-3) unstable; urgency=medium [ Robin Gustafsson ] diff --git a/debian/control b/debian/control index bc5b1f2..2efabf2 100644 --- a/debian/control +++ b/debian/control @@ -4,26 +4,27 @@ Priority: optional Maintainer: Debian Python Team Uploaders: Thomas Goirand , - Ondřej Nový , -Standards-Version: 4.6.1 +Standards-Version: 4.7.0 Build-Depends: debhelper-compat (= 13), - dh-python, + dh-sequence-python3, + dh-sequence-sphinxdoc, + flit, + pybuild-plugin-pyproject, python3-all, python3-cryptography , - python3-doc, + python3-doc , python3-ephemeral-port-reserve , python3-greenlet , + python3-markupsafe, python3-pallets-sphinx-themes , python3-pytest , python3-pytest-timeout , python3-pytest-xprocess , - python3-setuptools, python3-sphinx , - python3-sphinx-issues , python3-sphinxcontrib-log-cabinet , python3-watchdog , -Homepage: http://werkzeug.pocoo.org/ +Homepage: https://github.com/pallets/werkzeug/ Vcs-Git: https://salsa.debian.org/python-team/packages/python-werkzeug.git Vcs-Browser: https://salsa.debian.org/python-team/packages/python-werkzeug Testsuite: autopkgtest-pkg-python diff --git a/debian/copyright b/debian/copyright index f07855c..fd67950 100644 --- a/debian/copyright +++ b/debian/copyright @@ -1,13 +1,25 @@ Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Upstream-Name: Werkzeug Upstream-Contact: Armin Ronacher -Upstream-Source: http://werkzeug.pocoo.org/download -Files-Excluded: src/werkzeug/debug/shared/ubuntu.ttf - src/werkzeug/debug/shared/FONT_LICENSE - src/werkzeug/debug/shared/jquery.js +Upstream-Source: https://github.com/pallets/werkzeug/ Files: * -Copyright: Copyright 2007 Pallets +Copyright: Copyright 2007-2024 Pallets +License: BSD-3-clause + +Files: debian/* +Copyright: Copyright 2009, Noah Slater + (c) 2016-2020, Ondřej Nový + 2022, Thomas Goirand + 2023-2024, Carsten Schoenert +License: GAP + +Files: src/werkzeug/debug/shared/*.png +Copyright: Mark James +License: CC-BY-SA-2.5 or CC-BY-SA-3.0 +Comment: Originally on http://www.famfamfam.com/lab/icons/silk/ (not usable + any more) + License: BSD-3-clause Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,9 +48,589 @@ License: BSD-3-clause NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -Files: debian/* -Copyright: Copyright 2009, Noah Slater - (c) 2016-2020, Ondřej Nový +License: CC-BY-SA-2.5 + THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS + PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR + OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS + LICENSE OR COPYRIGHT LAW IS PROHIBITED. + . + BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE + BOUND BY THE TERMS OF THIS LICENSE. THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED + HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS. + . + 1. Definitions + . + a. "Collective Work" means a work, such as a periodical issue, anthology or + encyclopedia, in which the Work in its entirety in unmodified form, along + with a number of other contributions, constituting separate and + independent works in themselves, are assembled into a collective whole. A + work that constitutes a Collective Work will not be considered a + Derivative Work (as defined below) for the purposes of this License. + b. "Derivative Work" means a work based upon the Work or upon the Work and + other pre-existing works, such as a translation, musical arrangement, + dramatization, fictionalization, motion picture version, sound recording, + art reproduction, abridgment, condensation, or any other form in which the + Work may be recast, transformed, or adapted, except that a work that + constitutes a Collective Work will not be considered a Derivative Work + for the purpose of this License. For the avoidance of doubt, where the + Work is a musical composition or sound recording, the synchronization of + the Work in timed-relation with a moving image ("synching") will be + considered a Derivative Work for the purpose of this License. + c. "Licensor" means the individual or entity that offers the Work under the + terms of this License. + d. "Original Author" means the individual or entity who created the Work. + e. "Work" means the copyrightable work of authorship offered under the terms + of this License. + f. "You" means an individual or entity exercising rights under this License + who has not previously violated the terms of this License with respect to + the Work, or who has received express permission from the Licensor to + exercise rights under this License despite a previous violation. + g. "License Elements" means the following high-level license attributes as + selected by Licensor and indicated in the title of this License: + Attribution, ShareAlike. + . + 2. Fair Use Rights. + . + Nothing in this license is intended to reduce, limit, or restrict any rights + arising from fair use, first sale or other limitations on the exclusive rights + of the copyright owner under copyright law or other applicable laws. + . + 3. License Grant. + . + Subject to the terms and conditions of this License, Licensor hereby grants + You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of + the applicable copyright) license to exercise the rights in the Work as stated + below: + . + a. to reproduce the Work, to incorporate the Work into one or more Collective + Works, and to reproduce the Work as incorporated in the Collective Works; + b. to create and reproduce Derivative Works; + c. to distribute copies or phonorecords of, display publicly, perform + publicly, and perform publicly by means of a digital audio transmission + the Work including as incorporated in Collective Works; + d. to distribute copies or phonorecords of, display publicly, perform + publicly, and perform publicly by means of a digital audio transmission + Derivative Works. + e. For the avoidance of doubt, where the work is a musical composition: + i. Performance Royalties Under Blanket Licenses. Licensor waives the + exclusive right to collect, whether individually or via a + performance rights society (e.g. ASCAP, BMI, SESAC), royalties for + the public performance or public digital performance (e.g. webcast) + of the Work. + ii. Mechanical Rights and Statutory Royalties. Licensor waives the + exclusive right to collect, whether individually or via a music + rights society or designated agent (e.g. Harry Fox Agency), + royalties for any phonorecord You create from the Work ("cover + version") and distribute, subject to the compulsory license created + by 17 USC Section 115 of the US Copyright Act (or the equivalent in + other jurisdictions). + f. Webcasting Rights and Statutory Royalties. For the avoidance of doubt, + where the Work is a sound recording, Licensor waives the exclusive right + to collect, whether individually or via a performance-rights society (e.g. + SoundExchange), royalties for the public digital performance (e.g. + webcast) of the Work, subject to the compulsory license created by 17 USC + Section 114 of the US Copyright Act (or the equivalent in other + jurisdictions). + . + The above rights may be exercised in all media and formats whether now known or + hereafter devised. The above rights include the right to make such modifications + as are technically necessary to exercise the rights in other media and formats. + All rights not expressly granted by Licensor are hereby reserved. + . + 4. Restrictions. + . + The license granted in Section 3 above is expressly made subject to and limited + by the following restrictions: + . + a. You may distribute, publicly display, publicly perform, or publicly + digitally perform the Work only under the terms of this License, and You + must include a copy of, or the Uniform Resource Identifier for, this + License with every copy or phonorecord of the Work You distribute, + publicly display, publicly perform, or publicly digitally perform. You may + not offer or impose any terms on the Work that alter or restrict the terms + of this License or the recipients' exercise of the rights granted + hereunder. You may not sublicense the Work. You must keep intact all + notices that refer to this License and to the disclaimer of warranties. + You may not distribute, publicly display, publicly perform, or publicly + digitally perform the Work with any technological measures that control + access or use of the Work in a manner inconsistent with the terms of this + License Agreement. The above applies to the Work as incorporated in a + Collective Work, but this does not require the Collective Work apart from + the Work itself to be made subject to the terms of this License. If You + create a Collective Work, upon notice from any Licensor You must, to the + extent practicable, remove from the Collective Work any credit as required + by clause 4(c), as requested. If You create a Derivative Work, upon notice + from any Licensor You must, to the extent practicable, remove from the + Derivative Work any credit as required by clause 4(c), as requested. + b. You may distribute, publicly display, publicly perform, or publicly + digitally perform a Derivative Work only under the terms of this License, + a later version of this License with the same License Elements as this + License, or a Creative Commons iCommons license that contains the same + License Elements as this License (e.g. Attribution-ShareAlike 2.5 Japan). + You must include a copy of, or the Uniform Resource Identifier for, this + License or other license specified in the previous sentence with every + copy or phonorecord of each Derivative Work You distribute, publicly + display, publicly perform, or publicly digitally perform. You may not + offer or impose any terms on the Derivative Works that alter or restrict + the terms of this License or the recipients' exercise of the rights + granted hereunder, and You must keep intact all notices that refer to + this License and to the disclaimer of warranties. You may not distribute, + publicly display, publicly perform, or publicly digitally perform the + Derivative Work with any technological measures that control access or + use of the Work in a manner inconsistent with the terms of this License + Agreement. The above applies to the Derivative Work as incorporated in + a Collective Work, but this does not require the Collective Work apart + from the Derivative Work itself to be made subject to the terms of this + License. + c. If you distribute, publicly display, publicly perform, or publicly + digitally perform the Work or any Derivative Works or Collective Works, + You must keep intact all copyright notices for the Work and provide, + reasonable to the medium or means You are utilizing: (i) the name of the + Original Author (or pseudonym, if applicable) if supplied, and/or (ii) if + the Original Author and/or Licensor designate another party or parties + (e.g. a sponsor institute, publishing entity, journal) for attribution + in Licensor's copyright notice, terms of service or by other reasonable + means, the name of such party or parties; the title of the Work if + supplied; to the extent reasonably practicable, the Uniform Resource + Identifier, if any, that Licensor specifies to be associated with the + Work, unless such URI does not refer to the copyright notice or licensing + information for the Work; and in the case of a Derivative Work, a credit + identifying the use of the Work in the Derivative Work (e.g., "French + translation of the Work by Original Author," or "Screenplay based on + original Work by Original Author"). Such credit may be implemented in any + reasonable manner; provided, however, that in the case of a Derivative + Work or Collective Work, at a minimum such credit will appear where any + other comparable authorship credit appears and in a manner at least as + prominent as such other comparable authorship credit. + . + 5. Representations, Warranties and Disclaimer + . + UNLESS OTHERWISE AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE + WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING + THE MATERIALS, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT + LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR + PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, + OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME + JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH + EXCLUSION MAY NOT APPLY TO YOU. + . + 6. Limitation on Liability. + . + EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE + LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, + PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE + WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + . + 7. Termination + . + a. This License and the rights granted hereunder will terminate automatically + upon any breach by You of the terms of this License. Individuals or + entities who have received Derivative Works or Collective Works from You + under this License, however, will not have their licenses terminated + provided such individuals or entities remain in full compliance + with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any + termination of this License. + b. Subject to the above terms and conditions, the license granted here is + perpetual (for the duration of the applicable copyright in the Work). + Notwithstanding the above, Licensor reserves the right to release the + Work under different license terms or to stop distributing the Work at + any time; provided, however that any such election will not serve to + withdraw this License (or any other license that has been, or is required + to be, granted under the terms of this License), and this License will + continue in full force and effect unless terminated as stated above. + . + 8. Miscellaneous + . + a. Each time You distribute or publicly digitally perform the Work or a + Collective Work, the Licensor offers to the recipient a license to the + Work on the same terms and conditions as the license granted to You under + this License. + b. Each time You distribute or publicly digitally perform a Derivative Work, + Licensor offers to the recipient a license to the original Work on the + same terms and conditions as the license granted to You under this + License. + c. If any provision of this License is invalid or unenforceable under + applicable law, it shall not affect the validity or enforceability of the + remainder of the terms of this License, and without further action by + the parties to this agreement, such provision shall be reformed to the + minimum extent necessary to make such provision valid and enforceable. + d. No term or provision of this License shall be deemed waived and no breach + consented to unless such waiver or consent shall be in writing and signed + by the party to be charged with such waiver or consent. + e. This License constitutes the entire agreement between the parties with + respect to the Work licensed here. There are no understandings, + agreements or representations with respect to the Work not specified + here. Licensor shall not be bound by any additional provisions that may + appear in any communication from You. This License may not be modified + without the mutual written agreement of the Licensor and You. + +License: CC-BY-SA-3.0 + THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS + CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS + PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE + WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS + PROHIBITED. + . + BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND + AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS + LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU + THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH + TERMS AND CONDITIONS. + . + . + 1. Definitions + . + "Adaptation" means a work based upon the Work, or upon the Work and + other pre-existing works, such as a translation, adaptation, + derivative work, arrangement of music or other alterations of a + literary or artistic work, or phonogram or performance and includes + cinematographic adaptations or any other form in which the Work may be + recast, transformed, or adapted including in any form recognizably + derived from the original, except that a work that constitutes a + Collection will not be considered an Adaptation for the purpose of + this License. For the avoidance of doubt, where the Work is a musical + work, performance or phonogram, the synchronization of the Work in + timed-relation with a moving image ("synching") will be considered an + Adaptation for the purpose of this License. + . + "Collection" means a collection of literary or artistic works, such as + encyclopedias and anthologies, or performances, phonograms or + broadcasts, or other works or subject matter other than works listed + in Section 1(f) below, which, by reason of the selection and + arrangement of their contents, constitute intellectual creations, in + which the Work is included in its entirety in unmodified form along + with one or more other contributions, each constituting separate and + independent works in themselves, which together are assembled into a + collective whole. A work that constitutes a Collection will not be + considered an Adaptation (as defined below) for the purposes of this + License. + . + "Creative Commons Compatible License" means a license that is listed + at http://creativecommons.org/compatiblelicenses that has been + approved by Creative Commons as being essentially equivalent to this + License, including, at a minimum, because that license: (i) contains + terms that have the same purpose, meaning and effect as the License + Elements of this License; and, (ii) explicitly permits the relicensing + of adaptations of works made available under that license under this + License or a Creative Commons jurisdiction license with the same + License Elements as this License. + . + "Distribute" means to make available to the public the original and + copies of the Work or Adaptation, as appropriate, through sale or + other transfer of ownership. + . + "License Elements" means the following high-level license attributes + as selected by Licensor and indicated in the title of this License: + Attribution, ShareAlike. + . + "Licensor" means the individual, individuals, entity or entities that + offer(s) the Work under the terms of this License. + . + "Original Author" means, in the case of a literary or artistic work, + the individual, individuals, entity or entities who created the Work + or if no individual or entity can be identified, the publisher; and in + addition (i) in the case of a performance the actors, singers, + musicians, dancers, and other persons who act, sing, deliver, declaim, + play in, interpret or otherwise perform literary or artistic works or + expressions of folklore; (ii) in the case of a phonogram the producer + being the person or legal entity who first fixes the sounds of a + performance or other sounds; and, (iii) in the case of broadcasts, the + organization that transmits the broadcast. + . + "Work" means the literary and/or artistic work offered under the terms + of this License including without limitation any production in the + literary, scientific and artistic domain, whatever may be the mode or + form of its expression including digital form, such as a book, + pamphlet and other writing; a lecture, address, sermon or other work + of the same nature; a dramatic or dramatico-musical work; a + choreographic work or entertainment in dumb show; a musical + composition with or without words; a cinematographic work to which are + assimilated works expressed by a process analogous to cinematography; + a work of drawing, painting, architecture, sculpture, engraving or + lithography; a photographic work to which are assimilated works + expressed by a process analogous to photography; a work of applied + art; an illustration, map, plan, sketch or three-dimensional work + relative to geography, topography, architecture or science; a + performance; a broadcast; a phonogram; a compilation of data to the + extent it is protected as a copyrightable work; or a work performed by + a variety or circus performer to the extent it is not otherwise + considered a literary or artistic work. + . + "You" means an individual or entity exercising rights under this + License who has not previously violated the terms of this License with + respect to the Work, or who has received express permission from the + Licensor to exercise rights under this License despite a previous + violation. + . + "Publicly Perform" means to perform public recitations of the Work and + to communicate to the public those public recitations, by any means or + process, including by wire or wireless means or public digital + performances; to make available to the public Works in such a way that + members of the public may access these Works from a place and at a + place individually chosen by them; to perform the Work to the public + by any means or process and the communication to the public of the + performances of the Work, including by public digital performance; to + broadcast and rebroadcast the Work by any means including signs, + sounds or images. + . + "Reproduce" means to make copies of the Work by any means including + without limitation by sound or visual recordings and the right of + fixation and reproducing fixations of the Work, including storage of a + protected performance or phonogram in digital form or other electronic + medium. + . + . + 2. Fair Dealing Rights. Nothing in this License is intended to reduce, + limit, or restrict any uses free from copyright or rights arising from + limitations or exceptions that are provided for in connection with the + copyright protection under copyright law or other applicable laws. + . + . + 3. License Grant. Subject to the terms and conditions of this License, + Licensor hereby grants You a worldwide, royalty-free, non-exclusive, + perpetual (for the duration of the applicable copyright) license to + exercise the rights in the Work as stated below: + . + - to Reproduce the Work, to incorporate the Work into one or more + Collections, and to Reproduce the Work as incorporated in the + Collections; + . + - to create and Reproduce Adaptations provided that any such + Adaptation, including any translation in any medium, takes + reasonable steps to clearly label, demarcate or otherwise identify + that changes were made to the original Work. For example, a + translation could be marked "The original work was translated from + English to Spanish," or a modification could indicate "The original + work has been modified."; + . + - to Distribute and Publicly Perform the Work including as + incorporated in Collections; and, + . + - to Distribute and Publicly Perform Adaptations. + . + - For the avoidance of doubt: + - Non-waivable Compulsory License Schemes. In those jurisdictions + in which the right to collect royalties through any statutory or + compulsory licensing scheme cannot be waived, the Licensor + reserves the exclusive right to collect such royalties for any + exercise by You of the rights granted under this License; + - Waivable Compulsory License Schemes. In those jurisdictions in + which the right to collect royalties through any statutory or + compulsory licensing scheme can be waived, the Licensor waives + the exclusive right to collect such royalties for any exercise + by You of the rights granted under this License; and, + - Voluntary License Schemes. The Licensor waives the right to + collect royalties, whether individually or, in the event that + the Licensor is a member of a collecting society that + administers voluntary licensing schemes, via that society, from + any exercise by You of the rights granted under this License. + . + The above rights may be exercised in all media and formats whether now + known or hereafter devised. The above rights include the right to make + such modifications as are technically necessary to exercise the rights + in other media and formats. Subject to Section 8(f), all rights not + expressly granted by Licensor are hereby reserved. + . + . + 4. Restrictions. The license granted in Section 3 above is expressly + made subject to and limited by the following restrictions: + . + - You may Distribute or Publicly Perform the Work only under the + terms of this License. You must include a copy of, or the Uniform + Resource Identifier (URI) for, this License with every copy of the + Work You Distribute or Publicly Perform. You may not offer or + impose any terms on the Work that restrict the terms of this + License or the ability of the recipient of the Work to exercise the + rights granted to that recipient under the terms of the + License. You may not sublicense the Work. You must keep intact all + notices that refer to this License and to the disclaimer of + warranties with every copy of the Work You Distribute or Publicly + Perform. When You Distribute or Publicly Perform the Work, You may + not impose any effective technological measures on the Work that + restrict the ability of a recipient of the Work from You to + exercise the rights granted to that recipient under the terms of + the License. This Section 4(a) applies to the Work as incorporated + in a Collection, but this does not require the Collection apart + from the Work itself to be made subject to the terms of this + License. If You create a Collection, upon notice from any Licensor + You must, to the extent practicable, remove from the Collection any + credit as required by Section 4(c), as requested. If You create an + Adaptation, upon notice from any Licensor You must, to the extent + practicable, remove from the Adaptation any credit as required by + Section 4(c), as requested. + . + - You may Distribute or Publicly Perform an Adaptation only under the + terms of: (i) this License; (ii) a later version of this License + with the same License Elements as this License; (iii) a Creative + Commons jurisdiction license (either this or a later license + version) that contains the same License Elements as this License + (e.g., Attribution-ShareAlike 3.0 US)); (iv) a Creative Commons + Compatible License. If you license the Adaptation under one of the + licenses mentioned in (iv), you must comply with the terms of that + license. If you license the Adaptation under the terms of any of + the licenses mentioned in (i), (ii) or (iii) (the "Applicable + License"), you must comply with the terms of the Applicable License + generally and the following provisions: (I) You must include a copy + of, or the URI for, the Applicable License with every copy of each + Adaptation You Distribute or Publicly Perform; (II) You may not + offer or impose any terms on the Adaptation that restrict the terms + of the Applicable License or the ability of the recipient of the + Adaptation to exercise the rights granted to that recipient under + the terms of the Applicable License; (III) You must keep intact all + notices that refer to the Applicable License and to the disclaimer + of warranties with every copy of the Work as included in the + Adaptation You Distribute or Publicly Perform; (IV) when You + Distribute or Publicly Perform the Adaptation, You may not impose + any effective technological measures on the Adaptation that + restrict the ability of a recipient of the Adaptation from You to + exercise the rights granted to that recipient under the terms of + the Applicable License. This Section 4(b) applies to the Adaptation + as incorporated in a Collection, but this does not require the + Collection apart from the Adaptation itself to be made subject to + the terms of the Applicable License. + . + - If You Distribute, or Publicly Perform the Work or any Adaptations + or Collections, You must, unless a request has been made pursuant + to Section 4(a), keep intact all copyright notices for the Work and + provide, reasonable to the medium or means You are utilizing: (i) + the name of the Original Author (or pseudonym, if applicable) if + supplied, and/or if the Original Author and/or Licensor designate + another party or parties (e.g., a sponsor institute, publishing + entity, journal) for attribution ("Attribution Parties") in + Licensor's copyright notice, terms of service or by other + reasonable means, the name of such party or parties; (ii) the title + of the Work if supplied; (iii) to the extent reasonably + practicable, the URI, if any, that Licensor specifies to be + associated with the Work, unless such URI does not refer to the + copyright notice or licensing information for the Work; and (iv) , + consistent with Ssection 3(b), in the case of an Adaptation, a + credit identifying the use of the Work in the Adaptation (e.g., + "French translation of the Work by Original Author," or "Screenplay + based on original Work by Original Author"). The credit required by + this Section 4(c) may be implemented in any reasonable manner; + provided, however, that in the case of a Adaptation or Collection, + at a minimum such credit will appear, if a credit for all + contributing authors of the Adaptation or Collection appears, then + as part of these credits and in a manner at least as prominent as + the credits for the other contributing authors. For the avoidance + of doubt, You may only use the credit required by this Section for + the purpose of attribution in the manner set out above and, by + exercising Your rights under this License, You may not implicitly + or explicitly assert or imply any connection with, sponsorship or + endorsement by the Original Author, Licensor and/or Attribution + Parties, as appropriate, of You or Your use of the Work, without + the separate, express prior written permission of the Original + Author, Licensor and/or Attribution Parties. + . + - Except as otherwise agreed in writing by the Licensor or as may be + otherwise permitted by applicable law, if You Reproduce, Distribute + or Publicly Perform the Work either by itself or as part of any + Adaptations or Collections, You must not distort, mutilate, modify + or take other derogatory action in relation to the Work which would + be prejudicial to the Original Author's honor or + reputation. Licensor agrees that in those jurisdictions + (e.g. Japan), in which any exercise of the right granted in Section + 3(b) of this License (the right to make Adaptations) would be + deemed to be a distortion, mutilation, modification or other + derogatory action prejudicial to the Original Author's honor and + reputation, the Licensor will waive or not assert, as appropriate, + this Section, to the fullest extent permitted by the applicable + national law, to enable You to reasonably exercise Your right under + Section 3(b) of this License (right to make Adaptations) but not + otherwise. + . + . + 5. Representations, Warranties and Disclaimer + . + UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, + LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR + WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, + STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF + TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, + NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, + OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT + DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED + WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU. + . + . + 6. Limitation on Liability. + . + EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL + LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT + OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + . + . + 7. Termination + . + - This License and the rights granted hereunder will terminate + automatically upon any breach by You of the terms of this + License. Individuals or entities who have received Adaptations or + Collections from You under this License, however, will not have + their licenses terminated provided such individuals or entities + remain in full compliance with those licenses. Sections 1, 2, 5, 6, + 7, and 8 will survive any termination of this License. + . + - Subject to the above terms and conditions, the license granted here + is perpetual (for the duration of the applicable copyright in the + Work). Notwithstanding the above, Licensor reserves the right to + release the Work under different license terms or to stop + distributing the Work at any time; provided, however that any such + election will not serve to withdraw this License (or any other + license that has been, or is required to be, granted under the + terms of this License), and this License will continue in full + force and effect unless terminated as stated above. + . + . + 8. Miscellaneous + . + - Each time You Distribute or Publicly Perform the Work or a + Collection, the Licensor offers to the recipient a license to the + Work on the same terms and conditions as the license granted to You + under this License. + . + - Each time You Distribute or Publicly Perform an Adaptation, + Licensor offers to the recipient a license to the original Work on + the same terms and conditions as the license granted to You under + this License. + . + - If any provision of this License is invalid or unenforceable under + applicable law, it shall not affect the validity or enforceability + of the remainder of the terms of this License, and without further + action by the parties to this agreement, such provision shall be + reformed to the minimum extent necessary to make such provision + valid and enforceable. + . + - No term or provision of this License shall be deemed waived and no + breach consented to unless such waiver or consent shall be in + writing and signed by the party to be charged with such waiver or + consent. + . + - This License constitutes the entire agreement between the parties + with respect to the Work licensed here. There are no + understandings, agreements or representations with respect to the + Work not specified here. Licensor shall not be bound by any + additional provisions that may appear in any communication from + You. This License may not be modified without the mutual written + agreement of the Licensor and You. + . + - The rights granted under, and the subject matter referenced, in + this License were drafted utilizing the terminology of the Berne + Convention for the Protection of Literary and Artistic Works (as + amended on September 28, 1979), the Rome Convention of 1961, the + WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms + Treaty of 1996 and the Universal Copyright Convention (as revised + on July 24, 1971). These rights and subject matter take effect in + the relevant jurisdiction in which the License terms are sought to + be enforced according to the corresponding provisions of the + implementation of those treaty provisions in the applicable + national law. If the standard suite of rights granted under + applicable copyright law includes additional rights not granted + under this License, such additional rights are deemed to be + included in the License; this License is not intended to restrict + the license of any rights under applicable law. + License: GAP Copying and distribution of this package, with or without modification, are permitted in any medium without royalty provided the copyright notice and this diff --git a/debian/gbp.conf b/debian/gbp.conf new file mode 100644 index 0000000..2eb0837 --- /dev/null +++ b/debian/gbp.conf @@ -0,0 +1,8 @@ +[DEFAULT] +pristine-tar = True +compression = xz +debian-branch = debian/master +upstream-branch = upstream + +[pq] +patch-numbers = False diff --git a/debian/patches/0003-don-t-strip-leading-when-parsing-cookie.patch b/debian/patches/0003-don-t-strip-leading-when-parsing-cookie.patch deleted file mode 100644 index ce17c44..0000000 --- a/debian/patches/0003-don-t-strip-leading-when-parsing-cookie.patch +++ /dev/null @@ -1,83 +0,0 @@ -Description: CVE-2023-23934: don't strip leading = when parsing cookie - Applied-Upstream: 2.2.3 -Author: David Lord -Date: Tue, 31 Jan 2023 14:29:34 -0800 -Origin: upstream, https://github.com/pallets/werkzeug/commit/cf275f42acad1b5950c50ffe8ef58fe62cdce028 -Bug-Debian: https://bugs.debian.org/1031370 -Last-Update: 2023-04-21 - -diff --git a/src/werkzeug/_internal.py b/src/werkzeug/_internal.py -index 4636647..f95207a 100644 ---- a/src/werkzeug/_internal.py -+++ b/src/werkzeug/_internal.py -@@ -34,7 +34,7 @@ _quote_re = re.compile(rb"[\\].") - _legal_cookie_chars_re = rb"[\w\d!#%&\'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]" - _cookie_re = re.compile( - rb""" -- (?P[^=;]+) -+ (?P[^=;]*) - (?:\s*=\s* - (?P - "(?:[^\\"]|\\.)*" | -@@ -382,16 +382,21 @@ def _cookie_parse_impl(b: bytes) -> t.Iterator[t.Tuple[bytes, bytes]]: - """Lowlevel cookie parsing facility that operates on bytes.""" - i = 0 - n = len(b) -+ b += b";" - - while i < n: -- match = _cookie_re.search(b + b";", i) -+ match = _cookie_re.match(b, i) -+ - if not match: - break - -- key = match.group("key").strip() -- value = match.group("val") or b"" - i = match.end(0) -+ key = match.group("key").strip() -+ -+ if not key: -+ continue - -+ value = match.group("val") or b"" - yield key, _cookie_unquote(value) - - -diff --git a/src/werkzeug/sansio/http.py b/src/werkzeug/sansio/http.py -index 8288882..6b22738 100644 ---- a/src/werkzeug/sansio/http.py -+++ b/src/werkzeug/sansio/http.py -@@ -126,10 +126,6 @@ def parse_cookie( - def _parse_pairs() -> t.Iterator[t.Tuple[str, str]]: - for key, val in _cookie_parse_impl(cookie): # type: ignore - key_str = _to_str(key, charset, errors, allow_none_charset=True) -- -- if not key_str: -- continue -- - val_str = _to_str(val, charset, errors, allow_none_charset=True) - yield key_str, val_str - -diff --git a/tests/test_http.py b/tests/test_http.py -index 3760dc1..999549e 100644 ---- a/tests/test_http.py -+++ b/tests/test_http.py -@@ -411,7 +411,8 @@ class TestHTTPUtility: - def test_parse_cookie(self): - cookies = http.parse_cookie( - "dismiss-top=6; CP=null*; PHPSESSID=0a539d42abc001cdc762809248d4beed;" -- 'a=42; b="\\";"; ; fo234{=bar;blub=Blah; "__Secure-c"=d' -+ 'a=42; b="\\";"; ; fo234{=bar;blub=Blah; "__Secure-c"=d;' -+ "==__Host-eq=bad;__Host-eq=good;" - ) - assert cookies.to_dict() == { - "CP": "null*", -@@ -422,6 +423,7 @@ class TestHTTPUtility: - "fo234{": "bar", - "blub": "Blah", - '"__Secure-c"': "d", -+ "__Host-eq": "good", - } - - def test_dump_cookie(self): diff --git a/debian/patches/0004-limit-the-maximum-number-of-multipart-form-parts.patch b/debian/patches/0004-limit-the-maximum-number-of-multipart-form-parts.patch deleted file mode 100644 index bc7395b..0000000 --- a/debian/patches/0004-limit-the-maximum-number-of-multipart-form-parts.patch +++ /dev/null @@ -1,201 +0,0 @@ -Description: CVE-2023-25577: limit the maximum number of multipart form parts - Applied-Upstream: 2.2.3 -Author: David Lord -Date: Tue, 14 Feb 2023 09:08:57 -0800 -Origin: upstream, https://github.com/pallets/werkzeug/commit/517cac5a804e8c4dc4ed038bb20dacd038e7a9f1 -Bug-Debian: https://bugs.debian.org/1031370 -Last-Update: 2023-04-21 - -diff --git a/docs/request_data.rst b/docs/request_data.rst -index 83c6278..e55841e 100644 ---- a/docs/request_data.rst -+++ b/docs/request_data.rst -@@ -73,23 +73,26 @@ read the stream *or* call :meth:`~Request.get_data`. - Limiting Request Data - --------------------- - --To avoid being the victim of a DDOS attack you can set the maximum --accepted content length and request field sizes. The :class:`Request` --class has two attributes for that: :attr:`~Request.max_content_length` --and :attr:`~Request.max_form_memory_size`. -- --The first one can be used to limit the total content length. For example --by setting it to ``1024 * 1024 * 16`` the request won't accept more than --16MB of transmitted data. -- --Because certain data can't be moved to the hard disk (regular post data) --whereas temporary files can, there is a second limit you can set. The --:attr:`~Request.max_form_memory_size` limits the size of `POST` --transmitted form data. By setting it to ``1024 * 1024 * 2`` you can make --sure that all in memory-stored fields are not more than 2MB in size. -- --This however does *not* affect in-memory stored files if the --`stream_factory` used returns a in-memory file. -+The :class:`Request` class provides a few attributes to control how much data is -+processed from the request body. This can help mitigate DoS attacks that craft the -+request in such a way that the server uses too many resources to handle it. Each of -+these limits will raise a :exc:`~werkzeug.exceptions.RequestEntityTooLarge` if they are -+exceeded. -+ -+- :attr:`~Request.max_content_length` Stop reading request data after this number -+ of bytes. It's better to configure this in the WSGI server or HTTP server, rather -+ than the WSGI application. -+- :attr:`~Request.max_form_memory_size` Stop reading request data if any form part is -+ larger than this number of bytes. While file parts can be moved to disk, regular -+ form field data is stored in memory only. -+- :attr:`~Request.max_form_parts` Stop reading request data if more than this number -+ of parts are sent in multipart form data. This is useful to stop a very large number -+ of very small parts, especially file parts. The default is 1000. -+ -+Using Werkzeug to set these limits is only one layer of protection. WSGI servers -+and HTTPS servers should set their own limits on size and timeouts. The operating system -+or container manager should set limits on memory and processing time for server -+processes. - - - How to extend Parsing? -diff --git a/src/werkzeug/formparser.py b/src/werkzeug/formparser.py -index 10d58ca..bebb2fc 100644 ---- a/src/werkzeug/formparser.py -+++ b/src/werkzeug/formparser.py -@@ -179,6 +179,8 @@ class FormDataParser: - :param cls: an optional dict class to use. If this is not specified - or `None` the default :class:`MultiDict` is used. - :param silent: If set to False parsing errors will not be caught. -+ :param max_form_parts: The maximum number of parts to be parsed. If this is -+ exceeded, a :exc:`~exceptions.RequestEntityTooLarge` exception is raised. - """ - - def __init__( -@@ -190,6 +192,8 @@ class FormDataParser: - max_content_length: t.Optional[int] = None, - cls: t.Optional[t.Type[MultiDict]] = None, - silent: bool = True, -+ *, -+ max_form_parts: t.Optional[int] = None, - ) -> None: - if stream_factory is None: - stream_factory = default_stream_factory -@@ -199,6 +203,7 @@ class FormDataParser: - self.errors = errors - self.max_form_memory_size = max_form_memory_size - self.max_content_length = max_content_length -+ self.max_form_parts = max_form_parts - - if cls is None: - cls = MultiDict -@@ -281,6 +286,7 @@ class FormDataParser: - self.errors, - max_form_memory_size=self.max_form_memory_size, - cls=self.cls, -+ max_form_parts=self.max_form_parts, - ) - boundary = options.get("boundary", "").encode("ascii") - -@@ -346,10 +352,12 @@ class MultiPartParser: - max_form_memory_size: t.Optional[int] = None, - cls: t.Optional[t.Type[MultiDict]] = None, - buffer_size: int = 64 * 1024, -+ max_form_parts: t.Optional[int] = None, - ) -> None: - self.charset = charset - self.errors = errors - self.max_form_memory_size = max_form_memory_size -+ self.max_form_parts = max_form_parts - - if stream_factory is None: - stream_factory = default_stream_factory -@@ -409,7 +417,9 @@ class MultiPartParser: - [None], - ) - -- parser = MultipartDecoder(boundary, self.max_form_memory_size) -+ parser = MultipartDecoder( -+ boundary, self.max_form_memory_size, max_parts=self.max_form_parts -+ ) - - fields = [] - files = [] -diff --git a/src/werkzeug/sansio/multipart.py b/src/werkzeug/sansio/multipart.py -index d8abeb3..2684e5d 100644 ---- a/src/werkzeug/sansio/multipart.py -+++ b/src/werkzeug/sansio/multipart.py -@@ -87,10 +87,13 @@ class MultipartDecoder: - self, - boundary: bytes, - max_form_memory_size: Optional[int] = None, -+ *, -+ max_parts: Optional[int] = None, - ) -> None: - self.buffer = bytearray() - self.complete = False - self.max_form_memory_size = max_form_memory_size -+ self.max_parts = max_parts - self.state = State.PREAMBLE - self.boundary = boundary - -@@ -118,6 +121,7 @@ class MultipartDecoder: - re.MULTILINE, - ) - self._search_position = 0 -+ self._parts_decoded = 0 - - def last_newline(self) -> int: - try: -@@ -191,6 +195,10 @@ class MultipartDecoder: - ) - self.state = State.DATA - self._search_position = 0 -+ self._parts_decoded += 1 -+ -+ if self.max_parts is not None and self._parts_decoded > self.max_parts: -+ raise RequestEntityTooLarge() - else: - # Update the search start position to be equal to the - # current buffer length (already searched) minus a -diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py -index 57b739c..a6d5429 100644 ---- a/src/werkzeug/wrappers/request.py -+++ b/src/werkzeug/wrappers/request.py -@@ -83,6 +83,13 @@ class Request(_SansIORequest): - #: .. versionadded:: 0.5 - max_form_memory_size: t.Optional[int] = None - -+ #: The maximum number of multipart parts to parse, passed to -+ #: :attr:`form_data_parser_class`. Parsing form data with more than this -+ #: many parts will raise :exc:`~.RequestEntityTooLarge`. -+ #: -+ #: .. versionadded:: 2.2.3 -+ max_form_parts = 1000 -+ - #: The form data parser that should be used. Can be replaced to customize - #: the form date parsing. - form_data_parser_class: t.Type[FormDataParser] = FormDataParser -@@ -246,6 +253,7 @@ class Request(_SansIORequest): - self.max_form_memory_size, - self.max_content_length, - self.parameter_storage_class, -+ max_form_parts=self.max_form_parts, - ) - - def _load_form_data(self) -> None: -diff --git a/tests/test_formparser.py b/tests/test_formparser.py -index 49010b4..4c518b1 100644 ---- a/tests/test_formparser.py -+++ b/tests/test_formparser.py -@@ -127,6 +127,15 @@ class TestFormParser: - req.max_form_memory_size = 400 - assert req.form["foo"] == "Hello World" - -+ req = Request.from_values( -+ input_stream=io.BytesIO(data), -+ content_length=len(data), -+ content_type="multipart/form-data; boundary=foo", -+ method="POST", -+ ) -+ req.max_form_parts = 1 -+ pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) -+ - def test_missing_multipart_boundary(self): - data = ( - b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n" diff --git a/debian/patches/docs-Use-intersphix-with-Debian-packages.patch b/debian/patches/docs-Use-intersphix-with-Debian-packages.patch new file mode 100644 index 0000000..88c245b --- /dev/null +++ b/debian/patches/docs-Use-intersphix-with-Debian-packages.patch @@ -0,0 +1,22 @@ +From: Carsten Schoenert +Date: Sat, 16 Dec 2023 11:04:19 +0100 +Subject: docs: Use intersphix with Debian packages + +Forwarded: not-needed +--- + docs/conf.py | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/docs/conf.py b/docs/conf.py +index d58c17e..813e1c5 100644 +--- a/docs/conf.py ++++ b/docs/conf.py +@@ -28,7 +28,7 @@ extlinks = { + "ghsa": ("https://github.com/advisories/%s", "GHSA-%s"), + } + intersphinx_mapping = { +- "python": ("https://docs.python.org/3/", None), ++ "python": ("/usr/share/doc/python3-doc/html", None), + } + + # HTML ----------------------------------------------------------------- diff --git a/debian/patches/remove-test_exclude_patterns-test.patch b/debian/patches/remove-test_exclude_patterns-test.patch deleted file mode 100644 index fbc4f2c..0000000 --- a/debian/patches/remove-test_exclude_patterns-test.patch +++ /dev/null @@ -1,26 +0,0 @@ -Description: Remove test_exclude_patterns test - Under the sbuild environment, the asert doesn't work and sys.prefix gets - wrong. So I'm just removing this test. -Author: Thomas Goirand -Forwarded: not-needed -Last-Update: 2022-09-14 - ---- python-werkzeug-2.2.2.orig/tests/test_serving.py -+++ python-werkzeug-2.2.2/tests/test_serving.py -@@ -125,16 +125,6 @@ def test_windows_get_args_for_reloading( - assert rv == argv - - --@pytest.mark.parametrize("find", [_find_stat_paths, _find_watchdog_paths]) --def test_exclude_patterns(find): -- # Imported paths under sys.prefix will be included by default. -- paths = find(set(), set()) -- assert any(p.startswith(sys.prefix) for p in paths) -- # Those paths should be excluded due to the pattern. -- paths = find(set(), {f"{sys.prefix}*"}) -- assert not any(p.startswith(sys.prefix) for p in paths) -- -- - @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") - @pytest.mark.dev_server - def test_wrong_protocol(standard_app): diff --git a/debian/patches/series b/debian/patches/series index a9ce589..1308be9 100644 --- a/debian/patches/series +++ b/debian/patches/series @@ -1,4 +1,2 @@ -preserve-any-existing-PYTHONPATH-in-tests.patch -remove-test_exclude_patterns-test.patch -0003-don-t-strip-leading-when-parsing-cookie.patch -0004-limit-the-maximum-number-of-multipart-form-parts.patch +tests-Preserve-any-existing-PYTHONPATH-in-tests.patch +docs-Use-intersphix-with-Debian-packages.patch diff --git a/debian/patches/preserve-any-existing-PYTHONPATH-in-tests.patch b/debian/patches/tests-Preserve-any-existing-PYTHONPATH-in-tests.patch similarity index 83% rename from debian/patches/preserve-any-existing-PYTHONPATH-in-tests.patch rename to debian/patches/tests-Preserve-any-existing-PYTHONPATH-in-tests.patch index f70e6fc..6706bdf 100644 --- a/debian/patches/preserve-any-existing-PYTHONPATH-in-tests.patch +++ b/debian/patches/tests-Preserve-any-existing-PYTHONPATH-in-tests.patch @@ -1,17 +1,17 @@ -From b88042cfb32866a00d39b678bb224eb55ecf53c1 Mon Sep 17 00:00:00 2001 From: Lumir Balhar Date: Tue, 22 Jun 2021 22:10:17 +0200 -Subject: [PATCH] Preserve any existing PYTHONPATH in tests +Subject: tests: Preserve any existing PYTHONPATH in tests +Forwarded: not-needed --- tests/conftest.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py -index 4ad1ff23..7200d286 100644 +index b73202c..fd666c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py -@@ -118,9 +118,15 @@ def dev_server(xprocess, request, tmp_path): +@@ -103,9 +103,15 @@ def dev_server(xprocess, request, tmp_path): class Starter(ProcessStarter): args = [sys.executable, run_path, name, json.dumps(kwargs)] # Extend the existing env, otherwise Windows and CI fails. @@ -29,6 +29,3 @@ index 4ad1ff23..7200d286 100644 @cached_property def pattern(self): --- -2.31.1 - diff --git a/debian/python-werkzeug-doc.links b/debian/python-werkzeug-doc.links index 2c69103..21eac33 100644 --- a/debian/python-werkzeug-doc.links +++ b/debian/python-werkzeug-doc.links @@ -1,6 +1,3 @@ -/usr/share/doc/python-werkzeug-doc/examples /usr/share/doc/python-werkzeug/examples /usr/share/doc/python-werkzeug-doc/examples /usr/share/doc/python3-werkzeug/examples -/usr/share/doc/python-werkzeug-doc/html /usr/share/doc/python-werkzeug/html /usr/share/doc/python-werkzeug-doc/html /usr/share/doc/python3-werkzeug/html -/usr/share/doc/python-werkzeug-doc/html/_sources /usr/share/doc/python-werkzeug/rst /usr/share/doc/python-werkzeug-doc/html/_sources /usr/share/doc/python3-werkzeug/rst diff --git a/debian/python3-werkzeug.links b/debian/python3-werkzeug.links deleted file mode 100644 index 01ca329..0000000 --- a/debian/python3-werkzeug.links +++ /dev/null @@ -1 +0,0 @@ -/usr/share/javascript/jquery/jquery.js /usr/lib/python3/dist-packages/werkzeug/debug/shared/jquery.js diff --git a/debian/rules b/debian/rules index 00c01ef..8a7d89c 100755 --- a/debian/rules +++ b/debian/rules @@ -1,23 +1,26 @@ #!/usr/bin/make -f +# -*- makefile -*- -# Copyright 2009, Noah Slater +#export DH_VERBOSE = 1 -# Copying and distribution of this file, with or without modification, are -# permitted in any medium without royalty provided the copyright notice and this -# notice are preserved. +include /usr/share/dpkg/pkg-info.mk + +BUILD_DATE = $(shell LC_ALL=C date -u "+%d %B %Y" -d "@$(SOURCE_DATE_EPOCH)") +SPHINXOPTS := -E -N -D html_last_updated_fmt="$(BUILD_DATE)" export PYBUILD_NAME=werkzeug export PYBUILD_TEST_PYTEST=1 +export PYBUILD_BEFORE_INSTALL=rm -rf {build_dir}/werkzeug/debug/shared/ICON_LICENSE.md %: - dh $@ --with python3,sphinxdoc --buildsystem pybuild + dh $@ --buildsystem=pybuild override_dh_auto_clean: make -C docs clean rm -rf build Werkzeug.egg-info/ #find $(CURDIR) \( -name '\._*' -o -name '\.DS_Store' \) -delete find . -iname '__pycache__' -exec rm -rf {} \; || true - rm -rf .pytest_cache + rm -rf .pytest_cache .mypy_cache debian/doctrees dh_auto_clean override_dh_fixperms: @@ -32,7 +35,8 @@ override_dh_installexamples: dh_installexamples --doc-main-package=python-werkzeug-doc -ppython-werkzeug-doc override_dh_sphinxdoc: -ifeq (,$(findstring nodocs, $(DEB_BUILD_OPTIONS))) - PYTHONPATH=src python3 -m sphinx -b html docs/ debian/python-werkzeug-doc/usr/share/doc/python-werkzeug-doc/html/ +ifeq (,$(findstring nodoc, $(DEB_BUILD_OPTIONS))) + PYTHONPATH=`dirname $$(find .pybuild/ -type d -name "werkzeug*dist-info" | head -n1)` \ + python3 -m sphinx -b html docs/ -d debian/doctrees $(SPHINXOPTS) debian/python-werkzeug-doc/usr/share/doc/python-werkzeug-doc/html/ dh_sphinxdoc endif diff --git a/debian/tests/control b/debian/tests/control index bfa14fd..f555610 100644 --- a/debian/tests/control +++ b/debian/tests/control @@ -1,5 +1,12 @@ Tests: upstream Depends: + python3-all, + python3-cryptography, + python3-ephemeral-port-reserve, + python3-greenlet, + python3-pytest, + python3-pytest-timeout, + python3-pytest-xprocess, + python3-watchdog, @, - @builddeps@, Restrictions: allow-stderr diff --git a/debian/tests/upstream b/debian/tests/upstream index ad3ff8b..3478756 100755 --- a/debian/tests/upstream +++ b/debian/tests/upstream @@ -2,13 +2,20 @@ set -eu export LC_ALL=C.UTF-8 -pyvers=$(py3versions -r 2>/dev/null) cp -a tests "$AUTOPKGTEST_TMP" cd "$AUTOPKGTEST_TMP" -for py in ${pyvers}; do - echo "-=-=-=-=-=-=-=- running tests for ${py} -=-=-=-=-=-=-=-=-" - printf '$ %s\n' "${py} -m pytest tests" - ${py} -m pytest tests +for py3vers in $(py3versions -s); do + echo + echo "***************************" + echo "*** Testing with ${py3vers}" + echo "***************************" + echo + cd ${AUTOPKGTEST_TMP} && \ + echo "Content of current working folder:\n" && \ + ls -la && \ + echo "Running tests...\n" && \ + PYTHONPATH=. ${py3vers} -m pytest && \ + rm -rf .pytest_cache || exit 1 done diff --git a/debian/upstream/metadata b/debian/upstream/metadata index 2c2d29c..aee0873 100644 --- a/debian/upstream/metadata +++ b/debian/upstream/metadata @@ -1,4 +1,6 @@ +--- Bug-Database: https://github.com/pallets/werkzeug/issues Bug-Submit: https://github.com/pallets/werkzeug/issues/new +FAQ: https://werkzeug.palletsprojects.com/ Repository: https://github.com/pallets/werkzeug.git Repository-Browse: https://github.com/pallets/werkzeug diff --git a/debian/watch b/debian/watch index 8b2c992..81a92cb 100644 --- a/debian/watch +++ b/debian/watch @@ -1,6 +1,8 @@ -version=3 -opts=uversionmangle=s/(rc|a|b|c)/~$1/,\ -dversionmangle=auto,\ -repack,\ -filenamemangle=s/.+\/v?(\d\S*)\.tar\.gz/werkzeug-$1\.tar\.gz/ \ -https://github.com/pallets/werkzeug/tags .*/v?(\d\S*)\.tar\.gz +version=4 + +opts="mode=git, \ + compression=xz, \ + uversionmangle=s/(\d)[_\.\-\+]?((RC|rc|pre|dev|beta|alpha)\.?\d*)$/$1~$2/, \ + filenamemangle=s/.+\/v?(\d\S*)\.tar\.gz/werkzeug-$1\.tar\.gz/" \ +https://github.com/pallets/werkzeug.git \ + refs/tags/@ANY_VERSION@ diff --git a/docs/_static/favicon.ico b/docs/_static/favicon.ico deleted file mode 100644 index a3b079a..0000000 Binary files a/docs/_static/favicon.ico and /dev/null differ diff --git a/docs/_static/shortcut-icon.png b/docs/_static/shortcut-icon.png new file mode 100644 index 0000000..37cf028 Binary files /dev/null and b/docs/_static/shortcut-icon.png differ diff --git a/docs/_static/werkzeug-horizontal.png b/docs/_static/werkzeug-horizontal.png new file mode 100644 index 0000000..0581470 Binary files /dev/null and b/docs/_static/werkzeug-horizontal.png differ diff --git a/docs/_static/werkzeug-vertical.png b/docs/_static/werkzeug-vertical.png new file mode 100644 index 0000000..be2a7a3 Binary files /dev/null and b/docs/_static/werkzeug-vertical.png differ diff --git a/docs/_static/werkzeug.png b/docs/_static/werkzeug.png deleted file mode 100644 index 9cedb06..0000000 Binary files a/docs/_static/werkzeug.png and /dev/null differ diff --git a/docs/conf.py b/docs/conf.py index 96e998b..d58c17e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,30 +10,37 @@ # General -------------------------------------------------------------- -master_doc = "index" +default_role = "code" extensions = [ "sphinx.ext.autodoc", + "sphinx.ext.extlinks", "sphinx.ext.intersphinx", - "pallets_sphinx_themes", - "sphinx_issues", "sphinxcontrib.log_cabinet", + "pallets_sphinx_themes", ] autoclass_content = "both" +autodoc_member_order = "bysource" autodoc_typehints = "description" -intersphinx_mapping = {"python": ("https://docs.python.org/3/", None)} -issues_github_path = "pallets/werkzeug" +autodoc_preserve_defaults = True +extlinks = { + "issue": ("https://github.com/pallets/werkzeug/issues/%s", "#%s"), + "pr": ("https://github.com/pallets/werkzeug/pull/%s", "#%s"), + "ghsa": ("https://github.com/advisories/%s", "GHSA-%s"), +} +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), +} # HTML ----------------------------------------------------------------- html_theme = "werkzeug" +html_theme_options = {"index_sidebar_logo": False} html_context = { "project_links": [ ProjectLink("Donate", "https://palletsprojects.com/donate"), ProjectLink("PyPI Releases", "https://pypi.org/project/Werkzeug/"), ProjectLink("Source Code", "https://github.com/pallets/werkzeug/"), ProjectLink("Issue Tracker", "https://github.com/pallets/werkzeug/issues/"), - ProjectLink("Website", "https://palletsprojects.com/p/werkzeug/"), - ProjectLink("Twitter", "https://twitter.com/PalletsTeam"), ProjectLink("Chat", "https://discord.gg/pallets"), ] } @@ -43,13 +50,7 @@ } singlehtml_sidebars = {"index": ["project.html", "localtoc.html", "ethicalads.html"]} html_static_path = ["_static"] -html_favicon = "_static/favicon.ico" -html_logo = "_static/werkzeug.png" +html_favicon = "_static/shortcut-icon.png" +html_logo = "_static/werkzeug-vertical.png" html_title = f"Werkzeug Documentation ({version})" html_show_sourcelink = False - -# LaTeX ---------------------------------------------------------------- - -latex_documents = [ - (master_doc, f"Werkzeug-{version}.tex", html_title, author, "manual") -] diff --git a/docs/debug.rst b/docs/debug.rst index 25a9f0b..d842135 100644 --- a/docs/debug.rst +++ b/docs/debug.rst @@ -16,7 +16,8 @@ interactive debug console to execute code in any frame. The debugger allows the execution of arbitrary code which makes it a major security risk. **The debugger must never be used on production machines. We cannot stress this enough. Do not enable the debugger - in production.** + in production.** Production means anything that is not development, + and anything that is publicly accessible. .. note:: @@ -72,10 +73,9 @@ argument to get a detailed list of all the attributes it has. Debugger PIN ------------ -Starting with Werkzeug 0.11 the debug console is protected by a PIN. -This is a security helper to make it less likely for the debugger to be -exploited if you forget to disable it when deploying to production. The -PIN based authentication is enabled by default. +The debug console is protected by a PIN. This is a security helper to make it +less likely for the debugger to be exploited if you forget to disable it when +deploying to production. The PIN based authentication is enabled by default. The first time a console is opened, a dialog will prompt for a PIN that is printed to the command line. The PIN is generated in a stable way @@ -92,6 +92,31 @@ intended to make it harder for an attacker to exploit the debugger. Never enable the debugger in production.** +Allowed Hosts +------------- + +The debug console will only be served if the request comes from a trusted host. +If a request comes from a browser page that is not served on a trusted URL, a +400 error will be returned. + +By default, ``localhost``, any ``.localhost`` subdomain, and ``127.0.0.1`` are +trusted. ``run_simple`` will trust its ``hostname`` argument as well. To change +this further, use the debug middleware directly rather than through +``use_debugger=True``. + +.. code-block:: python + + if os.environ.get("USE_DEBUGGER") in {"1", "true"}: + app = DebuggedApplication(app, evalex=True) + app.trusted_hosts = [...] + + run_simple("localhost", 8080, app) + +**This feature is not meant to entirely secure the debugger. It is +intended to make it harder for an attacker to exploit the debugger. +Never enable the debugger in production.** + + Pasting Errors -------------- diff --git a/docs/http.rst b/docs/http.rst index cbf4e04..790de31 100644 --- a/docs/http.rst +++ b/docs/http.rst @@ -53,10 +53,6 @@ by :rfc:`2616`, Werkzeug implements some custom data structures that are .. autofunction:: parse_cache_control_header -.. autofunction:: parse_authorization_header - -.. autofunction:: parse_www_authenticate_header - .. autofunction:: parse_if_range_header .. autofunction:: parse_range_header diff --git a/docs/index.rst b/docs/index.rst index c4f0019..4bc4e30 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,6 +1,12 @@ +.. rst-class:: hide-header + Werkzeug ======== +.. image:: _static/werkzeug-horizontal.png + :align: center + :target: https://werkzeug.palletsprojects.com + *werkzeug* German noun: "tool". Etymology: *werk* ("work"), *zeug* ("stuff") @@ -72,7 +78,6 @@ Additional Information :maxdepth: 2 terms - unicode request_data license changes diff --git a/docs/installation.rst b/docs/installation.rst index 9c5aa7f..7138f08 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -6,13 +6,7 @@ Python Version -------------- We recommend using the latest version of Python. Werkzeug supports -Python 3.7 and newer. - - -Dependencies ------------- - -Werkzeug does not have any direct dependencies. +Python 3.8 and newer. Optional dependencies diff --git a/docs/license.rst b/docs/license.rst index a53a98c..2a445f9 100644 --- a/docs/license.rst +++ b/docs/license.rst @@ -1,4 +1,5 @@ BSD-3-Clause License ==================== -.. include:: ../LICENSE.rst +.. literalinclude:: ../LICENSE.txt + :language: text diff --git a/docs/middleware/index.rst b/docs/middleware/index.rst index 70cddee..3d7ede4 100644 --- a/docs/middleware/index.rst +++ b/docs/middleware/index.rst @@ -1 +1,20 @@ -.. automodule:: werkzeug.middleware +Middleware +========== + +A WSGI middleware is a WSGI application that wraps another application +in order to observe or change its behavior. Werkzeug provides some +middleware for common use cases. + +.. toctree:: + :maxdepth: 1 + + proxy_fix + shared_data + dispatcher + http_proxy + lint + profiler + +The :doc:`interactive debugger ` is also a middleware that can +be applied manually, although it is typically used automatically with +the :doc:`development server `. diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 1568892..0f3714e 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -43,9 +43,7 @@ there: >>> request = Request(environ) Now you can access the important variables and Werkzeug will parse them -for you and decode them where it makes sense. The default charset for -requests is set to `utf-8` but you can change that by subclassing -:class:`Request`. +for you and decode them where it makes sense. >>> request.path '/foo' diff --git a/docs/request_data.rst b/docs/request_data.rst index 83c6278..b1c97b2 100644 --- a/docs/request_data.rst +++ b/docs/request_data.rst @@ -73,23 +73,31 @@ read the stream *or* call :meth:`~Request.get_data`. Limiting Request Data --------------------- -To avoid being the victim of a DDOS attack you can set the maximum -accepted content length and request field sizes. The :class:`Request` -class has two attributes for that: :attr:`~Request.max_content_length` -and :attr:`~Request.max_form_memory_size`. - -The first one can be used to limit the total content length. For example -by setting it to ``1024 * 1024 * 16`` the request won't accept more than -16MB of transmitted data. - -Because certain data can't be moved to the hard disk (regular post data) -whereas temporary files can, there is a second limit you can set. The -:attr:`~Request.max_form_memory_size` limits the size of `POST` -transmitted form data. By setting it to ``1024 * 1024 * 2`` you can make -sure that all in memory-stored fields are not more than 2MB in size. - -This however does *not* affect in-memory stored files if the -`stream_factory` used returns a in-memory file. +The :class:`Request` class provides a few attributes to control how much data is +processed from the request body. This can help mitigate DoS attacks that craft the +request in such a way that the server uses too many resources to handle it. Each of +these limits will raise a :exc:`~werkzeug.exceptions.RequestEntityTooLarge` if they are +exceeded. + +- :attr:`~Request.max_content_length` Stop reading request data after this number + of bytes. It's better to configure this in the WSGI server or HTTP server, rather + than the WSGI application. +- :attr:`~Request.max_form_memory_size` Stop reading request data if any form part is + larger than this number of bytes. While file parts can be moved to disk, regular + form field data is stored in memory only. +- :attr:`~Request.max_form_parts` Stop reading request data if more than this number + of parts are sent in multipart form data. This is useful to stop a very large number + of very small parts, especially file parts. The default is 1000. + +Using Werkzeug to set these limits is only one layer of protection. WSGI servers +and HTTPS servers should set their own limits on size and timeouts. The operating system +or container manager should set limits on memory and processing time for server +processes. + +If a 413 Content Too Large error is returned before the entire request is read, clients +may show a "connection reset" failure instead of the 413 error. This is based on how the +WSGI/HTTP server and client handle connections, it's not something the WSGI application +(Werkzeug) has control over. How to extend Parsing? diff --git a/docs/test.rst b/docs/test.rst index efb449a..d31ac59 100644 --- a/docs/test.rst +++ b/docs/test.rst @@ -18,8 +18,8 @@ requests. >>> response = c.get("/") >>> response.status_code 200 ->>> resp.headers -Headers([('Content-Type', 'text/html; charset=utf-8'), ('Content-Length', '6658')]) +>>> response.headers +Headers([('Content-Type', 'text/html; charset=utf-8'), ('Content-Length', '5211')]) >>> response.get_data(as_text=True) '...' @@ -102,6 +102,10 @@ API :members: :member-order: bysource +.. autoclass:: Cookie + :members: + :member-order: bysource + .. autoclass:: EnvironBuilder :members: :member-order: bysource diff --git a/docs/unicode.rst b/docs/unicode.rst deleted file mode 100644 index 30f76f5..0000000 --- a/docs/unicode.rst +++ /dev/null @@ -1,76 +0,0 @@ -Unicode -======= - -.. currentmodule:: werkzeug - -Werkzeug uses strings internally everwhere text data is assumed, even if -the HTTP standard is not Unicode aware. Basically all incoming data is -decoded from the charset (UTF-8 by default) so that you don't work with -bytes directly. Outgoing data is encoded into the target charset. - - -Unicode in Python ------------------ - -Imagine you have the German Umlaut ``ö``. In ASCII you cannot represent -that character, but in the ``latin-1`` and ``utf-8`` character sets you -can represent it, but they look different when encoded: - ->>> "ö".encode("latin1") -b'\xf6' ->>> "ö".encode("utf-8") -b'\xc3\xb6' - -An ``ö`` looks different depending on the encoding which makes it hard -to work with it as bytes. Instead, Python treats strings as Unicode text -and stores the information ``LATIN SMALL LETTER O WITH DIAERESIS`` -instead of the bytes for ``ö`` in a specific encoding. The length of a -string with 1 character will be 1, where the length of the bytes might -be some other value. - - -Unicode in HTTP ---------------- - -However, the HTTP spec was written in a time where ASCII bytes were the -common way data was represented. To work around this for the modern -web, Werkzeug decodes and encodes incoming and outgoing data -automatically. Data sent from the browser to the web application is -decoded from UTF-8 bytes into a string. Data sent from the application -back to the browser is encoded back to UTF-8. - - -Error Handling --------------- - -Functions that do internal encoding or decoding accept an ``errors`` -keyword argument that is passed to :meth:`str.decode` and -:meth:`str.encode`. The default is ``'replace'`` so that errors are easy -to spot. It might be useful to set it to ``'strict'`` in order to catch -the error and report the bad data to the client. - - -Request and Response Objects ----------------------------- - -In most cases, you should stick with Werkzeug's default encoding of -UTF-8. If you have a specific reason to, you can subclass -:class:`wrappers.Request` and :class:`wrappers.Response` to change the -encoding and error handling. - -.. code-block:: python - - from werkzeug.wrappers.request import Request - from werkzeug.wrappers.response import Response - - class Latin1Request(Request): - charset = "latin1" - encoding_errors = "strict" - - class Latin1Response(Response): - charset = "latin1" - -The error handling can only be changed for the request. Werkzeug will -always raise errors when encoding to bytes in the response. It's your -responsibility to not create data that is not present in the target -charset. This is not an issue for UTF-8. diff --git a/docs/utils.rst b/docs/utils.rst index 0d4e339..6afa4ab 100644 --- a/docs/utils.rst +++ b/docs/utils.rst @@ -23,6 +23,8 @@ General Helpers .. autofunction:: send_file +.. autofunction:: send_from_directory + .. autofunction:: import_string .. autofunction:: find_modules diff --git a/docs/wsgi.rst b/docs/wsgi.rst index a96916b..67b3bb6 100644 --- a/docs/wsgi.rst +++ b/docs/wsgi.rst @@ -22,10 +22,6 @@ iterator and the input stream. .. autoclass:: LimitedStream :members: -.. autofunction:: make_line_iter - -.. autofunction:: make_chunk_iter - .. autofunction:: wrap_file @@ -43,18 +39,6 @@ information or perform common manipulations: .. autofunction:: get_current_url -.. autofunction:: get_query_string - -.. autofunction:: get_script_name - -.. autofunction:: get_path_info - -.. autofunction:: pop_path_info - -.. autofunction:: peek_path_info - -.. autofunction:: extract_path_info - .. autofunction:: host_is_trusted diff --git a/examples/couchy/utils.py b/examples/couchy/utils.py index 03d1681..5c39fdf 100644 --- a/examples/couchy/utils.py +++ b/examples/couchy/utils.py @@ -1,6 +1,7 @@ from os import path from random import randrange from random import sample +from urllib.parse import urlsplit from jinja2 import Environment from jinja2 import FileSystemLoader @@ -8,7 +9,6 @@ from werkzeug.local import LocalManager from werkzeug.routing import Map from werkzeug.routing import Rule -from werkzeug.urls import url_parse from werkzeug.utils import cached_property from werkzeug.wrappers import Response @@ -49,7 +49,7 @@ def render_template(template, **context): def validate_url(url): - return url_parse(url)[0] in ALLOWED_SCHEMES + return urlsplit(url)[0] in ALLOWED_SCHEMES def get_random_uid(): diff --git a/examples/shortly/shortly.py b/examples/shortly/shortly.py index 10e957e..5205f22 100644 --- a/examples/shortly/shortly.py +++ b/examples/shortly/shortly.py @@ -1,5 +1,6 @@ """A simple URL shortener using Werkzeug and redis.""" import os +from urllib.parse import urlsplit import redis from jinja2 import Environment @@ -9,7 +10,6 @@ from werkzeug.middleware.shared_data import SharedDataMiddleware from werkzeug.routing import Map from werkzeug.routing import Rule -from werkzeug.urls import url_parse from werkzeug.utils import redirect from werkzeug.wrappers import Request from werkzeug.wrappers import Response @@ -27,12 +27,12 @@ def base36_encode(number): def is_valid_url(url): - parts = url_parse(url) + parts = urlsplit(url) return parts.scheme in ("http", "https") def get_hostname(url): - return url_parse(url).netloc + return urlsplit(url).netloc class Shortly: diff --git a/examples/shorty/utils.py b/examples/shorty/utils.py index 2d9fe0e..4d064e3 100644 --- a/examples/shorty/utils.py +++ b/examples/shorty/utils.py @@ -1,6 +1,7 @@ from os import path from random import randrange from random import sample +from urllib.parse import urlsplit from jinja2 import Environment from jinja2 import FileSystemLoader @@ -11,7 +12,6 @@ from werkzeug.local import LocalManager from werkzeug.routing import Map from werkzeug.routing import Rule -from werkzeug.urls import url_parse from werkzeug.utils import cached_property from werkzeug.wrappers import Response @@ -59,7 +59,7 @@ def render_template(template, **context): def validate_url(url): - return url_parse(url)[0] in ALLOWED_SCHEMES + return urlsplit(url)[0] in ALLOWED_SCHEMES def get_random_uid(): diff --git a/examples/simplewiki/utils.py b/examples/simplewiki/utils.py index 6cafab4..00729c6 100644 --- a/examples/simplewiki/utils.py +++ b/examples/simplewiki/utils.py @@ -1,12 +1,12 @@ from os import path +from urllib.parse import quote +from urllib.parse import urlencode import creoleparser from genshi import Stream from genshi.template import TemplateLoader from werkzeug.local import Local from werkzeug.local import LocalManager -from werkzeug.urls import url_encode -from werkzeug.urls import url_quote from werkzeug.utils import cached_property from werkzeug.wrappers import Request as BaseRequest from werkzeug.wrappers import Response as BaseResponse @@ -58,9 +58,9 @@ def href(*args, **kw): """ result = [f"{request.script_root if request else ''}/"] for idx, arg in enumerate(args): - result.append(f"{'/' if idx else ''}{url_quote(arg)}") + result.append(f"{'/' if idx else ''}{quote(arg)}") if kw: - result.append(f"?{url_encode(kw)}") + result.append(f"?{urlencode(kw)}") return "".join(result) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..eb06882 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,117 @@ +[project] +name = "Werkzeug" +version = "3.0.3" +description = "The comprehensive WSGI web application library." +readme = "README.md" +license = {file = "LICENSE.txt"} +maintainers = [{name = "Pallets", email = "contact@palletsprojects.com"}] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Topic :: Internet :: WWW/HTTP :: Dynamic Content", + "Topic :: Internet :: WWW/HTTP :: WSGI", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Application", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware", + "Topic :: Software Development :: Libraries :: Application Frameworks", + "Typing :: Typed", +] +requires-python = ">=3.8" +dependencies = [ + "MarkupSafe>=2.1.1", +] + +[project.urls] +Donate = "https://palletsprojects.com/donate" +Documentation = "https://werkzeug.palletsprojects.com/" +Changes = "https://werkzeug.palletsprojects.com/changes/" +"Source Code" = "https://github.com/pallets/werkzeug/" +"Issue Tracker" = "https://github.com/pallets/werkzeug/issues/" +Chat = "https://discord.gg/pallets" + +[project.optional-dependencies] +watchdog = ["watchdog>=2.3"] + +[build-system] +requires = ["flit_core<4"] +build-backend = "flit_core.buildapi" + +[tool.flit.module] +name = "werkzeug" + +[tool.flit.sdist] +include = [ + "docs/", + "examples/", + "requirements/", + "tests/", + "CHANGES.rst", + "tox.ini", +] +exclude = [ + "docs/_build/", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +filterwarnings = [ + "error", +] +markers = ["dev_server: tests that start the dev server"] + +[tool.coverage.run] +branch = true +source = ["werkzeug", "tests"] + +[tool.coverage.paths] +source = ["src", "*/site-packages"] + +[tool.mypy] +files = ["src/werkzeug"] +show_error_codes = true +pretty = true +strict = true + +[[tool.mypy.overrides]] +module = [ + "colorama.*", + "cryptography.*", + "eventlet.*", + "gevent.*", + "greenlet.*", + "watchdog.*", + "xprocess.*", +] +ignore_missing_imports = true + +[tool.pyright] +pythonVersion = "3.8" +include = ["src/werkzeug"] + +[tool.ruff] +extend-exclude = ["examples/"] +src = ["src"] +fix = true +show-fixes = true +output-format = "full" + +[tool.ruff.lint] +select = [ + "B", # flake8-bugbear + "E", # pycodestyle error + "F", # pyflakes + "I", # isort + "UP", # pyupgrade + "W", # pycodestyle warning +] +ignore = [ + "E402", # allow circular imports at end of file +] +ignore-init-module-imports = true + +[tool.ruff.lint.isort] +force-single-line = true +order-by-type = false diff --git a/requirements/build.in b/requirements/build.in new file mode 100644 index 0000000..378eac2 --- /dev/null +++ b/requirements/build.in @@ -0,0 +1 @@ +build diff --git a/requirements/build.txt b/requirements/build.txt new file mode 100644 index 0000000..9ecc489 --- /dev/null +++ b/requirements/build.txt @@ -0,0 +1,12 @@ +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# pip-compile build.in +# +build==1.2.1 + # via -r build.in +packaging==24.0 + # via build +pyproject-hooks==1.0.0 + # via build diff --git a/requirements/dev.in b/requirements/dev.in index 99f5942..1efde82 100644 --- a/requirements/dev.in +++ b/requirements/dev.in @@ -1,6 +1,5 @@ --r docs.in --r tests.in --r typing.in -pip-compile-multi +-r docs.txt +-r tests.txt +-r typing.txt pre-commit tox diff --git a/requirements/dev.txt b/requirements/dev.txt index 50e233e..186ceda 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,64 +1,202 @@ -# SHA1:54b5b77ec8c7a0064ffa93b2fd16cb0130ba177c # -# This file is autogenerated by pip-compile-multi -# To update, run: +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: # -# pip-compile-multi +# pip-compile dev.in # --r docs.txt --r tests.txt --r typing.txt -build==0.8.0 - # via pip-tools -cfgv==3.3.1 +alabaster==0.7.16 + # via + # -r docs.txt + # sphinx +babel==2.14.0 + # via + # -r docs.txt + # sphinx +cachetools==5.3.3 + # via tox +certifi==2024.2.2 + # via + # -r docs.txt + # requests +cffi==1.16.0 + # via + # -r tests.txt + # cryptography +cfgv==3.4.0 # via pre-commit -click==8.1.3 +chardet==5.2.0 + # via tox +charset-normalizer==3.3.2 # via - # pip-compile-multi - # pip-tools -distlib==0.3.4 + # -r docs.txt + # requests +colorama==0.4.6 + # via tox +cryptography==42.0.5 + # via -r tests.txt +distlib==0.3.8 # via virtualenv -filelock==3.7.1 +docutils==0.20.1 + # via + # -r docs.txt + # sphinx +ephemeral-port-reserve==1.1.4 + # via -r tests.txt +filelock==3.13.3 # via # tox # virtualenv -greenlet==1.1.2 ; python_version < "3.11" - # via -r requirements/tests.in -identify==2.5.1 +greenlet==3.0.3 + # via -r tests.txt +identify==2.5.35 # via pre-commit -nodeenv==1.7.0 - # via pre-commit -pep517==0.12.0 - # via build -pip-compile-multi==2.4.5 - # via -r requirements/dev.in -pip-tools==6.8.0 - # via pip-compile-multi -platformdirs==2.5.2 - # via virtualenv -pre-commit==2.20.0 - # via -r requirements/dev.in -pyyaml==6.0 - # via pre-commit -six==1.16.0 +idna==3.6 + # via + # -r docs.txt + # requests +imagesize==1.4.1 + # via + # -r docs.txt + # sphinx +iniconfig==2.0.0 + # via + # -r tests.txt + # -r typing.txt + # pytest +jinja2==3.1.3 + # via + # -r docs.txt + # sphinx +markupsafe==2.1.5 + # via + # -r docs.txt + # jinja2 +mypy==1.9.0 + # via -r typing.txt +mypy-extensions==1.0.0 + # via + # -r typing.txt + # mypy +nodeenv==1.8.0 + # via + # -r typing.txt + # pre-commit + # pyright +packaging==24.0 + # via + # -r docs.txt + # -r tests.txt + # -r typing.txt + # pallets-sphinx-themes + # pyproject-api + # pytest + # sphinx + # tox +pallets-sphinx-themes==2.1.1 + # via -r docs.txt +platformdirs==4.2.0 # via # tox # virtualenv -toml==0.10.2 +pluggy==1.4.0 # via - # pre-commit + # -r tests.txt + # -r typing.txt + # pytest # tox -toposort==1.7 - # via pip-compile-multi -tox==3.25.1 - # via -r requirements/dev.in -virtualenv==20.15.1 +pre-commit==3.7.0 + # via -r dev.in +psutil==5.9.8 + # via + # -r tests.txt + # pytest-xprocess +pycparser==2.22 + # via + # -r tests.txt + # cffi +pygments==2.17.2 + # via + # -r docs.txt + # sphinx +pyproject-api==1.6.1 + # via tox +pyright==1.1.357 + # via -r typing.txt +pytest==8.1.1 + # via + # -r tests.txt + # -r typing.txt + # pytest-timeout + # pytest-xprocess +pytest-timeout==2.3.1 + # via -r tests.txt +pytest-xprocess==0.23.0 + # via -r tests.txt +pyyaml==6.0.1 + # via pre-commit +requests==2.31.0 + # via + # -r docs.txt + # sphinx +snowballstemmer==2.2.0 + # via + # -r docs.txt + # sphinx +sphinx==7.2.6 + # via + # -r docs.txt + # pallets-sphinx-themes + # sphinxcontrib-log-cabinet +sphinxcontrib-applehelp==1.0.8 + # via + # -r docs.txt + # sphinx +sphinxcontrib-devhelp==1.0.6 + # via + # -r docs.txt + # sphinx +sphinxcontrib-htmlhelp==2.0.5 + # via + # -r docs.txt + # sphinx +sphinxcontrib-jsmath==1.0.1 + # via + # -r docs.txt + # sphinx +sphinxcontrib-log-cabinet==1.0.1 + # via -r docs.txt +sphinxcontrib-qthelp==1.0.7 + # via + # -r docs.txt + # sphinx +sphinxcontrib-serializinghtml==1.1.10 + # via + # -r docs.txt + # sphinx +tox==4.14.2 + # via -r dev.in +types-contextvars==2.4.7.3 + # via -r typing.txt +types-dataclasses==0.6.6 + # via -r typing.txt +types-setuptools==69.2.0.20240317 + # via -r typing.txt +typing-extensions==4.11.0 + # via + # -r typing.txt + # mypy +urllib3==2.2.1 + # via + # -r docs.txt + # requests +virtualenv==20.25.1 # via # pre-commit # tox -wheel==0.37.1 - # via pip-tools +watchdog==4.0.0 + # via + # -r tests.txt + # -r typing.txt # The following packages are considered to be unsafe in a requirements file: -# pip # setuptools diff --git a/requirements/docs.in b/requirements/docs.in index 7ec501b..ba3fd77 100644 --- a/requirements/docs.in +++ b/requirements/docs.in @@ -1,4 +1,3 @@ -Pallets-Sphinx-Themes -Sphinx -sphinx-issues +pallets-sphinx-themes +sphinx sphinxcontrib-log-cabinet diff --git a/requirements/docs.txt b/requirements/docs.txt index 8238e78..ed605ea 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,65 +1,57 @@ -# SHA1:45c590f97fe95b8bdc755eef796e91adf5fbe4ea # -# This file is autogenerated by pip-compile-multi -# To update, run: +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: # -# pip-compile-multi +# pip-compile docs.in # -alabaster==0.7.12 +alabaster==0.7.16 # via sphinx -babel==2.10.3 +babel==2.14.0 # via sphinx -certifi==2022.6.15 +certifi==2024.2.2 # via requests -charset-normalizer==2.1.0 +charset-normalizer==3.3.2 # via requests -docutils==0.18.1 +docutils==0.20.1 # via sphinx -idna==3.3 +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -jinja2==3.1.2 +jinja2==3.1.3 # via sphinx -markupsafe==2.1.1 +markupsafe==2.1.5 # via jinja2 -packaging==21.3 +packaging==24.0 # via # pallets-sphinx-themes # sphinx -pallets-sphinx-themes==2.0.2 - # via -r requirements/docs.in -pygments==2.12.0 +pallets-sphinx-themes==2.1.1 + # via -r docs.in +pygments==2.17.2 # via sphinx -pyparsing==3.0.9 - # via packaging -pytz==2022.1 - # via babel -requests==2.28.1 +requests==2.31.0 # via sphinx snowballstemmer==2.2.0 # via sphinx -sphinx==5.0.2 +sphinx==7.2.6 # via - # -r requirements/docs.in + # -r docs.in # pallets-sphinx-themes - # sphinx-issues # sphinxcontrib-log-cabinet -sphinx-issues==3.0.1 - # via -r requirements/docs.in -sphinxcontrib-applehelp==1.0.2 +sphinxcontrib-applehelp==1.0.8 # via sphinx -sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-devhelp==1.0.6 # via sphinx -sphinxcontrib-htmlhelp==2.0.0 +sphinxcontrib-htmlhelp==2.0.5 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-log-cabinet==1.0.1 - # via -r requirements/docs.in -sphinxcontrib-qthelp==1.0.3 + # via -r docs.in +sphinxcontrib-qthelp==1.0.7 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib-serializinghtml==1.1.10 # via sphinx -urllib3==1.26.10 +urllib3==2.2.1 # via requests diff --git a/requirements/tests.in b/requirements/tests.in index 3ced491..8228f8e 100644 --- a/requirements/tests.in +++ b/requirements/tests.in @@ -1,7 +1,8 @@ pytest pytest-timeout -pytest-xprocess +# pinned for python 3.8 support +pytest-xprocess<1 cryptography -greenlet ; python_version < "3.11" +greenlet watchdog ephemeral-port-reserve diff --git a/requirements/tests.txt b/requirements/tests.txt index 689d8ba..14b6743 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,44 +1,35 @@ -# SHA1:42b4e3e66395275e048d9a92c294b2c650393866 # -# This file is autogenerated by pip-compile-multi -# To update, run: +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: # -# pip-compile-multi +# pip-compile tests.in # -attrs==21.4.0 - # via pytest -cffi==1.15.1 +cffi==1.16.0 # via cryptography -cryptography==37.0.4 - # via -r requirements/tests.in +cryptography==42.0.5 + # via -r tests.in ephemeral-port-reserve==1.1.4 - # via -r requirements/tests.in -greenlet==1.1.2 ; python_version < "3.11" - # via -r requirements/tests.in -iniconfig==1.1.1 + # via -r tests.in +greenlet==3.0.3 + # via -r tests.in +iniconfig==2.0.0 # via pytest -packaging==21.3 +packaging==24.0 # via pytest -pluggy==1.0.0 +pluggy==1.4.0 # via pytest -psutil==5.9.1 +psutil==5.9.8 # via pytest-xprocess -py==1.11.0 - # via pytest -pycparser==2.21 +pycparser==2.22 # via cffi -pyparsing==3.0.9 - # via packaging -pytest==7.1.2 +pytest==8.1.1 # via - # -r requirements/tests.in + # -r tests.in # pytest-timeout # pytest-xprocess -pytest-timeout==2.1.0 - # via -r requirements/tests.in -pytest-xprocess==0.19.0 - # via -r requirements/tests.in -tomli==2.0.1 - # via pytest -watchdog==2.1.9 - # via -r requirements/tests.in +pytest-timeout==2.3.1 + # via -r tests.in +pytest-xprocess==0.23.0 + # via -r tests.in +watchdog==4.0.0 + # via -r tests.in diff --git a/requirements/typing.in b/requirements/typing.in index e17c43d..096413b 100644 --- a/requirements/typing.in +++ b/requirements/typing.in @@ -1,4 +1,7 @@ mypy +pyright +pytest types-contextvars types-dataclasses types-setuptools +watchdog diff --git a/requirements/typing.txt b/requirements/typing.txt index 1f6de2c..09c78d7 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,21 +1,35 @@ -# SHA1:95499f7e92b572adde012b13e1ec99dbbb2f7089 # -# This file is autogenerated by pip-compile-multi -# To update, run: +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: # -# pip-compile-multi +# pip-compile typing.in # -mypy==0.961 - # via -r requirements/typing.in -mypy-extensions==0.4.3 +iniconfig==2.0.0 + # via pytest +mypy==1.9.0 + # via -r typing.in +mypy-extensions==1.0.0 # via mypy -tomli==2.0.1 - # via mypy -types-contextvars==2.4.7 - # via -r requirements/typing.in +nodeenv==1.8.0 + # via pyright +packaging==24.0 + # via pytest +pluggy==1.4.0 + # via pytest +pyright==1.1.357 + # via -r typing.in +pytest==8.1.1 + # via -r typing.in +types-contextvars==2.4.7.3 + # via -r typing.in types-dataclasses==0.6.6 - # via -r requirements/typing.in -types-setuptools==62.6.1 - # via -r requirements/typing.in -typing-extensions==4.3.0 + # via -r typing.in +types-setuptools==69.2.0.20240317 + # via -r typing.in +typing-extensions==4.11.0 # via mypy +watchdog==4.0.0 + # via -r typing.in + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 2a1c2e4..0000000 --- a/setup.cfg +++ /dev/null @@ -1,130 +0,0 @@ -[metadata] -name = Werkzeug -version = attr: werkzeug.__version__ -url = https://palletsprojects.com/p/werkzeug/ -project_urls = - Donate = https://palletsprojects.com/donate - Documentation = https://werkzeug.palletsprojects.com/ - Changes = https://werkzeug.palletsprojects.com/changes/ - Source Code = https://github.com/pallets/werkzeug/ - Issue Tracker = https://github.com/pallets/werkzeug/issues/ - Twitter = https://twitter.com/PalletsTeam - Chat = https://discord.gg/pallets -license = BSD-3-Clause -author = Armin Ronacher -author_email = armin.ronacher@active-4.com -maintainer = Pallets -maintainer_email = contact@palletsprojects.com -description = The comprehensive WSGI web application library. -long_description = file: README.rst -long_description_content_type = text/x-rst -classifiers = - Development Status :: 5 - Production/Stable - Environment :: Web Environment - Intended Audience :: Developers - License :: OSI Approved :: BSD License - Operating System :: OS Independent - Programming Language :: Python - Topic :: Internet :: WWW/HTTP :: Dynamic Content - Topic :: Internet :: WWW/HTTP :: WSGI - Topic :: Internet :: WWW/HTTP :: WSGI :: Application - Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware - Topic :: Software Development :: Libraries :: Application Frameworks - -[options] -packages = find: -package_dir = = src -include_package_data = True -python_requires = >= 3.7 -# Dependencies are in setup.py for GitHub's dependency graph. - -[options.packages.find] -where = src - -[tool:pytest] -testpaths = tests -filterwarnings = - error -markers = - dev_server: tests that start the dev server - -[coverage:run] -branch = True -source = - werkzeug - tests - -[coverage:paths] -source = - src - */site-packages - -[flake8] -# B = bugbear -# E = pycodestyle errors -# F = flake8 pyflakes -# W = pycodestyle warnings -# B9 = bugbear opinions -# ISC = implicit str concat -select = B, E, F, W, B9, ISC -ignore = - # slice notation whitespace, invalid - E203 - # import at top, too many circular import fixes - E402 - # line length, handled by bugbear B950 - E501 - # bare except, handled by bugbear B001 - E722 - # bin op line break, invalid - W503 -# up to 88 allowed by bugbear B950 -max-line-length = 80 -per-file-ignores = - # __init__ exports names - **/__init__.py: F401 - # LocalProxy assigns lambdas - src/werkzeug/local.py: E731 - -[mypy] -files = src/werkzeug -python_version = 3.7 -show_error_codes = True -allow_redefinition = True -disallow_subclassing_any = True -# disallow_untyped_calls = True -disallow_untyped_defs = True -disallow_incomplete_defs = True -no_implicit_optional = True -local_partial_types = True -no_implicit_reexport = True -strict_equality = True -warn_redundant_casts = True -warn_unused_configs = True -warn_unused_ignores = True -warn_return_any = True -# warn_unreachable = True - -[mypy-werkzeug.wrappers] -no_implicit_reexport = False - -[mypy-colorama.*] -ignore_missing_imports = True - -[mypy-cryptography.*] -ignore_missing_imports = True - -[mypy-eventlet.*] -ignore_missing_imports = True - -[mypy-gevent.*] -ignore_missing_imports = True - -[mypy-greenlet.*] -ignore_missing_imports = True - -[mypy-watchdog.*] -ignore_missing_imports = True - -[mypy-xprocess.*] -ignore_missing_imports = True diff --git a/setup.py b/setup.py deleted file mode 100644 index 413ce2a..0000000 --- a/setup.py +++ /dev/null @@ -1,8 +0,0 @@ -from setuptools import setup - -# Metadata goes in setup.cfg. These are here for GitHub's dependency graph. -setup( - name="Werkzeug", - install_requires=["MarkupSafe>=2.1.1"], - extras_require={"watchdog": ["watchdog"]}, -) diff --git a/src/werkzeug/__init__.py b/src/werkzeug/__init__.py index fd7f8d2..57cb753 100644 --- a/src/werkzeug/__init__.py +++ b/src/werkzeug/__init__.py @@ -1,6 +1,25 @@ +from __future__ import annotations + +import typing as t + from .serving import run_simple as run_simple from .test import Client as Client from .wrappers import Request as Request from .wrappers import Response as Response -__version__ = "2.2.2" + +def __getattr__(name: str) -> t.Any: + if name == "__version__": + import importlib.metadata + import warnings + + warnings.warn( + "The '__version__' attribute is deprecated and will be removed in" + " Werkzeug 3.1. Use feature detection or" + " 'importlib.metadata.version(\"werkzeug\")' instead.", + DeprecationWarning, + stacklevel=2, + ) + return importlib.metadata.version("werkzeug") + + raise AttributeError(name) diff --git a/src/werkzeug/_internal.py b/src/werkzeug/_internal.py index 4636647..7dd2fbc 100644 --- a/src/werkzeug/_internal.py +++ b/src/werkzeug/_internal.py @@ -1,50 +1,18 @@ +from __future__ import annotations + import logging -import operator import re -import string import sys -import typing import typing as t -from datetime import date from datetime import datetime from datetime import timezone -from itertools import chain -from weakref import WeakKeyDictionary if t.TYPE_CHECKING: - from _typeshed.wsgi import StartResponse - from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment - from .wrappers.request import Request # noqa: F401 - -_logger: t.Optional[logging.Logger] = None -_signature_cache = WeakKeyDictionary() # type: ignore -_epoch_ord = date(1970, 1, 1).toordinal() -_legal_cookie_chars = frozenset( - c.encode("ascii") - for c in f"{string.ascii_letters}{string.digits}/=!#$%&'*+-.^_`|~:" -) - -_cookie_quoting_map = {b",": b"\\054", b";": b"\\073", b'"': b'\\"', b"\\": b"\\\\"} -for _i in chain(range(32), range(127, 256)): - _cookie_quoting_map[_i.to_bytes(1, sys.byteorder)] = f"\\{_i:03o}".encode("latin1") - -_octal_re = re.compile(rb"\\[0-3][0-7][0-7]") -_quote_re = re.compile(rb"[\\].") -_legal_cookie_chars_re = rb"[\w\d!#%&\'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]" -_cookie_re = re.compile( - rb""" - (?P[^=;]+) - (?:\s*=\s* - (?P - "(?:[^\\"]|\\.)*" | - (?:.*?) - ) - )? - \s*; -""", - flags=re.VERBOSE, -) + + from .wrappers.request import Request + +_logger: logging.Logger | None = None class _Missing: @@ -58,110 +26,15 @@ def __reduce__(self) -> str: _missing = _Missing() -@typing.overload -def _make_encode_wrapper(reference: str) -> t.Callable[[str], str]: - ... - - -@typing.overload -def _make_encode_wrapper(reference: bytes) -> t.Callable[[str], bytes]: - ... - - -def _make_encode_wrapper(reference: t.AnyStr) -> t.Callable[[str], t.AnyStr]: - """Create a function that will be called with a string argument. If - the reference is bytes, values will be encoded to bytes. - """ - if isinstance(reference, str): - return lambda x: x - - return operator.methodcaller("encode", "latin1") - - -def _check_str_tuple(value: t.Tuple[t.AnyStr, ...]) -> None: - """Ensure tuple items are all strings or all bytes.""" - if not value: - return - - item_type = str if isinstance(value[0], str) else bytes - - if any(not isinstance(item, item_type) for item in value): - raise TypeError(f"Cannot mix str and bytes arguments (got {value!r})") - - -_default_encoding = sys.getdefaultencoding() - - -def _to_bytes( - x: t.Union[str, bytes], charset: str = _default_encoding, errors: str = "strict" -) -> bytes: - if x is None or isinstance(x, bytes): - return x - - if isinstance(x, (bytearray, memoryview)): - return bytes(x) - - if isinstance(x, str): - return x.encode(charset, errors) - - raise TypeError("Expected bytes") - - -@typing.overload -def _to_str( # type: ignore - x: None, - charset: t.Optional[str] = ..., - errors: str = ..., - allow_none_charset: bool = ..., -) -> None: - ... - - -@typing.overload -def _to_str( - x: t.Any, - charset: t.Optional[str] = ..., - errors: str = ..., - allow_none_charset: bool = ..., -) -> str: - ... - - -def _to_str( - x: t.Optional[t.Any], - charset: t.Optional[str] = _default_encoding, - errors: str = "strict", - allow_none_charset: bool = False, -) -> t.Optional[t.Union[str, bytes]]: - if x is None or isinstance(x, str): - return x +def _wsgi_decoding_dance(s: str) -> str: + return s.encode("latin1").decode(errors="replace") - if not isinstance(x, (bytes, bytearray)): - return str(x) - if charset is None: - if allow_none_charset: - return x +def _wsgi_encoding_dance(s: str) -> str: + return s.encode().decode("latin1") - return x.decode(charset, errors) # type: ignore - -def _wsgi_decoding_dance( - s: str, charset: str = "utf-8", errors: str = "replace" -) -> str: - return s.encode("latin1").decode(charset, errors) - - -def _wsgi_encoding_dance( - s: str, charset: str = "utf-8", errors: str = "replace" -) -> str: - if isinstance(s, bytes): - return s.decode("latin1", errors) - - return s.encode(charset).decode("latin1", errors) - - -def _get_environ(obj: t.Union["WSGIEnvironment", "Request"]) -> "WSGIEnvironment": +def _get_environ(obj: WSGIEnvironment | Request) -> WSGIEnvironment: env = getattr(obj, "environ", obj) assert isinstance( env, dict @@ -188,7 +61,7 @@ def _has_level_handler(logger: logging.Logger) -> bool: return False -class _ColorStreamHandler(logging.StreamHandler): +class _ColorStreamHandler(logging.StreamHandler): # type: ignore[type-arg] """On Windows, wrap stream with Colorama for ANSI style support.""" def __init__(self) -> None: @@ -224,17 +97,15 @@ def _log(type: str, message: str, *args: t.Any, **kwargs: t.Any) -> None: getattr(_logger, type)(message.rstrip(), *args, **kwargs) -@typing.overload -def _dt_as_utc(dt: None) -> None: - ... +@t.overload +def _dt_as_utc(dt: None) -> None: ... -@typing.overload -def _dt_as_utc(dt: datetime) -> datetime: - ... +@t.overload +def _dt_as_utc(dt: datetime) -> datetime: ... -def _dt_as_utc(dt: t.Optional[datetime]) -> t.Optional[datetime]: +def _dt_as_utc(dt: datetime | None) -> datetime | None: if dt is None: return dt @@ -257,11 +128,11 @@ class _DictAccessorProperty(t.Generic[_TAccessorValue]): def __init__( self, name: str, - default: t.Optional[_TAccessorValue] = None, - load_func: t.Optional[t.Callable[[str], _TAccessorValue]] = None, - dump_func: t.Optional[t.Callable[[_TAccessorValue], str]] = None, - read_only: t.Optional[bool] = None, - doc: t.Optional[str] = None, + default: _TAccessorValue | None = None, + load_func: t.Callable[[str], _TAccessorValue] | None = None, + dump_func: t.Callable[[_TAccessorValue], str] | None = None, + read_only: bool | None = None, + doc: str | None = None, ) -> None: self.name = name self.default = default @@ -274,19 +145,17 @@ def __init__( def lookup(self, instance: t.Any) -> t.MutableMapping[str, t.Any]: raise NotImplementedError - @typing.overload + @t.overload def __get__( self, instance: None, owner: type - ) -> "_DictAccessorProperty[_TAccessorValue]": - ... + ) -> _DictAccessorProperty[_TAccessorValue]: ... - @typing.overload - def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue: - ... + @t.overload + def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue: ... def __get__( - self, instance: t.Optional[t.Any], owner: type - ) -> t.Union[_TAccessorValue, "_DictAccessorProperty[_TAccessorValue]"]: + self, instance: t.Any | None, owner: type + ) -> _TAccessorValue | _DictAccessorProperty[_TAccessorValue]: if instance is None: return self @@ -324,225 +193,19 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {self.name}>" -def _cookie_quote(b: bytes) -> bytes: - buf = bytearray() - all_legal = True - _lookup = _cookie_quoting_map.get - _push = buf.extend - - for char_int in b: - char = char_int.to_bytes(1, sys.byteorder) - if char not in _legal_cookie_chars: - all_legal = False - char = _lookup(char, char) - _push(char) - - if all_legal: - return bytes(buf) - return bytes(b'"' + buf + b'"') - - -def _cookie_unquote(b: bytes) -> bytes: - if len(b) < 2: - return b - if b[:1] != b'"' or b[-1:] != b'"': - return b - - b = b[1:-1] - - i = 0 - n = len(b) - rv = bytearray() - _push = rv.extend - - while 0 <= i < n: - o_match = _octal_re.search(b, i) - q_match = _quote_re.search(b, i) - if not o_match and not q_match: - rv.extend(b[i:]) - break - j = k = -1 - if o_match: - j = o_match.start(0) - if q_match: - k = q_match.start(0) - if q_match and (not o_match or k < j): - _push(b[i:k]) - _push(b[k + 1 : k + 2]) - i = k + 2 - else: - _push(b[i:j]) - rv.append(int(b[j + 1 : j + 4], 8)) - i = j + 4 - - return bytes(rv) - - -def _cookie_parse_impl(b: bytes) -> t.Iterator[t.Tuple[bytes, bytes]]: - """Lowlevel cookie parsing facility that operates on bytes.""" - i = 0 - n = len(b) - - while i < n: - match = _cookie_re.search(b + b";", i) - if not match: - break +_plain_int_re = re.compile(r"-?\d+", re.ASCII) - key = match.group("key").strip() - value = match.group("val") or b"" - i = match.end(0) - yield key, _cookie_unquote(value) +def _plain_int(value: str) -> int: + """Parse an int only if it is only ASCII digits and ``-``. + This disallows ``+``, ``_``, and non-ASCII digits, which are accepted by ``int`` but + are not allowed in HTTP header values. -def _encode_idna(domain: str) -> bytes: - # If we're given bytes, make sure they fit into ASCII - if isinstance(domain, bytes): - domain.decode("ascii") - return domain - - # Otherwise check if it's already ascii, then return - try: - return domain.encode("ascii") - except UnicodeError: - pass - - # Otherwise encode each part separately - return b".".join(p.encode("idna") for p in domain.split(".")) - + Any leading or trailing whitespace is stripped + """ + value = value.strip() + if _plain_int_re.fullmatch(value) is None: + raise ValueError -def _decode_idna(domain: t.Union[str, bytes]) -> str: - # If the input is a string try to encode it to ascii to do the idna - # decoding. If that fails because of a unicode error, then we - # already have a decoded idna domain. - if isinstance(domain, str): - try: - domain = domain.encode("ascii") - except UnicodeError: - return domain # type: ignore - - # Decode each part separately. If a part fails, try to decode it - # with ascii and silently ignore errors. This makes sense because - # the idna codec does not have error handling. - def decode_part(part: bytes) -> str: - try: - return part.decode("idna") - except UnicodeError: - return part.decode("ascii", "ignore") - - return ".".join(decode_part(p) for p in domain.split(b".")) - - -@typing.overload -def _make_cookie_domain(domain: None) -> None: - ... - - -@typing.overload -def _make_cookie_domain(domain: str) -> bytes: - ... - - -def _make_cookie_domain(domain: t.Optional[str]) -> t.Optional[bytes]: - if domain is None: - return None - domain = _encode_idna(domain) - if b":" in domain: - domain = domain.split(b":", 1)[0] - if b"." in domain: - return domain - raise ValueError( - "Setting 'domain' for a cookie on a server running locally (ex: " - "localhost) is not supported by complying browsers. You should " - "have something like: '127.0.0.1 localhost dev.localhost' on " - "your hosts file and then point your server to run on " - "'dev.localhost' and also set 'domain' for 'dev.localhost'" - ) - - -def _easteregg(app: t.Optional["WSGIApplication"] = None) -> "WSGIApplication": - """Like the name says. But who knows how it works?""" - - def bzzzzzzz(gyver: bytes) -> str: - import base64 - import zlib - - return zlib.decompress(base64.b64decode(gyver)).decode("ascii") - - gyver = "\n".join( - [ - x + (77 - len(x)) * " " - for x in bzzzzzzz( - b""" -eJyFlzuOJDkMRP06xRjymKgDJCDQStBYT8BCgK4gTwfQ2fcFs2a2FzvZk+hvlcRvRJD148efHt9m -9Xz94dRY5hGt1nrYcXx7us9qlcP9HHNh28rz8dZj+q4rynVFFPdlY4zH873NKCexrDM6zxxRymzz -4QIxzK4bth1PV7+uHn6WXZ5C4ka/+prFzx3zWLMHAVZb8RRUxtFXI5DTQ2n3Hi2sNI+HK43AOWSY -jmEzE4naFp58PdzhPMdslLVWHTGUVpSxImw+pS/D+JhzLfdS1j7PzUMxij+mc2U0I9zcbZ/HcZxc -q1QjvvcThMYFnp93agEx392ZdLJWXbi/Ca4Oivl4h/Y1ErEqP+lrg7Xa4qnUKu5UE9UUA4xeqLJ5 -jWlPKJvR2yhRI7xFPdzPuc6adXu6ovwXwRPXXnZHxlPtkSkqWHilsOrGrvcVWXgGP3daXomCj317 -8P2UOw/NnA0OOikZyFf3zZ76eN9QXNwYdD8f8/LdBRFg0BO3bB+Pe/+G8er8tDJv83XTkj7WeMBJ -v/rnAfdO51d6sFglfi8U7zbnr0u9tyJHhFZNXYfH8Iafv2Oa+DT6l8u9UYlajV/hcEgk1x8E8L/r -XJXl2SK+GJCxtnyhVKv6GFCEB1OO3f9YWAIEbwcRWv/6RPpsEzOkXURMN37J0PoCSYeBnJQd9Giu -LxYQJNlYPSo/iTQwgaihbART7Fcyem2tTSCcwNCs85MOOpJtXhXDe0E7zgZJkcxWTar/zEjdIVCk -iXy87FW6j5aGZhttDBoAZ3vnmlkx4q4mMmCdLtnHkBXFMCReqthSGkQ+MDXLLCpXwBs0t+sIhsDI -tjBB8MwqYQpLygZ56rRHHpw+OAVyGgaGRHWy2QfXez+ZQQTTBkmRXdV/A9LwH6XGZpEAZU8rs4pE -1R4FQ3Uwt8RKEtRc0/CrANUoes3EzM6WYcFyskGZ6UTHJWenBDS7h163Eo2bpzqxNE9aVgEM2CqI -GAJe9Yra4P5qKmta27VjzYdR04Vc7KHeY4vs61C0nbywFmcSXYjzBHdiEjraS7PGG2jHHTpJUMxN -Jlxr3pUuFvlBWLJGE3GcA1/1xxLcHmlO+LAXbhrXah1tD6Ze+uqFGdZa5FM+3eHcKNaEarutAQ0A -QMAZHV+ve6LxAwWnXbbSXEG2DmCX5ijeLCKj5lhVFBrMm+ryOttCAeFpUdZyQLAQkA06RLs56rzG -8MID55vqr/g64Qr/wqwlE0TVxgoiZhHrbY2h1iuuyUVg1nlkpDrQ7Vm1xIkI5XRKLedN9EjzVchu -jQhXcVkjVdgP2O99QShpdvXWoSwkp5uMwyjt3jiWCqWGSiaaPAzohjPanXVLbM3x0dNskJsaCEyz -DTKIs+7WKJD4ZcJGfMhLFBf6hlbnNkLEePF8Cx2o2kwmYF4+MzAxa6i+6xIQkswOqGO+3x9NaZX8 -MrZRaFZpLeVTYI9F/djY6DDVVs340nZGmwrDqTCiiqD5luj3OzwpmQCiQhdRYowUYEA3i1WWGwL4 -GCtSoO4XbIPFeKGU13XPkDf5IdimLpAvi2kVDVQbzOOa4KAXMFlpi/hV8F6IDe0Y2reg3PuNKT3i -RYhZqtkQZqSB2Qm0SGtjAw7RDwaM1roESC8HWiPxkoOy0lLTRFG39kvbLZbU9gFKFRvixDZBJmpi -Xyq3RE5lW00EJjaqwp/v3EByMSpVZYsEIJ4APaHmVtpGSieV5CALOtNUAzTBiw81GLgC0quyzf6c -NlWknzJeCsJ5fup2R4d8CYGN77mu5vnO1UqbfElZ9E6cR6zbHjgsr9ly18fXjZoPeDjPuzlWbFwS -pdvPkhntFvkc13qb9094LL5NrA3NIq3r9eNnop9DizWOqCEbyRBFJTHn6Tt3CG1o8a4HevYh0XiJ -sR0AVVHuGuMOIfbuQ/OKBkGRC6NJ4u7sbPX8bG/n5sNIOQ6/Y/BX3IwRlTSabtZpYLB85lYtkkgm -p1qXK3Du2mnr5INXmT/78KI12n11EFBkJHHp0wJyLe9MvPNUGYsf+170maayRoy2lURGHAIapSpQ -krEDuNoJCHNlZYhKpvw4mspVWxqo415n8cD62N9+EfHrAvqQnINStetek7RY2Urv8nxsnGaZfRr/ -nhXbJ6m/yl1LzYqscDZA9QHLNbdaSTTr+kFg3bC0iYbX/eQy0Bv3h4B50/SGYzKAXkCeOLI3bcAt -mj2Z/FM1vQWgDynsRwNvrWnJHlespkrp8+vO1jNaibm+PhqXPPv30YwDZ6jApe3wUjFQobghvW9p -7f2zLkGNv8b191cD/3vs9Q833z8t""" - ).splitlines() - ] - ) - - def easteregged( - environ: "WSGIEnvironment", start_response: "StartResponse" - ) -> t.Iterable[bytes]: - def injecting_start_response( - status: str, headers: t.List[t.Tuple[str, str]], exc_info: t.Any = None - ) -> t.Callable[[bytes], t.Any]: - headers.append(("X-Powered-By", "Werkzeug")) - return start_response(status, headers, exc_info) - - if app is not None and environ.get("QUERY_STRING") != "macgybarchakku": - return app(environ, injecting_start_response) - injecting_start_response("200 OK", [("Content-Type", "text/html")]) - return [ - f"""\ - - - -About Werkzeug - - - -

Werkzeug

-

the Swiss Army knife of Python web development.

-
{gyver}\n\n\n
- -""".encode( - "latin1" - ) - ] - - return easteregged + return int(value) diff --git a/src/werkzeug/_reloader.py b/src/werkzeug/_reloader.py index 57f3117..d7e91a6 100644 --- a/src/werkzeug/_reloader.py +++ b/src/werkzeug/_reloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import fnmatch import os import subprocess @@ -20,7 +22,7 @@ if hasattr(sys, "real_prefix"): # virtualenv < 20 - prefix.add(sys.real_prefix) # type: ignore[attr-defined] + prefix.add(sys.real_prefix) _stat_ignore_scan = tuple(prefix) del prefix @@ -55,13 +57,13 @@ def _iter_module_paths() -> t.Iterator[str]: yield name -def _remove_by_pattern(paths: t.Set[str], exclude_patterns: t.Set[str]) -> None: +def _remove_by_pattern(paths: set[str], exclude_patterns: set[str]) -> None: for pattern in exclude_patterns: paths.difference_update(fnmatch.filter(paths, pattern)) def _find_stat_paths( - extra_files: t.Set[str], exclude_patterns: t.Set[str] + extra_files: set[str], exclude_patterns: set[str] ) -> t.Iterable[str]: """Find paths for the stat reloader to watch. Returns imported module files, Python files under non-system paths. Extra files and @@ -115,7 +117,7 @@ def _find_stat_paths( def _find_watchdog_paths( - extra_files: t.Set[str], exclude_patterns: t.Set[str] + extra_files: set[str], exclude_patterns: set[str] ) -> t.Iterable[str]: """Find paths for the stat reloader to watch. Looks at the same sources as the stat reloader, but watches everything under @@ -139,7 +141,7 @@ def _find_watchdog_paths( def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: - root: t.Dict[str, dict] = {} + root: dict[str, dict[str, t.Any]] = {} for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True): node = root @@ -151,21 +153,28 @@ def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: rv = set() - def _walk(node: t.Mapping[str, dict], path: t.Tuple[str, ...]) -> None: + def _walk(node: t.Mapping[str, dict[str, t.Any]], path: tuple[str, ...]) -> None: for prefix, child in node.items(): _walk(child, path + (prefix,)) - if not node: + # If there are no more nodes, and a path has been accumulated, add it. + # Path may be empty if the "" entry is in sys.path. + if not node and path: rv.add(os.path.join(*path)) _walk(root, ()) return rv -def _get_args_for_reloading() -> t.List[str]: +def _get_args_for_reloading() -> list[str]: """Determine how the script was executed, and return the args needed to execute it again in a new process. """ + if sys.version_info >= (3, 10): + # sys.orig_argv, added in Python 3.10, contains the exact args used to invoke + # Python. Still replace argv[0] with sys.executable for accuracy. + return [sys.executable, *sys.orig_argv[1:]] + rv = [sys.executable] py_script = sys.argv[0] args = sys.argv[1:] @@ -221,15 +230,15 @@ class ReloaderLoop: def __init__( self, - extra_files: t.Optional[t.Iterable[str]] = None, - exclude_patterns: t.Optional[t.Iterable[str]] = None, - interval: t.Union[int, float] = 1, + extra_files: t.Iterable[str] | None = None, + exclude_patterns: t.Iterable[str] | None = None, + interval: int | float = 1, ) -> None: - self.extra_files: t.Set[str] = {os.path.abspath(x) for x in extra_files or ()} - self.exclude_patterns: t.Set[str] = set(exclude_patterns or ()) + self.extra_files: set[str] = {os.path.abspath(x) for x in extra_files or ()} + self.exclude_patterns: set[str] = set(exclude_patterns or ()) self.interval = interval - def __enter__(self) -> "ReloaderLoop": + def __enter__(self) -> ReloaderLoop: """Do any setup, then run one step of the watch to populate the initial filesystem state. """ @@ -281,7 +290,7 @@ class StatReloaderLoop(ReloaderLoop): name = "stat" def __enter__(self) -> ReloaderLoop: - self.mtimes: t.Dict[str, float] = {} + self.mtimes: dict[str, float] = {} return super().__enter__() def run_step(self) -> None: @@ -303,17 +312,22 @@ def run_step(self) -> None: class WatchdogReloaderLoop(ReloaderLoop): def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: - from watchdog.observers import Observer + from watchdog.events import EVENT_TYPE_OPENED + from watchdog.events import FileModifiedEvent from watchdog.events import PatternMatchingEventHandler + from watchdog.observers import Observer super().__init__(*args, **kwargs) trigger_reload = self.trigger_reload - class EventHandler(PatternMatchingEventHandler): # type: ignore - def on_any_event(self, event): # type: ignore + class EventHandler(PatternMatchingEventHandler): + def on_any_event(self, event: FileModifiedEvent): # type: ignore + if event.event_type == EVENT_TYPE_OPENED: + return + trigger_reload(event.src_path) - reloader_name = Observer.__name__.lower() + reloader_name = Observer.__name__.lower() # type: ignore[attr-defined] if reloader_name.endswith("observer"): reloader_name = reloader_name[:-8] @@ -326,7 +340,7 @@ def on_any_event(self, event): # type: ignore # the source file (or initial pyc file) as well. Ignore Git and # Mercurial internal changes. extra_patterns = [p for p in self.extra_files if not os.path.isdir(p)] - self.event_handler = EventHandler( + self.event_handler = EventHandler( # type: ignore[no-untyped-call] patterns=["*.py", "*.pyc", "*.zip", *extra_patterns], ignore_patterns=[ *[f"*/{d}/*" for d in _ignore_common_dirs], @@ -343,12 +357,12 @@ def trigger_reload(self, filename: str) -> None: self.log_reload(filename) def __enter__(self) -> ReloaderLoop: - self.watches: t.Dict[str, t.Any] = {} - self.observer.start() + self.watches: dict[str, t.Any] = {} + self.observer.start() # type: ignore[no-untyped-call] return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore - self.observer.stop() + self.observer.stop() # type: ignore[no-untyped-call] self.observer.join() def run(self) -> None: @@ -364,7 +378,7 @@ def run_step(self) -> None: for path in _find_watchdog_paths(self.extra_files, self.exclude_patterns): if path not in self.watches: try: - self.watches[path] = self.observer.schedule( + self.watches[path] = self.observer.schedule( # type: ignore[no-untyped-call] self.event_handler, path, recursive=True ) except OSError: @@ -379,10 +393,10 @@ def run_step(self) -> None: watch = self.watches.pop(path, None) if watch is not None: - self.observer.unschedule(watch) + self.observer.unschedule(watch) # type: ignore[no-untyped-call] -reloader_loops: t.Dict[str, t.Type[ReloaderLoop]] = { +reloader_loops: dict[str, type[ReloaderLoop]] = { "stat": StatReloaderLoop, "watchdog": WatchdogReloaderLoop, } @@ -416,9 +430,9 @@ def ensure_echo_on() -> None: def run_with_reloader( main_func: t.Callable[[], None], - extra_files: t.Optional[t.Iterable[str]] = None, - exclude_patterns: t.Optional[t.Iterable[str]] = None, - interval: t.Union[int, float] = 1, + extra_files: t.Iterable[str] | None = None, + exclude_patterns: t.Iterable[str] | None = None, + interval: int | float = 1, reloader_type: str = "auto", ) -> None: """Run the given function in an independent Python interpreter.""" diff --git a/src/werkzeug/datastructures.py b/src/werkzeug/datastructures.py deleted file mode 100644 index 43ee8c7..0000000 --- a/src/werkzeug/datastructures.py +++ /dev/null @@ -1,3040 +0,0 @@ -import base64 -import codecs -import mimetypes -import os -import re -from collections.abc import Collection -from collections.abc import MutableSet -from copy import deepcopy -from io import BytesIO -from itertools import repeat -from os import fspath - -from . import exceptions -from ._internal import _missing - - -def is_immutable(self): - raise TypeError(f"{type(self).__name__!r} objects are immutable") - - -def iter_multi_items(mapping): - """Iterates over the items of a mapping yielding keys and values - without dropping any from more complex structures. - """ - if isinstance(mapping, MultiDict): - yield from mapping.items(multi=True) - elif isinstance(mapping, dict): - for key, value in mapping.items(): - if isinstance(value, (tuple, list)): - for v in value: - yield key, v - else: - yield key, value - else: - yield from mapping - - -class ImmutableListMixin: - """Makes a :class:`list` immutable. - - .. versionadded:: 0.5 - - :private: - """ - - _hash_cache = None - - def __hash__(self): - if self._hash_cache is not None: - return self._hash_cache - rv = self._hash_cache = hash(tuple(self)) - return rv - - def __reduce_ex__(self, protocol): - return type(self), (list(self),) - - def __delitem__(self, key): - is_immutable(self) - - def __iadd__(self, other): - is_immutable(self) - - def __imul__(self, other): - is_immutable(self) - - def __setitem__(self, key, value): - is_immutable(self) - - def append(self, item): - is_immutable(self) - - def remove(self, item): - is_immutable(self) - - def extend(self, iterable): - is_immutable(self) - - def insert(self, pos, value): - is_immutable(self) - - def pop(self, index=-1): - is_immutable(self) - - def reverse(self): - is_immutable(self) - - def sort(self, key=None, reverse=False): - is_immutable(self) - - -class ImmutableList(ImmutableListMixin, list): - """An immutable :class:`list`. - - .. versionadded:: 0.5 - - :private: - """ - - def __repr__(self): - return f"{type(self).__name__}({list.__repr__(self)})" - - -class ImmutableDictMixin: - """Makes a :class:`dict` immutable. - - .. versionadded:: 0.5 - - :private: - """ - - _hash_cache = None - - @classmethod - def fromkeys(cls, keys, value=None): - instance = super().__new__(cls) - instance.__init__(zip(keys, repeat(value))) - return instance - - def __reduce_ex__(self, protocol): - return type(self), (dict(self),) - - def _iter_hashitems(self): - return self.items() - - def __hash__(self): - if self._hash_cache is not None: - return self._hash_cache - rv = self._hash_cache = hash(frozenset(self._iter_hashitems())) - return rv - - def setdefault(self, key, default=None): - is_immutable(self) - - def update(self, *args, **kwargs): - is_immutable(self) - - def pop(self, key, default=None): - is_immutable(self) - - def popitem(self): - is_immutable(self) - - def __setitem__(self, key, value): - is_immutable(self) - - def __delitem__(self, key): - is_immutable(self) - - def clear(self): - is_immutable(self) - - -class ImmutableMultiDictMixin(ImmutableDictMixin): - """Makes a :class:`MultiDict` immutable. - - .. versionadded:: 0.5 - - :private: - """ - - def __reduce_ex__(self, protocol): - return type(self), (list(self.items(multi=True)),) - - def _iter_hashitems(self): - return self.items(multi=True) - - def add(self, key, value): - is_immutable(self) - - def popitemlist(self): - is_immutable(self) - - def poplist(self, key): - is_immutable(self) - - def setlist(self, key, new_list): - is_immutable(self) - - def setlistdefault(self, key, default_list=None): - is_immutable(self) - - -def _calls_update(name): - def oncall(self, *args, **kw): - rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) - - if self.on_update is not None: - self.on_update(self) - - return rv - - oncall.__name__ = name - return oncall - - -class UpdateDictMixin(dict): - """Makes dicts call `self.on_update` on modifications. - - .. versionadded:: 0.5 - - :private: - """ - - on_update = None - - def setdefault(self, key, default=None): - modified = key not in self - rv = super().setdefault(key, default) - if modified and self.on_update is not None: - self.on_update(self) - return rv - - def pop(self, key, default=_missing): - modified = key in self - if default is _missing: - rv = super().pop(key) - else: - rv = super().pop(key, default) - if modified and self.on_update is not None: - self.on_update(self) - return rv - - __setitem__ = _calls_update("__setitem__") - __delitem__ = _calls_update("__delitem__") - clear = _calls_update("clear") - popitem = _calls_update("popitem") - update = _calls_update("update") - - -class TypeConversionDict(dict): - """Works like a regular dict but the :meth:`get` method can perform - type conversions. :class:`MultiDict` and :class:`CombinedMultiDict` - are subclasses of this class and provide the same feature. - - .. versionadded:: 0.5 - """ - - def get(self, key, default=None, type=None): - """Return the default value if the requested data doesn't exist. - If `type` is provided and is a callable it should convert the value, - return it or raise a :exc:`ValueError` if that is not possible. In - this case the function will return the default as if the value was not - found: - - >>> d = TypeConversionDict(foo='42', bar='blub') - >>> d.get('foo', type=int) - 42 - >>> d.get('bar', -1, type=int) - -1 - - :param key: The key to be looked up. - :param default: The default value to be returned if the key can't - be looked up. If not further specified `None` is - returned. - :param type: A callable that is used to cast the value in the - :class:`MultiDict`. If a :exc:`ValueError` is raised - by this callable the default value is returned. - """ - try: - rv = self[key] - except KeyError: - return default - if type is not None: - try: - rv = type(rv) - except ValueError: - rv = default - return rv - - -class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): - """Works like a :class:`TypeConversionDict` but does not support - modifications. - - .. versionadded:: 0.5 - """ - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return TypeConversionDict(self) - - def __copy__(self): - return self - - -class MultiDict(TypeConversionDict): - """A :class:`MultiDict` is a dictionary subclass customized to deal with - multiple values for the same key which is for example used by the parsing - functions in the wrappers. This is necessary because some HTML form - elements pass multiple values for the same key. - - :class:`MultiDict` implements all standard dictionary methods. - Internally, it saves all values for a key as a list, but the standard dict - access methods will only return the first value for a key. If you want to - gain access to the other values, too, you have to use the `list` methods as - explained below. - - Basic Usage: - - >>> d = MultiDict([('a', 'b'), ('a', 'c')]) - >>> d - MultiDict([('a', 'b'), ('a', 'c')]) - >>> d['a'] - 'b' - >>> d.getlist('a') - ['b', 'c'] - >>> 'a' in d - True - - It behaves like a normal dict thus all dict functions will only return the - first value when multiple values for one key are found. - - From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a - subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will - render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP - exceptions. - - A :class:`MultiDict` can be constructed from an iterable of - ``(key, value)`` tuples, a dict, a :class:`MultiDict` or from Werkzeug 0.2 - onwards some keyword parameters. - - :param mapping: the initial value for the :class:`MultiDict`. Either a - regular dict, an iterable of ``(key, value)`` tuples - or `None`. - """ - - def __init__(self, mapping=None): - if isinstance(mapping, MultiDict): - dict.__init__(self, ((k, l[:]) for k, l in mapping.lists())) - elif isinstance(mapping, dict): - tmp = {} - for key, value in mapping.items(): - if isinstance(value, (tuple, list)): - if len(value) == 0: - continue - value = list(value) - else: - value = [value] - tmp[key] = value - dict.__init__(self, tmp) - else: - tmp = {} - for key, value in mapping or (): - tmp.setdefault(key, []).append(value) - dict.__init__(self, tmp) - - def __getstate__(self): - return dict(self.lists()) - - def __setstate__(self, value): - dict.clear(self) - dict.update(self, value) - - def __iter__(self): - # Work around https://bugs.python.org/issue43246. - # (`return super().__iter__()` also works here, which makes this look - # even more like it should be a no-op, yet it isn't.) - return dict.__iter__(self) - - def __getitem__(self, key): - """Return the first data value for this key; - raises KeyError if not found. - - :param key: The key to be looked up. - :raise KeyError: if the key does not exist. - """ - - if key in self: - lst = dict.__getitem__(self, key) - if len(lst) > 0: - return lst[0] - raise exceptions.BadRequestKeyError(key) - - def __setitem__(self, key, value): - """Like :meth:`add` but removes an existing key first. - - :param key: the key for the value. - :param value: the value to set. - """ - dict.__setitem__(self, key, [value]) - - def add(self, key, value): - """Adds a new value for the key. - - .. versionadded:: 0.6 - - :param key: the key for the value. - :param value: the value to add. - """ - dict.setdefault(self, key, []).append(value) - - def getlist(self, key, type=None): - """Return the list of items for a given key. If that key is not in the - `MultiDict`, the return value will be an empty list. Just like `get`, - `getlist` accepts a `type` parameter. All items will be converted - with the callable defined there. - - :param key: The key to be looked up. - :param type: A callable that is used to cast the value in the - :class:`MultiDict`. If a :exc:`ValueError` is raised - by this callable the value will be removed from the list. - :return: a :class:`list` of all the values for the key. - """ - try: - rv = dict.__getitem__(self, key) - except KeyError: - return [] - if type is None: - return list(rv) - result = [] - for item in rv: - try: - result.append(type(item)) - except ValueError: - pass - return result - - def setlist(self, key, new_list): - """Remove the old values for a key and add new ones. Note that the list - you pass the values in will be shallow-copied before it is inserted in - the dictionary. - - >>> d = MultiDict() - >>> d.setlist('foo', ['1', '2']) - >>> d['foo'] - '1' - >>> d.getlist('foo') - ['1', '2'] - - :param key: The key for which the values are set. - :param new_list: An iterable with the new values for the key. Old values - are removed first. - """ - dict.__setitem__(self, key, list(new_list)) - - def setdefault(self, key, default=None): - """Returns the value for the key if it is in the dict, otherwise it - returns `default` and sets that value for `key`. - - :param key: The key to be looked up. - :param default: The default value to be returned if the key is not - in the dict. If not further specified it's `None`. - """ - if key not in self: - self[key] = default - else: - default = self[key] - return default - - def setlistdefault(self, key, default_list=None): - """Like `setdefault` but sets multiple values. The list returned - is not a copy, but the list that is actually used internally. This - means that you can put new values into the dict by appending items - to the list: - - >>> d = MultiDict({"foo": 1}) - >>> d.setlistdefault("foo").extend([2, 3]) - >>> d.getlist("foo") - [1, 2, 3] - - :param key: The key to be looked up. - :param default_list: An iterable of default values. It is either copied - (in case it was a list) or converted into a list - before returned. - :return: a :class:`list` - """ - if key not in self: - default_list = list(default_list or ()) - dict.__setitem__(self, key, default_list) - else: - default_list = dict.__getitem__(self, key) - return default_list - - def items(self, multi=False): - """Return an iterator of ``(key, value)`` pairs. - - :param multi: If set to `True` the iterator returned will have a pair - for each value of each key. Otherwise it will only - contain pairs for the first value of each key. - """ - for key, values in dict.items(self): - if multi: - for value in values: - yield key, value - else: - yield key, values[0] - - def lists(self): - """Return a iterator of ``(key, values)`` pairs, where values is the list - of all values associated with the key.""" - for key, values in dict.items(self): - yield key, list(values) - - def values(self): - """Returns an iterator of the first value on every key's value list.""" - for values in dict.values(self): - yield values[0] - - def listvalues(self): - """Return an iterator of all values associated with a key. Zipping - :meth:`keys` and this is the same as calling :meth:`lists`: - - >>> d = MultiDict({"foo": [1, 2, 3]}) - >>> zip(d.keys(), d.listvalues()) == d.lists() - True - """ - return dict.values(self) - - def copy(self): - """Return a shallow copy of this object.""" - return self.__class__(self) - - def deepcopy(self, memo=None): - """Return a deep copy of this object.""" - return self.__class__(deepcopy(self.to_dict(flat=False), memo)) - - def to_dict(self, flat=True): - """Return the contents as regular dict. If `flat` is `True` the - returned dict will only have the first item present, if `flat` is - `False` all values will be returned as lists. - - :param flat: If set to `False` the dict returned will have lists - with all the values in it. Otherwise it will only - contain the first value for each key. - :return: a :class:`dict` - """ - if flat: - return dict(self.items()) - return dict(self.lists()) - - def update(self, mapping): - """update() extends rather than replaces existing key lists: - - >>> a = MultiDict({'x': 1}) - >>> b = MultiDict({'x': 2, 'y': 3}) - >>> a.update(b) - >>> a - MultiDict([('y', 3), ('x', 1), ('x', 2)]) - - If the value list for a key in ``other_dict`` is empty, no new values - will be added to the dict and the key will not be created: - - >>> x = {'empty_list': []} - >>> y = MultiDict() - >>> y.update(x) - >>> y - MultiDict([]) - """ - for key, value in iter_multi_items(mapping): - MultiDict.add(self, key, value) - - def pop(self, key, default=_missing): - """Pop the first item for a list on the dict. Afterwards the - key is removed from the dict, so additional values are discarded: - - >>> d = MultiDict({"foo": [1, 2, 3]}) - >>> d.pop("foo") - 1 - >>> "foo" in d - False - - :param key: the key to pop. - :param default: if provided the value to return if the key was - not in the dictionary. - """ - try: - lst = dict.pop(self, key) - - if len(lst) == 0: - raise exceptions.BadRequestKeyError(key) - - return lst[0] - except KeyError: - if default is not _missing: - return default - - raise exceptions.BadRequestKeyError(key) from None - - def popitem(self): - """Pop an item from the dict.""" - try: - item = dict.popitem(self) - - if len(item[1]) == 0: - raise exceptions.BadRequestKeyError(item[0]) - - return (item[0], item[1][0]) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - def poplist(self, key): - """Pop the list for a key from the dict. If the key is not in the dict - an empty list is returned. - - .. versionchanged:: 0.5 - If the key does no longer exist a list is returned instead of - raising an error. - """ - return dict.pop(self, key, []) - - def popitemlist(self): - """Pop a ``(key, list)`` tuple from the dict.""" - try: - return dict.popitem(self) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - def __copy__(self): - return self.copy() - - def __deepcopy__(self, memo): - return self.deepcopy(memo=memo) - - def __repr__(self): - return f"{type(self).__name__}({list(self.items(multi=True))!r})" - - -class _omd_bucket: - """Wraps values in the :class:`OrderedMultiDict`. This makes it - possible to keep an order over multiple different keys. It requires - a lot of extra memory and slows down access a lot, but makes it - possible to access elements in O(1) and iterate in O(n). - """ - - __slots__ = ("prev", "key", "value", "next") - - def __init__(self, omd, key, value): - self.prev = omd._last_bucket - self.key = key - self.value = value - self.next = None - - if omd._first_bucket is None: - omd._first_bucket = self - if omd._last_bucket is not None: - omd._last_bucket.next = self - omd._last_bucket = self - - def unlink(self, omd): - if self.prev: - self.prev.next = self.next - if self.next: - self.next.prev = self.prev - if omd._first_bucket is self: - omd._first_bucket = self.next - if omd._last_bucket is self: - omd._last_bucket = self.prev - - -class OrderedMultiDict(MultiDict): - """Works like a regular :class:`MultiDict` but preserves the - order of the fields. To convert the ordered multi dict into a - list you can use the :meth:`items` method and pass it ``multi=True``. - - In general an :class:`OrderedMultiDict` is an order of magnitude - slower than a :class:`MultiDict`. - - .. admonition:: note - - Due to a limitation in Python you cannot convert an ordered - multi dict into a regular dict by using ``dict(multidict)``. - Instead you have to use the :meth:`to_dict` method, otherwise - the internal bucket objects are exposed. - """ - - def __init__(self, mapping=None): - dict.__init__(self) - self._first_bucket = self._last_bucket = None - if mapping is not None: - OrderedMultiDict.update(self, mapping) - - def __eq__(self, other): - if not isinstance(other, MultiDict): - return NotImplemented - if isinstance(other, OrderedMultiDict): - iter1 = iter(self.items(multi=True)) - iter2 = iter(other.items(multi=True)) - try: - for k1, v1 in iter1: - k2, v2 = next(iter2) - if k1 != k2 or v1 != v2: - return False - except StopIteration: - return False - try: - next(iter2) - except StopIteration: - return True - return False - if len(self) != len(other): - return False - for key, values in self.lists(): - if other.getlist(key) != values: - return False - return True - - __hash__ = None - - def __reduce_ex__(self, protocol): - return type(self), (list(self.items(multi=True)),) - - def __getstate__(self): - return list(self.items(multi=True)) - - def __setstate__(self, values): - dict.clear(self) - for key, value in values: - self.add(key, value) - - def __getitem__(self, key): - if key in self: - return dict.__getitem__(self, key)[0].value - raise exceptions.BadRequestKeyError(key) - - def __setitem__(self, key, value): - self.poplist(key) - self.add(key, value) - - def __delitem__(self, key): - self.pop(key) - - def keys(self): - return (key for key, value in self.items()) - - def __iter__(self): - return iter(self.keys()) - - def values(self): - return (value for key, value in self.items()) - - def items(self, multi=False): - ptr = self._first_bucket - if multi: - while ptr is not None: - yield ptr.key, ptr.value - ptr = ptr.next - else: - returned_keys = set() - while ptr is not None: - if ptr.key not in returned_keys: - returned_keys.add(ptr.key) - yield ptr.key, ptr.value - ptr = ptr.next - - def lists(self): - returned_keys = set() - ptr = self._first_bucket - while ptr is not None: - if ptr.key not in returned_keys: - yield ptr.key, self.getlist(ptr.key) - returned_keys.add(ptr.key) - ptr = ptr.next - - def listvalues(self): - for _key, values in self.lists(): - yield values - - def add(self, key, value): - dict.setdefault(self, key, []).append(_omd_bucket(self, key, value)) - - def getlist(self, key, type=None): - try: - rv = dict.__getitem__(self, key) - except KeyError: - return [] - if type is None: - return [x.value for x in rv] - result = [] - for item in rv: - try: - result.append(type(item.value)) - except ValueError: - pass - return result - - def setlist(self, key, new_list): - self.poplist(key) - for value in new_list: - self.add(key, value) - - def setlistdefault(self, key, default_list=None): - raise TypeError("setlistdefault is unsupported for ordered multi dicts") - - def update(self, mapping): - for key, value in iter_multi_items(mapping): - OrderedMultiDict.add(self, key, value) - - def poplist(self, key): - buckets = dict.pop(self, key, ()) - for bucket in buckets: - bucket.unlink(self) - return [x.value for x in buckets] - - def pop(self, key, default=_missing): - try: - buckets = dict.pop(self, key) - except KeyError: - if default is not _missing: - return default - - raise exceptions.BadRequestKeyError(key) from None - - for bucket in buckets: - bucket.unlink(self) - - return buckets[0].value - - def popitem(self): - try: - key, buckets = dict.popitem(self) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - for bucket in buckets: - bucket.unlink(self) - - return key, buckets[0].value - - def popitemlist(self): - try: - key, buckets = dict.popitem(self) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - for bucket in buckets: - bucket.unlink(self) - - return key, [x.value for x in buckets] - - -def _options_header_vkw(value, kw): - return http.dump_options_header( - value, {k.replace("_", "-"): v for k, v in kw.items()} - ) - - -def _unicodify_header_value(value): - if isinstance(value, bytes): - value = value.decode("latin-1") - if not isinstance(value, str): - value = str(value) - return value - - -class Headers: - """An object that stores some headers. It has a dict-like interface, - but is ordered, can store the same key multiple times, and iterating - yields ``(key, value)`` pairs instead of only keys. - - This data structure is useful if you want a nicer way to handle WSGI - headers which are stored as tuples in a list. - - From Werkzeug 0.3 onwards, the :exc:`KeyError` raised by this class is - also a subclass of the :class:`~exceptions.BadRequest` HTTP exception - and will render a page for a ``400 BAD REQUEST`` if caught in a - catch-all for HTTP exceptions. - - Headers is mostly compatible with the Python :class:`wsgiref.headers.Headers` - class, with the exception of `__getitem__`. :mod:`wsgiref` will return - `None` for ``headers['missing']``, whereas :class:`Headers` will raise - a :class:`KeyError`. - - To create a new ``Headers`` object, pass it a list, dict, or - other ``Headers`` object with default values. These values are - validated the same way values added later are. - - :param defaults: The list of default values for the :class:`Headers`. - - .. versionchanged:: 2.1.0 - Default values are validated the same as values added later. - - .. versionchanged:: 0.9 - This data structure now stores unicode values similar to how the - multi dicts do it. The main difference is that bytes can be set as - well which will automatically be latin1 decoded. - - .. versionchanged:: 0.9 - The :meth:`linked` function was removed without replacement as it - was an API that does not support the changes to the encoding model. - """ - - def __init__(self, defaults=None): - self._list = [] - if defaults is not None: - self.extend(defaults) - - def __getitem__(self, key, _get_mode=False): - if not _get_mode: - if isinstance(key, int): - return self._list[key] - elif isinstance(key, slice): - return self.__class__(self._list[key]) - if not isinstance(key, str): - raise exceptions.BadRequestKeyError(key) - ikey = key.lower() - for k, v in self._list: - if k.lower() == ikey: - return v - # micro optimization: if we are in get mode we will catch that - # exception one stack level down so we can raise a standard - # key error instead of our special one. - if _get_mode: - raise KeyError() - raise exceptions.BadRequestKeyError(key) - - def __eq__(self, other): - def lowered(item): - return (item[0].lower(),) + item[1:] - - return other.__class__ is self.__class__ and set( - map(lowered, other._list) - ) == set(map(lowered, self._list)) - - __hash__ = None - - def get(self, key, default=None, type=None, as_bytes=False): - """Return the default value if the requested data doesn't exist. - If `type` is provided and is a callable it should convert the value, - return it or raise a :exc:`ValueError` if that is not possible. In - this case the function will return the default as if the value was not - found: - - >>> d = Headers([('Content-Length', '42')]) - >>> d.get('Content-Length', type=int) - 42 - - .. versionadded:: 0.9 - Added support for `as_bytes`. - - :param key: The key to be looked up. - :param default: The default value to be returned if the key can't - be looked up. If not further specified `None` is - returned. - :param type: A callable that is used to cast the value in the - :class:`Headers`. If a :exc:`ValueError` is raised - by this callable the default value is returned. - :param as_bytes: return bytes instead of strings. - """ - try: - rv = self.__getitem__(key, _get_mode=True) - except KeyError: - return default - if as_bytes: - rv = rv.encode("latin1") - if type is None: - return rv - try: - return type(rv) - except ValueError: - return default - - def getlist(self, key, type=None, as_bytes=False): - """Return the list of items for a given key. If that key is not in the - :class:`Headers`, the return value will be an empty list. Just like - :meth:`get`, :meth:`getlist` accepts a `type` parameter. All items will - be converted with the callable defined there. - - .. versionadded:: 0.9 - Added support for `as_bytes`. - - :param key: The key to be looked up. - :param type: A callable that is used to cast the value in the - :class:`Headers`. If a :exc:`ValueError` is raised - by this callable the value will be removed from the list. - :return: a :class:`list` of all the values for the key. - :param as_bytes: return bytes instead of strings. - """ - ikey = key.lower() - result = [] - for k, v in self: - if k.lower() == ikey: - if as_bytes: - v = v.encode("latin1") - if type is not None: - try: - v = type(v) - except ValueError: - continue - result.append(v) - return result - - def get_all(self, name): - """Return a list of all the values for the named field. - - This method is compatible with the :mod:`wsgiref` - :meth:`~wsgiref.headers.Headers.get_all` method. - """ - return self.getlist(name) - - def items(self, lower=False): - for key, value in self: - if lower: - key = key.lower() - yield key, value - - def keys(self, lower=False): - for key, _ in self.items(lower): - yield key - - def values(self): - for _, value in self.items(): - yield value - - def extend(self, *args, **kwargs): - """Extend headers in this object with items from another object - containing header items as well as keyword arguments. - - To replace existing keys instead of extending, use - :meth:`update` instead. - - If provided, the first argument can be another :class:`Headers` - object, a :class:`MultiDict`, :class:`dict`, or iterable of - pairs. - - .. versionchanged:: 1.0 - Support :class:`MultiDict`. Allow passing ``kwargs``. - """ - if len(args) > 1: - raise TypeError(f"update expected at most 1 arguments, got {len(args)}") - - if args: - for key, value in iter_multi_items(args[0]): - self.add(key, value) - - for key, value in iter_multi_items(kwargs): - self.add(key, value) - - def __delitem__(self, key, _index_operation=True): - if _index_operation and isinstance(key, (int, slice)): - del self._list[key] - return - key = key.lower() - new = [] - for k, v in self._list: - if k.lower() != key: - new.append((k, v)) - self._list[:] = new - - def remove(self, key): - """Remove a key. - - :param key: The key to be removed. - """ - return self.__delitem__(key, _index_operation=False) - - def pop(self, key=None, default=_missing): - """Removes and returns a key or index. - - :param key: The key to be popped. If this is an integer the item at - that position is removed, if it's a string the value for - that key is. If the key is omitted or `None` the last - item is removed. - :return: an item. - """ - if key is None: - return self._list.pop() - if isinstance(key, int): - return self._list.pop(key) - try: - rv = self[key] - self.remove(key) - except KeyError: - if default is not _missing: - return default - raise - return rv - - def popitem(self): - """Removes a key or index and returns a (key, value) item.""" - return self.pop() - - def __contains__(self, key): - """Check if a key is present.""" - try: - self.__getitem__(key, _get_mode=True) - except KeyError: - return False - return True - - def __iter__(self): - """Yield ``(key, value)`` tuples.""" - return iter(self._list) - - def __len__(self): - return len(self._list) - - def add(self, _key, _value, **kw): - """Add a new header tuple to the list. - - Keyword arguments can specify additional parameters for the header - value, with underscores converted to dashes:: - - >>> d = Headers() - >>> d.add('Content-Type', 'text/plain') - >>> d.add('Content-Disposition', 'attachment', filename='foo.png') - - The keyword argument dumping uses :func:`dump_options_header` - behind the scenes. - - .. versionadded:: 0.4.1 - keyword arguments were added for :mod:`wsgiref` compatibility. - """ - if kw: - _value = _options_header_vkw(_value, kw) - _key = _unicodify_header_value(_key) - _value = _unicodify_header_value(_value) - self._validate_value(_value) - self._list.append((_key, _value)) - - def _validate_value(self, value): - if not isinstance(value, str): - raise TypeError("Value should be a string.") - if "\n" in value or "\r" in value: - raise ValueError( - "Detected newline in header value. This is " - "a potential security problem" - ) - - def add_header(self, _key, _value, **_kw): - """Add a new header tuple to the list. - - An alias for :meth:`add` for compatibility with the :mod:`wsgiref` - :meth:`~wsgiref.headers.Headers.add_header` method. - """ - self.add(_key, _value, **_kw) - - def clear(self): - """Clears all headers.""" - del self._list[:] - - def set(self, _key, _value, **kw): - """Remove all header tuples for `key` and add a new one. The newly - added key either appears at the end of the list if there was no - entry or replaces the first one. - - Keyword arguments can specify additional parameters for the header - value, with underscores converted to dashes. See :meth:`add` for - more information. - - .. versionchanged:: 0.6.1 - :meth:`set` now accepts the same arguments as :meth:`add`. - - :param key: The key to be inserted. - :param value: The value to be inserted. - """ - if kw: - _value = _options_header_vkw(_value, kw) - _key = _unicodify_header_value(_key) - _value = _unicodify_header_value(_value) - self._validate_value(_value) - if not self._list: - self._list.append((_key, _value)) - return - listiter = iter(self._list) - ikey = _key.lower() - for idx, (old_key, _old_value) in enumerate(listiter): - if old_key.lower() == ikey: - # replace first occurrence - self._list[idx] = (_key, _value) - break - else: - self._list.append((_key, _value)) - return - self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey] - - def setlist(self, key, values): - """Remove any existing values for a header and add new ones. - - :param key: The header key to set. - :param values: An iterable of values to set for the key. - - .. versionadded:: 1.0 - """ - if values: - values_iter = iter(values) - self.set(key, next(values_iter)) - - for value in values_iter: - self.add(key, value) - else: - self.remove(key) - - def setdefault(self, key, default): - """Return the first value for the key if it is in the headers, - otherwise set the header to the value given by ``default`` and - return that. - - :param key: The header key to get. - :param default: The value to set for the key if it is not in the - headers. - """ - if key in self: - return self[key] - - self.set(key, default) - return default - - def setlistdefault(self, key, default): - """Return the list of values for the key if it is in the - headers, otherwise set the header to the list of values given - by ``default`` and return that. - - Unlike :meth:`MultiDict.setlistdefault`, modifying the returned - list will not affect the headers. - - :param key: The header key to get. - :param default: An iterable of values to set for the key if it - is not in the headers. - - .. versionadded:: 1.0 - """ - if key not in self: - self.setlist(key, default) - - return self.getlist(key) - - def __setitem__(self, key, value): - """Like :meth:`set` but also supports index/slice based setting.""" - if isinstance(key, (slice, int)): - if isinstance(key, int): - value = [value] - value = [ - (_unicodify_header_value(k), _unicodify_header_value(v)) - for (k, v) in value - ] - for (_, v) in value: - self._validate_value(v) - if isinstance(key, int): - self._list[key] = value[0] - else: - self._list[key] = value - else: - self.set(key, value) - - def update(self, *args, **kwargs): - """Replace headers in this object with items from another - headers object and keyword arguments. - - To extend existing keys instead of replacing, use :meth:`extend` - instead. - - If provided, the first argument can be another :class:`Headers` - object, a :class:`MultiDict`, :class:`dict`, or iterable of - pairs. - - .. versionadded:: 1.0 - """ - if len(args) > 1: - raise TypeError(f"update expected at most 1 arguments, got {len(args)}") - - if args: - mapping = args[0] - - if isinstance(mapping, (Headers, MultiDict)): - for key in mapping.keys(): - self.setlist(key, mapping.getlist(key)) - elif isinstance(mapping, dict): - for key, value in mapping.items(): - if isinstance(value, (list, tuple)): - self.setlist(key, value) - else: - self.set(key, value) - else: - for key, value in mapping: - self.set(key, value) - - for key, value in kwargs.items(): - if isinstance(value, (list, tuple)): - self.setlist(key, value) - else: - self.set(key, value) - - def to_wsgi_list(self): - """Convert the headers into a list suitable for WSGI. - - :return: list - """ - return list(self) - - def copy(self): - return self.__class__(self._list) - - def __copy__(self): - return self.copy() - - def __str__(self): - """Returns formatted headers suitable for HTTP transmission.""" - strs = [] - for key, value in self.to_wsgi_list(): - strs.append(f"{key}: {value}") - strs.append("\r\n") - return "\r\n".join(strs) - - def __repr__(self): - return f"{type(self).__name__}({list(self)!r})" - - -class ImmutableHeadersMixin: - """Makes a :class:`Headers` immutable. We do not mark them as - hashable though since the only usecase for this datastructure - in Werkzeug is a view on a mutable structure. - - .. versionadded:: 0.5 - - :private: - """ - - def __delitem__(self, key, **kwargs): - is_immutable(self) - - def __setitem__(self, key, value): - is_immutable(self) - - def set(self, _key, _value, **kw): - is_immutable(self) - - def setlist(self, key, values): - is_immutable(self) - - def add(self, _key, _value, **kw): - is_immutable(self) - - def add_header(self, _key, _value, **_kw): - is_immutable(self) - - def remove(self, key): - is_immutable(self) - - def extend(self, *args, **kwargs): - is_immutable(self) - - def update(self, *args, **kwargs): - is_immutable(self) - - def insert(self, pos, value): - is_immutable(self) - - def pop(self, key=None, default=_missing): - is_immutable(self) - - def popitem(self): - is_immutable(self) - - def setdefault(self, key, default): - is_immutable(self) - - def setlistdefault(self, key, default): - is_immutable(self) - - -class EnvironHeaders(ImmutableHeadersMixin, Headers): - """Read only version of the headers from a WSGI environment. This - provides the same interface as `Headers` and is constructed from - a WSGI environment. - - From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a - subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will - render a page for a ``400 BAD REQUEST`` if caught in a catch-all for - HTTP exceptions. - """ - - def __init__(self, environ): - self.environ = environ - - def __eq__(self, other): - return self.environ is other.environ - - __hash__ = None - - def __getitem__(self, key, _get_mode=False): - # _get_mode is a no-op for this class as there is no index but - # used because get() calls it. - if not isinstance(key, str): - raise KeyError(key) - key = key.upper().replace("-", "_") - if key in ("CONTENT_TYPE", "CONTENT_LENGTH"): - return _unicodify_header_value(self.environ[key]) - return _unicodify_header_value(self.environ[f"HTTP_{key}"]) - - def __len__(self): - # the iter is necessary because otherwise list calls our - # len which would call list again and so forth. - return len(list(iter(self))) - - def __iter__(self): - for key, value in self.environ.items(): - if key.startswith("HTTP_") and key not in ( - "HTTP_CONTENT_TYPE", - "HTTP_CONTENT_LENGTH", - ): - yield ( - key[5:].replace("_", "-").title(), - _unicodify_header_value(value), - ) - elif key in ("CONTENT_TYPE", "CONTENT_LENGTH") and value: - yield (key.replace("_", "-").title(), _unicodify_header_value(value)) - - def copy(self): - raise TypeError(f"cannot create {type(self).__name__!r} copies") - - -class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): - """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict` - instances as sequence and it will combine the return values of all wrapped - dicts: - - >>> from werkzeug.datastructures import CombinedMultiDict, MultiDict - >>> post = MultiDict([('foo', 'bar')]) - >>> get = MultiDict([('blub', 'blah')]) - >>> combined = CombinedMultiDict([get, post]) - >>> combined['foo'] - 'bar' - >>> combined['blub'] - 'blah' - - This works for all read operations and will raise a `TypeError` for - methods that usually change data which isn't possible. - - From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a - subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will - render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP - exceptions. - """ - - def __reduce_ex__(self, protocol): - return type(self), (self.dicts,) - - def __init__(self, dicts=None): - self.dicts = list(dicts) or [] - - @classmethod - def fromkeys(cls, keys, value=None): - raise TypeError(f"cannot create {cls.__name__!r} instances by fromkeys") - - def __getitem__(self, key): - for d in self.dicts: - if key in d: - return d[key] - raise exceptions.BadRequestKeyError(key) - - def get(self, key, default=None, type=None): - for d in self.dicts: - if key in d: - if type is not None: - try: - return type(d[key]) - except ValueError: - continue - return d[key] - return default - - def getlist(self, key, type=None): - rv = [] - for d in self.dicts: - rv.extend(d.getlist(key, type)) - return rv - - def _keys_impl(self): - """This function exists so __len__ can be implemented more efficiently, - saving one list creation from an iterator. - """ - rv = set() - rv.update(*self.dicts) - return rv - - def keys(self): - return self._keys_impl() - - def __iter__(self): - return iter(self.keys()) - - def items(self, multi=False): - found = set() - for d in self.dicts: - for key, value in d.items(multi): - if multi: - yield key, value - elif key not in found: - found.add(key) - yield key, value - - def values(self): - for _key, value in self.items(): - yield value - - def lists(self): - rv = {} - for d in self.dicts: - for key, values in d.lists(): - rv.setdefault(key, []).extend(values) - return list(rv.items()) - - def listvalues(self): - return (x[1] for x in self.lists()) - - def copy(self): - """Return a shallow mutable copy of this object. - - This returns a :class:`MultiDict` representing the data at the - time of copying. The copy will no longer reflect changes to the - wrapped dicts. - - .. versionchanged:: 0.15 - Return a mutable :class:`MultiDict`. - """ - return MultiDict(self) - - def to_dict(self, flat=True): - """Return the contents as regular dict. If `flat` is `True` the - returned dict will only have the first item present, if `flat` is - `False` all values will be returned as lists. - - :param flat: If set to `False` the dict returned will have lists - with all the values in it. Otherwise it will only - contain the first item for each key. - :return: a :class:`dict` - """ - if flat: - return dict(self.items()) - - return dict(self.lists()) - - def __len__(self): - return len(self._keys_impl()) - - def __contains__(self, key): - for d in self.dicts: - if key in d: - return True - return False - - def __repr__(self): - return f"{type(self).__name__}({self.dicts!r})" - - -class FileMultiDict(MultiDict): - """A special :class:`MultiDict` that has convenience methods to add - files to it. This is used for :class:`EnvironBuilder` and generally - useful for unittesting. - - .. versionadded:: 0.5 - """ - - def add_file(self, name, file, filename=None, content_type=None): - """Adds a new file to the dict. `file` can be a file name or - a :class:`file`-like or a :class:`FileStorage` object. - - :param name: the name of the field. - :param file: a filename or :class:`file`-like object - :param filename: an optional filename - :param content_type: an optional content type - """ - if isinstance(file, FileStorage): - value = file - else: - if isinstance(file, str): - if filename is None: - filename = file - file = open(file, "rb") - if filename and content_type is None: - content_type = ( - mimetypes.guess_type(filename)[0] or "application/octet-stream" - ) - value = FileStorage(file, filename, name, content_type) - - self.add(name, value) - - -class ImmutableDict(ImmutableDictMixin, dict): - """An immutable :class:`dict`. - - .. versionadded:: 0.5 - """ - - def __repr__(self): - return f"{type(self).__name__}({dict.__repr__(self)})" - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return dict(self) - - def __copy__(self): - return self - - -class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): - """An immutable :class:`MultiDict`. - - .. versionadded:: 0.5 - """ - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return MultiDict(self) - - def __copy__(self): - return self - - -class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): - """An immutable :class:`OrderedMultiDict`. - - .. versionadded:: 0.6 - """ - - def _iter_hashitems(self): - return enumerate(self.items(multi=True)) - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return OrderedMultiDict(self) - - def __copy__(self): - return self - - -class Accept(ImmutableList): - """An :class:`Accept` object is just a list subclass for lists of - ``(value, quality)`` tuples. It is automatically sorted by specificity - and quality. - - All :class:`Accept` objects work similar to a list but provide extra - functionality for working with the data. Containment checks are - normalized to the rules of that header: - - >>> a = CharsetAccept([('ISO-8859-1', 1), ('utf-8', 0.7)]) - >>> a.best - 'ISO-8859-1' - >>> 'iso-8859-1' in a - True - >>> 'UTF8' in a - True - >>> 'utf7' in a - False - - To get the quality for an item you can use normal item lookup: - - >>> print a['utf-8'] - 0.7 - >>> a['utf7'] - 0 - - .. versionchanged:: 0.5 - :class:`Accept` objects are forced immutable now. - - .. versionchanged:: 1.0.0 - :class:`Accept` internal values are no longer ordered - alphabetically for equal quality tags. Instead the initial - order is preserved. - - """ - - def __init__(self, values=()): - if values is None: - list.__init__(self) - self.provided = False - elif isinstance(values, Accept): - self.provided = values.provided - list.__init__(self, values) - else: - self.provided = True - values = sorted( - values, key=lambda x: (self._specificity(x[0]), x[1]), reverse=True - ) - list.__init__(self, values) - - def _specificity(self, value): - """Returns a tuple describing the value's specificity.""" - return (value != "*",) - - def _value_matches(self, value, item): - """Check if a value matches a given accept item.""" - return item == "*" or item.lower() == value.lower() - - def __getitem__(self, key): - """Besides index lookup (getting item n) you can also pass it a string - to get the quality for the item. If the item is not in the list, the - returned quality is ``0``. - """ - if isinstance(key, str): - return self.quality(key) - return list.__getitem__(self, key) - - def quality(self, key): - """Returns the quality of the key. - - .. versionadded:: 0.6 - In previous versions you had to use the item-lookup syntax - (eg: ``obj[key]`` instead of ``obj.quality(key)``) - """ - for item, quality in self: - if self._value_matches(key, item): - return quality - return 0 - - def __contains__(self, value): - for item, _quality in self: - if self._value_matches(value, item): - return True - return False - - def __repr__(self): - pairs_str = ", ".join(f"({x!r}, {y})" for x, y in self) - return f"{type(self).__name__}([{pairs_str}])" - - def index(self, key): - """Get the position of an entry or raise :exc:`ValueError`. - - :param key: The key to be looked up. - - .. versionchanged:: 0.5 - This used to raise :exc:`IndexError`, which was inconsistent - with the list API. - """ - if isinstance(key, str): - for idx, (item, _quality) in enumerate(self): - if self._value_matches(key, item): - return idx - raise ValueError(key) - return list.index(self, key) - - def find(self, key): - """Get the position of an entry or return -1. - - :param key: The key to be looked up. - """ - try: - return self.index(key) - except ValueError: - return -1 - - def values(self): - """Iterate over all values.""" - for item in self: - yield item[0] - - def to_header(self): - """Convert the header set into an HTTP header string.""" - result = [] - for value, quality in self: - if quality != 1: - value = f"{value};q={quality}" - result.append(value) - return ",".join(result) - - def __str__(self): - return self.to_header() - - def _best_single_match(self, match): - for client_item, quality in self: - if self._value_matches(match, client_item): - # self is sorted by specificity descending, we can exit - return client_item, quality - return None - - def best_match(self, matches, default=None): - """Returns the best match from a list of possible matches based - on the specificity and quality of the client. If two items have the - same quality and specificity, the one is returned that comes first. - - :param matches: a list of matches to check for - :param default: the value that is returned if none match - """ - result = default - best_quality = -1 - best_specificity = (-1,) - for server_item in matches: - match = self._best_single_match(server_item) - if not match: - continue - client_item, quality = match - specificity = self._specificity(client_item) - if quality <= 0 or quality < best_quality: - continue - # better quality or same quality but more specific => better match - if quality > best_quality or specificity > best_specificity: - result = server_item - best_quality = quality - best_specificity = specificity - return result - - @property - def best(self): - """The best match as value.""" - if self: - return self[0][0] - - -_mime_split_re = re.compile(r"/|(?:\s*;\s*)") - - -def _normalize_mime(value): - return _mime_split_re.split(value.lower()) - - -class MIMEAccept(Accept): - """Like :class:`Accept` but with special methods and behavior for - mimetypes. - """ - - def _specificity(self, value): - return tuple(x != "*" for x in _mime_split_re.split(value)) - - def _value_matches(self, value, item): - # item comes from the client, can't match if it's invalid. - if "/" not in item: - return False - - # value comes from the application, tell the developer when it - # doesn't look valid. - if "/" not in value: - raise ValueError(f"invalid mimetype {value!r}") - - # Split the match value into type, subtype, and a sorted list of parameters. - normalized_value = _normalize_mime(value) - value_type, value_subtype = normalized_value[:2] - value_params = sorted(normalized_value[2:]) - - # "*/*" is the only valid value that can start with "*". - if value_type == "*" and value_subtype != "*": - raise ValueError(f"invalid mimetype {value!r}") - - # Split the accept item into type, subtype, and parameters. - normalized_item = _normalize_mime(item) - item_type, item_subtype = normalized_item[:2] - item_params = sorted(normalized_item[2:]) - - # "*/not-*" from the client is invalid, can't match. - if item_type == "*" and item_subtype != "*": - return False - - return ( - (item_type == "*" and item_subtype == "*") - or (value_type == "*" and value_subtype == "*") - ) or ( - item_type == value_type - and ( - item_subtype == "*" - or value_subtype == "*" - or (item_subtype == value_subtype and item_params == value_params) - ) - ) - - @property - def accept_html(self): - """True if this object accepts HTML.""" - return ( - "text/html" in self or "application/xhtml+xml" in self or self.accept_xhtml - ) - - @property - def accept_xhtml(self): - """True if this object accepts XHTML.""" - return "application/xhtml+xml" in self or "application/xml" in self - - @property - def accept_json(self): - """True if this object accepts JSON.""" - return "application/json" in self - - -_locale_delim_re = re.compile(r"[_-]") - - -def _normalize_lang(value): - """Process a language tag for matching.""" - return _locale_delim_re.split(value.lower()) - - -class LanguageAccept(Accept): - """Like :class:`Accept` but with normalization for language tags.""" - - def _value_matches(self, value, item): - return item == "*" or _normalize_lang(value) == _normalize_lang(item) - - def best_match(self, matches, default=None): - """Given a list of supported values, finds the best match from - the list of accepted values. - - Language tags are normalized for the purpose of matching, but - are returned unchanged. - - If no exact match is found, this will fall back to matching - the first subtag (primary language only), first with the - accepted values then with the match values. This partial is not - applied to any other language subtags. - - The default is returned if no exact or fallback match is found. - - :param matches: A list of supported languages to find a match. - :param default: The value that is returned if none match. - """ - # Look for an exact match first. If a client accepts "en-US", - # "en-US" is a valid match at this point. - result = super().best_match(matches) - - if result is not None: - return result - - # Fall back to accepting primary tags. If a client accepts - # "en-US", "en" is a valid match at this point. Need to use - # re.split to account for 2 or 3 letter codes. - fallback = Accept( - [(_locale_delim_re.split(item[0], 1)[0], item[1]) for item in self] - ) - result = fallback.best_match(matches) - - if result is not None: - return result - - # Fall back to matching primary tags. If the client accepts - # "en", "en-US" is a valid match at this point. - fallback_matches = [_locale_delim_re.split(item, 1)[0] for item in matches] - result = super().best_match(fallback_matches) - - # Return a value from the original match list. Find the first - # original value that starts with the matched primary tag. - if result is not None: - return next(item for item in matches if item.startswith(result)) - - return default - - -class CharsetAccept(Accept): - """Like :class:`Accept` but with normalization for charsets.""" - - def _value_matches(self, value, item): - def _normalize(name): - try: - return codecs.lookup(name).name - except LookupError: - return name.lower() - - return item == "*" or _normalize(value) == _normalize(item) - - -def cache_control_property(key, empty, type): - """Return a new property object for a cache header. Useful if you - want to add support for a cache extension in a subclass. - - .. versionchanged:: 2.0 - Renamed from ``cache_property``. - """ - return property( - lambda x: x._get_cache_value(key, empty, type), - lambda x, v: x._set_cache_value(key, v, type), - lambda x: x._del_cache_value(key), - f"accessor for {key!r}", - ) - - -class _CacheControl(UpdateDictMixin, dict): - """Subclass of a dict that stores values for a Cache-Control header. It - has accessors for all the cache-control directives specified in RFC 2616. - The class does not differentiate between request and response directives. - - Because the cache-control directives in the HTTP header use dashes the - python descriptors use underscores for that. - - To get a header of the :class:`CacheControl` object again you can convert - the object into a string or call the :meth:`to_header` method. If you plan - to subclass it and add your own items have a look at the sourcecode for - that class. - - .. versionchanged:: 2.1.0 - Setting int properties such as ``max_age`` will convert the - value to an int. - - .. versionchanged:: 0.4 - - Setting `no_cache` or `private` to boolean `True` will set the implicit - none-value which is ``*``: - - >>> cc = ResponseCacheControl() - >>> cc.no_cache = True - >>> cc - - >>> cc.no_cache - '*' - >>> cc.no_cache = None - >>> cc - - - In versions before 0.5 the behavior documented here affected the now - no longer existing `CacheControl` class. - """ - - no_cache = cache_control_property("no-cache", "*", None) - no_store = cache_control_property("no-store", None, bool) - max_age = cache_control_property("max-age", -1, int) - no_transform = cache_control_property("no-transform", None, None) - - def __init__(self, values=(), on_update=None): - dict.__init__(self, values or ()) - self.on_update = on_update - self.provided = values is not None - - def _get_cache_value(self, key, empty, type): - """Used internally by the accessor properties.""" - if type is bool: - return key in self - if key in self: - value = self[key] - if value is None: - return empty - elif type is not None: - try: - value = type(value) - except ValueError: - pass - return value - return None - - def _set_cache_value(self, key, value, type): - """Used internally by the accessor properties.""" - if type is bool: - if value: - self[key] = None - else: - self.pop(key, None) - else: - if value is None: - self.pop(key, None) - elif value is True: - self[key] = None - else: - if type is not None: - self[key] = type(value) - else: - self[key] = value - - def _del_cache_value(self, key): - """Used internally by the accessor properties.""" - if key in self: - del self[key] - - def to_header(self): - """Convert the stored values into a cache control header.""" - return http.dump_header(self) - - def __str__(self): - return self.to_header() - - def __repr__(self): - kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) - return f"<{type(self).__name__} {kv_str}>" - - cache_property = staticmethod(cache_control_property) - - -class RequestCacheControl(ImmutableDictMixin, _CacheControl): - """A cache control for requests. This is immutable and gives access - to all the request-relevant cache control headers. - - To get a header of the :class:`RequestCacheControl` object again you can - convert the object into a string or call the :meth:`to_header` method. If - you plan to subclass it and add your own items have a look at the sourcecode - for that class. - - .. versionchanged:: 2.1.0 - Setting int properties such as ``max_age`` will convert the - value to an int. - - .. versionadded:: 0.5 - In previous versions a `CacheControl` class existed that was used - both for request and response. - """ - - max_stale = cache_control_property("max-stale", "*", int) - min_fresh = cache_control_property("min-fresh", "*", int) - only_if_cached = cache_control_property("only-if-cached", None, bool) - - -class ResponseCacheControl(_CacheControl): - """A cache control for responses. Unlike :class:`RequestCacheControl` - this is mutable and gives access to response-relevant cache control - headers. - - To get a header of the :class:`ResponseCacheControl` object again you can - convert the object into a string or call the :meth:`to_header` method. If - you plan to subclass it and add your own items have a look at the sourcecode - for that class. - - .. versionchanged:: 2.1.1 - ``s_maxage`` converts the value to an int. - - .. versionchanged:: 2.1.0 - Setting int properties such as ``max_age`` will convert the - value to an int. - - .. versionadded:: 0.5 - In previous versions a `CacheControl` class existed that was used - both for request and response. - """ - - public = cache_control_property("public", None, bool) - private = cache_control_property("private", "*", None) - must_revalidate = cache_control_property("must-revalidate", None, bool) - proxy_revalidate = cache_control_property("proxy-revalidate", None, bool) - s_maxage = cache_control_property("s-maxage", None, int) - immutable = cache_control_property("immutable", None, bool) - - -def csp_property(key): - """Return a new property object for a content security policy header. - Useful if you want to add support for a csp extension in a - subclass. - """ - return property( - lambda x: x._get_value(key), - lambda x, v: x._set_value(key, v), - lambda x: x._del_value(key), - f"accessor for {key!r}", - ) - - -class ContentSecurityPolicy(UpdateDictMixin, dict): - """Subclass of a dict that stores values for a Content Security Policy - header. It has accessors for all the level 3 policies. - - Because the csp directives in the HTTP header use dashes the - python descriptors use underscores for that. - - To get a header of the :class:`ContentSecuirtyPolicy` object again - you can convert the object into a string or call the - :meth:`to_header` method. If you plan to subclass it and add your - own items have a look at the sourcecode for that class. - - .. versionadded:: 1.0.0 - Support for Content Security Policy headers was added. - - """ - - base_uri = csp_property("base-uri") - child_src = csp_property("child-src") - connect_src = csp_property("connect-src") - default_src = csp_property("default-src") - font_src = csp_property("font-src") - form_action = csp_property("form-action") - frame_ancestors = csp_property("frame-ancestors") - frame_src = csp_property("frame-src") - img_src = csp_property("img-src") - manifest_src = csp_property("manifest-src") - media_src = csp_property("media-src") - navigate_to = csp_property("navigate-to") - object_src = csp_property("object-src") - prefetch_src = csp_property("prefetch-src") - plugin_types = csp_property("plugin-types") - report_to = csp_property("report-to") - report_uri = csp_property("report-uri") - sandbox = csp_property("sandbox") - script_src = csp_property("script-src") - script_src_attr = csp_property("script-src-attr") - script_src_elem = csp_property("script-src-elem") - style_src = csp_property("style-src") - style_src_attr = csp_property("style-src-attr") - style_src_elem = csp_property("style-src-elem") - worker_src = csp_property("worker-src") - - def __init__(self, values=(), on_update=None): - dict.__init__(self, values or ()) - self.on_update = on_update - self.provided = values is not None - - def _get_value(self, key): - """Used internally by the accessor properties.""" - return self.get(key) - - def _set_value(self, key, value): - """Used internally by the accessor properties.""" - if value is None: - self.pop(key, None) - else: - self[key] = value - - def _del_value(self, key): - """Used internally by the accessor properties.""" - if key in self: - del self[key] - - def to_header(self): - """Convert the stored values into a cache control header.""" - return http.dump_csp_header(self) - - def __str__(self): - return self.to_header() - - def __repr__(self): - kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) - return f"<{type(self).__name__} {kv_str}>" - - -class CallbackDict(UpdateDictMixin, dict): - """A dict that calls a function passed every time something is changed. - The function is passed the dict instance. - """ - - def __init__(self, initial=None, on_update=None): - dict.__init__(self, initial or ()) - self.on_update = on_update - - def __repr__(self): - return f"<{type(self).__name__} {dict.__repr__(self)}>" - - -class HeaderSet(MutableSet): - """Similar to the :class:`ETags` class this implements a set-like structure. - Unlike :class:`ETags` this is case insensitive and used for vary, allow, and - content-language headers. - - If not constructed using the :func:`parse_set_header` function the - instantiation works like this: - - >>> hs = HeaderSet(['foo', 'bar', 'baz']) - >>> hs - HeaderSet(['foo', 'bar', 'baz']) - """ - - def __init__(self, headers=None, on_update=None): - self._headers = list(headers or ()) - self._set = {x.lower() for x in self._headers} - self.on_update = on_update - - def add(self, header): - """Add a new header to the set.""" - self.update((header,)) - - def remove(self, header): - """Remove a header from the set. This raises an :exc:`KeyError` if the - header is not in the set. - - .. versionchanged:: 0.5 - In older versions a :exc:`IndexError` was raised instead of a - :exc:`KeyError` if the object was missing. - - :param header: the header to be removed. - """ - key = header.lower() - if key not in self._set: - raise KeyError(header) - self._set.remove(key) - for idx, key in enumerate(self._headers): - if key.lower() == header: - del self._headers[idx] - break - if self.on_update is not None: - self.on_update(self) - - def update(self, iterable): - """Add all the headers from the iterable to the set. - - :param iterable: updates the set with the items from the iterable. - """ - inserted_any = False - for header in iterable: - key = header.lower() - if key not in self._set: - self._headers.append(header) - self._set.add(key) - inserted_any = True - if inserted_any and self.on_update is not None: - self.on_update(self) - - def discard(self, header): - """Like :meth:`remove` but ignores errors. - - :param header: the header to be discarded. - """ - try: - self.remove(header) - except KeyError: - pass - - def find(self, header): - """Return the index of the header in the set or return -1 if not found. - - :param header: the header to be looked up. - """ - header = header.lower() - for idx, item in enumerate(self._headers): - if item.lower() == header: - return idx - return -1 - - def index(self, header): - """Return the index of the header in the set or raise an - :exc:`IndexError`. - - :param header: the header to be looked up. - """ - rv = self.find(header) - if rv < 0: - raise IndexError(header) - return rv - - def clear(self): - """Clear the set.""" - self._set.clear() - del self._headers[:] - if self.on_update is not None: - self.on_update(self) - - def as_set(self, preserve_casing=False): - """Return the set as real python set type. When calling this, all - the items are converted to lowercase and the ordering is lost. - - :param preserve_casing: if set to `True` the items in the set returned - will have the original case like in the - :class:`HeaderSet`, otherwise they will - be lowercase. - """ - if preserve_casing: - return set(self._headers) - return set(self._set) - - def to_header(self): - """Convert the header set into an HTTP header string.""" - return ", ".join(map(http.quote_header_value, self._headers)) - - def __getitem__(self, idx): - return self._headers[idx] - - def __delitem__(self, idx): - rv = self._headers.pop(idx) - self._set.remove(rv.lower()) - if self.on_update is not None: - self.on_update(self) - - def __setitem__(self, idx, value): - old = self._headers[idx] - self._set.remove(old.lower()) - self._headers[idx] = value - self._set.add(value.lower()) - if self.on_update is not None: - self.on_update(self) - - def __contains__(self, header): - return header.lower() in self._set - - def __len__(self): - return len(self._set) - - def __iter__(self): - return iter(self._headers) - - def __bool__(self): - return bool(self._set) - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"{type(self).__name__}({self._headers!r})" - - -class ETags(Collection): - """A set that can be used to check if one etag is present in a collection - of etags. - """ - - def __init__(self, strong_etags=None, weak_etags=None, star_tag=False): - if not star_tag and strong_etags: - self._strong = frozenset(strong_etags) - else: - self._strong = frozenset() - - self._weak = frozenset(weak_etags or ()) - self.star_tag = star_tag - - def as_set(self, include_weak=False): - """Convert the `ETags` object into a python set. Per default all the - weak etags are not part of this set.""" - rv = set(self._strong) - if include_weak: - rv.update(self._weak) - return rv - - def is_weak(self, etag): - """Check if an etag is weak.""" - return etag in self._weak - - def is_strong(self, etag): - """Check if an etag is strong.""" - return etag in self._strong - - def contains_weak(self, etag): - """Check if an etag is part of the set including weak and strong tags.""" - return self.is_weak(etag) or self.contains(etag) - - def contains(self, etag): - """Check if an etag is part of the set ignoring weak tags. - It is also possible to use the ``in`` operator. - """ - if self.star_tag: - return True - return self.is_strong(etag) - - def contains_raw(self, etag): - """When passed a quoted tag it will check if this tag is part of the - set. If the tag is weak it is checked against weak and strong tags, - otherwise strong only.""" - etag, weak = http.unquote_etag(etag) - if weak: - return self.contains_weak(etag) - return self.contains(etag) - - def to_header(self): - """Convert the etags set into a HTTP header string.""" - if self.star_tag: - return "*" - return ", ".join( - [f'"{x}"' for x in self._strong] + [f'W/"{x}"' for x in self._weak] - ) - - def __call__(self, etag=None, data=None, include_weak=False): - if [etag, data].count(None) != 1: - raise TypeError("either tag or data required, but at least one") - if etag is None: - etag = http.generate_etag(data) - if include_weak: - if etag in self._weak: - return True - return etag in self._strong - - def __bool__(self): - return bool(self.star_tag or self._strong or self._weak) - - def __str__(self): - return self.to_header() - - def __len__(self): - return len(self._strong) - - def __iter__(self): - return iter(self._strong) - - def __contains__(self, etag): - return self.contains(etag) - - def __repr__(self): - return f"<{type(self).__name__} {str(self)!r}>" - - -class IfRange: - """Very simple object that represents the `If-Range` header in parsed - form. It will either have neither a etag or date or one of either but - never both. - - .. versionadded:: 0.7 - """ - - def __init__(self, etag=None, date=None): - #: The etag parsed and unquoted. Ranges always operate on strong - #: etags so the weakness information is not necessary. - self.etag = etag - #: The date in parsed format or `None`. - self.date = date - - def to_header(self): - """Converts the object back into an HTTP header.""" - if self.date is not None: - return http.http_date(self.date) - if self.etag is not None: - return http.quote_etag(self.etag) - return "" - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"<{type(self).__name__} {str(self)!r}>" - - -class Range: - """Represents a ``Range`` header. All methods only support only - bytes as the unit. Stores a list of ranges if given, but the methods - only work if only one range is provided. - - :raise ValueError: If the ranges provided are invalid. - - .. versionchanged:: 0.15 - The ranges passed in are validated. - - .. versionadded:: 0.7 - """ - - def __init__(self, units, ranges): - #: The units of this range. Usually "bytes". - self.units = units - #: A list of ``(begin, end)`` tuples for the range header provided. - #: The ranges are non-inclusive. - self.ranges = ranges - - for start, end in ranges: - if start is None or (end is not None and (start < 0 or start >= end)): - raise ValueError(f"{(start, end)} is not a valid range.") - - def range_for_length(self, length): - """If the range is for bytes, the length is not None and there is - exactly one range and it is satisfiable it returns a ``(start, stop)`` - tuple, otherwise `None`. - """ - if self.units != "bytes" or length is None or len(self.ranges) != 1: - return None - start, end = self.ranges[0] - if end is None: - end = length - if start < 0: - start += length - if http.is_byte_range_valid(start, end, length): - return start, min(end, length) - return None - - def make_content_range(self, length): - """Creates a :class:`~werkzeug.datastructures.ContentRange` object - from the current range and given content length. - """ - rng = self.range_for_length(length) - if rng is not None: - return ContentRange(self.units, rng[0], rng[1], length) - return None - - def to_header(self): - """Converts the object back into an HTTP header.""" - ranges = [] - for begin, end in self.ranges: - if end is None: - ranges.append(f"{begin}-" if begin >= 0 else str(begin)) - else: - ranges.append(f"{begin}-{end - 1}") - return f"{self.units}={','.join(ranges)}" - - def to_content_range_header(self, length): - """Converts the object into `Content-Range` HTTP header, - based on given length - """ - range = self.range_for_length(length) - if range is not None: - return f"{self.units} {range[0]}-{range[1] - 1}/{length}" - return None - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"<{type(self).__name__} {str(self)!r}>" - - -def _callback_property(name): - def fget(self): - return getattr(self, name) - - def fset(self, value): - setattr(self, name, value) - if self.on_update is not None: - self.on_update(self) - - return property(fget, fset) - - -class ContentRange: - """Represents the content range header. - - .. versionadded:: 0.7 - """ - - def __init__(self, units, start, stop, length=None, on_update=None): - assert http.is_byte_range_valid(start, stop, length), "Bad range provided" - self.on_update = on_update - self.set(start, stop, length, units) - - #: The units to use, usually "bytes" - units = _callback_property("_units") - #: The start point of the range or `None`. - start = _callback_property("_start") - #: The stop point of the range (non-inclusive) or `None`. Can only be - #: `None` if also start is `None`. - stop = _callback_property("_stop") - #: The length of the range or `None`. - length = _callback_property("_length") - - def set(self, start, stop, length=None, units="bytes"): - """Simple method to update the ranges.""" - assert http.is_byte_range_valid(start, stop, length), "Bad range provided" - self._units = units - self._start = start - self._stop = stop - self._length = length - if self.on_update is not None: - self.on_update(self) - - def unset(self): - """Sets the units to `None` which indicates that the header should - no longer be used. - """ - self.set(None, None, units=None) - - def to_header(self): - if self.units is None: - return "" - if self.length is None: - length = "*" - else: - length = self.length - if self.start is None: - return f"{self.units} */{length}" - return f"{self.units} {self.start}-{self.stop - 1}/{length}" - - def __bool__(self): - return self.units is not None - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"<{type(self).__name__} {str(self)!r}>" - - -class Authorization(ImmutableDictMixin, dict): - """Represents an ``Authorization`` header sent by the client. - - This is returned by - :func:`~werkzeug.http.parse_authorization_header`. It can be useful - to create the object manually to pass to the test - :class:`~werkzeug.test.Client`. - - .. versionchanged:: 0.5 - This object became immutable. - """ - - def __init__(self, auth_type, data=None): - dict.__init__(self, data or {}) - self.type = auth_type - - @property - def username(self): - """The username transmitted. This is set for both basic and digest - auth all the time. - """ - return self.get("username") - - @property - def password(self): - """When the authentication type is basic this is the password - transmitted by the client, else `None`. - """ - return self.get("password") - - @property - def realm(self): - """This is the server realm sent back for HTTP digest auth.""" - return self.get("realm") - - @property - def nonce(self): - """The nonce the server sent for digest auth, sent back by the client. - A nonce should be unique for every 401 response for HTTP digest auth. - """ - return self.get("nonce") - - @property - def uri(self): - """The URI from Request-URI of the Request-Line; duplicated because - proxies are allowed to change the Request-Line in transit. HTTP - digest auth only. - """ - return self.get("uri") - - @property - def nc(self): - """The nonce count value transmitted by clients if a qop-header is - also transmitted. HTTP digest auth only. - """ - return self.get("nc") - - @property - def cnonce(self): - """If the server sent a qop-header in the ``WWW-Authenticate`` - header, the client has to provide this value for HTTP digest auth. - See the RFC for more details. - """ - return self.get("cnonce") - - @property - def response(self): - """A string of 32 hex digits computed as defined in RFC 2617, which - proves that the user knows a password. Digest auth only. - """ - return self.get("response") - - @property - def opaque(self): - """The opaque header from the server returned unchanged by the client. - It is recommended that this string be base64 or hexadecimal data. - Digest auth only. - """ - return self.get("opaque") - - @property - def qop(self): - """Indicates what "quality of protection" the client has applied to - the message for HTTP digest auth. Note that this is a single token, - not a quoted list of alternatives as in WWW-Authenticate. - """ - return self.get("qop") - - def to_header(self): - """Convert to a string value for an ``Authorization`` header. - - .. versionadded:: 2.0 - Added to support passing authorization to the test client. - """ - if self.type == "basic": - value = base64.b64encode( - f"{self.username}:{self.password}".encode() - ).decode("utf8") - return f"Basic {value}" - - if self.type == "digest": - return f"Digest {http.dump_header(self)}" - - raise ValueError(f"Unsupported type {self.type!r}.") - - -def auth_property(name, doc=None): - """A static helper function for Authentication subclasses to add - extra authentication system properties onto a class:: - - class FooAuthenticate(WWWAuthenticate): - special_realm = auth_property('special_realm') - - For more information have a look at the sourcecode to see how the - regular properties (:attr:`realm` etc.) are implemented. - """ - - def _set_value(self, value): - if value is None: - self.pop(name, None) - else: - self[name] = str(value) - - return property(lambda x: x.get(name), _set_value, doc=doc) - - -def _set_property(name, doc=None): - def fget(self): - def on_update(header_set): - if not header_set and name in self: - del self[name] - elif header_set: - self[name] = header_set.to_header() - - return http.parse_set_header(self.get(name), on_update) - - return property(fget, doc=doc) - - -class WWWAuthenticate(UpdateDictMixin, dict): - """Provides simple access to `WWW-Authenticate` headers.""" - - #: list of keys that require quoting in the generated header - _require_quoting = frozenset(["domain", "nonce", "opaque", "realm", "qop"]) - - def __init__(self, auth_type=None, values=None, on_update=None): - dict.__init__(self, values or ()) - if auth_type: - self["__auth_type__"] = auth_type - self.on_update = on_update - - def set_basic(self, realm="authentication required"): - """Clear the auth info and enable basic auth.""" - dict.clear(self) - dict.update(self, {"__auth_type__": "basic", "realm": realm}) - if self.on_update: - self.on_update(self) - - def set_digest( - self, realm, nonce, qop=("auth",), opaque=None, algorithm=None, stale=False - ): - """Clear the auth info and enable digest auth.""" - d = { - "__auth_type__": "digest", - "realm": realm, - "nonce": nonce, - "qop": http.dump_header(qop), - } - if stale: - d["stale"] = "TRUE" - if opaque is not None: - d["opaque"] = opaque - if algorithm is not None: - d["algorithm"] = algorithm - dict.clear(self) - dict.update(self, d) - if self.on_update: - self.on_update(self) - - def to_header(self): - """Convert the stored values into a WWW-Authenticate header.""" - d = dict(self) - auth_type = d.pop("__auth_type__", None) or "basic" - kv_items = ( - (k, http.quote_header_value(v, allow_token=k not in self._require_quoting)) - for k, v in d.items() - ) - kv_string = ", ".join([f"{k}={v}" for k, v in kv_items]) - return f"{auth_type.title()} {kv_string}" - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"<{type(self).__name__} {self.to_header()!r}>" - - type = auth_property( - "__auth_type__", - doc="""The type of the auth mechanism. HTTP currently specifies - ``Basic`` and ``Digest``.""", - ) - realm = auth_property( - "realm", - doc="""A string to be displayed to users so they know which - username and password to use. This string should contain at - least the name of the host performing the authentication and - might additionally indicate the collection of users who might - have access.""", - ) - domain = _set_property( - "domain", - doc="""A list of URIs that define the protection space. If a URI - is an absolute path, it is relative to the canonical root URL of - the server being accessed.""", - ) - nonce = auth_property( - "nonce", - doc=""" - A server-specified data string which should be uniquely generated - each time a 401 response is made. It is recommended that this - string be base64 or hexadecimal data.""", - ) - opaque = auth_property( - "opaque", - doc="""A string of data, specified by the server, which should - be returned by the client unchanged in the Authorization header - of subsequent requests with URIs in the same protection space. - It is recommended that this string be base64 or hexadecimal - data.""", - ) - algorithm = auth_property( - "algorithm", - doc="""A string indicating a pair of algorithms used to produce - the digest and a checksum. If this is not present it is assumed - to be "MD5". If the algorithm is not understood, the challenge - should be ignored (and a different one used, if there is more - than one).""", - ) - qop = _set_property( - "qop", - doc="""A set of quality-of-privacy directives such as auth and - auth-int.""", - ) - - @property - def stale(self): - """A flag, indicating that the previous request from the client - was rejected because the nonce value was stale. - """ - val = self.get("stale") - if val is not None: - return val.lower() == "true" - - @stale.setter - def stale(self, value): - if value is None: - self.pop("stale", None) - else: - self["stale"] = "TRUE" if value else "FALSE" - - auth_property = staticmethod(auth_property) - - -class FileStorage: - """The :class:`FileStorage` class is a thin wrapper over incoming files. - It is used by the request object to represent uploaded files. All the - attributes of the wrapper stream are proxied by the file storage so - it's possible to do ``storage.read()`` instead of the long form - ``storage.stream.read()``. - """ - - def __init__( - self, - stream=None, - filename=None, - name=None, - content_type=None, - content_length=None, - headers=None, - ): - self.name = name - self.stream = stream or BytesIO() - - # If no filename is provided, attempt to get the filename from - # the stream object. Python names special streams like - # ```` with angular brackets, skip these streams. - if filename is None: - filename = getattr(stream, "name", None) - - if filename is not None: - filename = os.fsdecode(filename) - - if filename and filename[0] == "<" and filename[-1] == ">": - filename = None - else: - filename = os.fsdecode(filename) - - self.filename = filename - - if headers is None: - headers = Headers() - self.headers = headers - if content_type is not None: - headers["Content-Type"] = content_type - if content_length is not None: - headers["Content-Length"] = str(content_length) - - def _parse_content_type(self): - if not hasattr(self, "_parsed_content_type"): - self._parsed_content_type = http.parse_options_header(self.content_type) - - @property - def content_type(self): - """The content-type sent in the header. Usually not available""" - return self.headers.get("content-type") - - @property - def content_length(self): - """The content-length sent in the header. Usually not available""" - try: - return int(self.headers.get("content-length") or 0) - except ValueError: - return 0 - - @property - def mimetype(self): - """Like :attr:`content_type`, but without parameters (eg, without - charset, type etc.) and always lowercase. For example if the content - type is ``text/HTML; charset=utf-8`` the mimetype would be - ``'text/html'``. - - .. versionadded:: 0.7 - """ - self._parse_content_type() - return self._parsed_content_type[0].lower() - - @property - def mimetype_params(self): - """The mimetype parameters as dict. For example if the content - type is ``text/html; charset=utf-8`` the params would be - ``{'charset': 'utf-8'}``. - - .. versionadded:: 0.7 - """ - self._parse_content_type() - return self._parsed_content_type[1] - - def save(self, dst, buffer_size=16384): - """Save the file to a destination path or file object. If the - destination is a file object you have to close it yourself after the - call. The buffer size is the number of bytes held in memory during - the copy process. It defaults to 16KB. - - For secure file saving also have a look at :func:`secure_filename`. - - :param dst: a filename, :class:`os.PathLike`, or open file - object to write to. - :param buffer_size: Passed as the ``length`` parameter of - :func:`shutil.copyfileobj`. - - .. versionchanged:: 1.0 - Supports :mod:`pathlib`. - """ - from shutil import copyfileobj - - close_dst = False - - if hasattr(dst, "__fspath__"): - dst = fspath(dst) - - if isinstance(dst, str): - dst = open(dst, "wb") - close_dst = True - - try: - copyfileobj(self.stream, dst, buffer_size) - finally: - if close_dst: - dst.close() - - def close(self): - """Close the underlying file if possible.""" - try: - self.stream.close() - except Exception: - pass - - def __bool__(self): - return bool(self.filename) - - def __getattr__(self, name): - try: - return getattr(self.stream, name) - except AttributeError: - # SpooledTemporaryFile doesn't implement IOBase, get the - # attribute from its backing file instead. - # https://github.com/python/cpython/pull/3249 - if hasattr(self.stream, "_file"): - return getattr(self.stream._file, name) - raise - - def __iter__(self): - return iter(self.stream) - - def __repr__(self): - return f"<{type(self).__name__}: {self.filename!r} ({self.content_type!r})>" - - -# circular dependencies -from . import http diff --git a/src/werkzeug/datastructures.pyi b/src/werkzeug/datastructures.pyi deleted file mode 100644 index 7bf7297..0000000 --- a/src/werkzeug/datastructures.pyi +++ /dev/null @@ -1,921 +0,0 @@ -from datetime import datetime -from os import PathLike -from typing import Any -from typing import Callable -from typing import Collection -from typing import Dict -from typing import FrozenSet -from typing import Generic -from typing import Hashable -from typing import IO -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Mapping -from typing import NoReturn -from typing import Optional -from typing import overload -from typing import Set -from typing import Tuple -from typing import Type -from typing import TypeVar -from typing import Union -from _typeshed import SupportsKeysAndGetItem -from _typeshed.wsgi import WSGIEnvironment - -from typing_extensions import Literal -from typing_extensions import SupportsIndex - -K = TypeVar("K") -V = TypeVar("V") -T = TypeVar("T") -D = TypeVar("D") -_CD = TypeVar("_CD", bound="CallbackDict") - -def is_immutable(self: object) -> NoReturn: ... -def iter_multi_items( - mapping: Union[Mapping[K, Union[V, Iterable[V]]], Iterable[Tuple[K, V]]] -) -> Iterator[Tuple[K, V]]: ... - -class ImmutableListMixin(List[V]): - _hash_cache: Optional[int] - def __hash__(self) -> int: ... # type: ignore - def __delitem__(self, key: Union[SupportsIndex, slice]) -> NoReturn: ... - def __iadd__(self, other: t.Any) -> NoReturn: ... # type: ignore - def __imul__(self, other: SupportsIndex) -> NoReturn: ... - def __setitem__( # type: ignore - self, key: Union[int, slice], value: V - ) -> NoReturn: ... - def append(self, value: V) -> NoReturn: ... - def remove(self, value: V) -> NoReturn: ... - def extend(self, values: Iterable[V]) -> NoReturn: ... - def insert(self, pos: SupportsIndex, value: V) -> NoReturn: ... - def pop(self, index: SupportsIndex = -1) -> NoReturn: ... - def reverse(self) -> NoReturn: ... - def sort( - self, key: Optional[Callable[[V], Any]] = None, reverse: bool = False - ) -> NoReturn: ... - -class ImmutableList(ImmutableListMixin[V]): ... - -class ImmutableDictMixin(Dict[K, V]): - _hash_cache: Optional[int] - @classmethod - def fromkeys( # type: ignore - cls, keys: Iterable[K], value: Optional[V] = None - ) -> ImmutableDictMixin[K, V]: ... - def _iter_hashitems(self) -> Iterable[Hashable]: ... - def __hash__(self) -> int: ... # type: ignore - def setdefault(self, key: K, default: Optional[V] = None) -> NoReturn: ... - def update(self, *args: Any, **kwargs: V) -> NoReturn: ... - def pop(self, key: K, default: Optional[V] = None) -> NoReturn: ... # type: ignore - def popitem(self) -> NoReturn: ... - def __setitem__(self, key: K, value: V) -> NoReturn: ... - def __delitem__(self, key: K) -> NoReturn: ... - def clear(self) -> NoReturn: ... - -class ImmutableMultiDictMixin(ImmutableDictMixin[K, V]): - def _iter_hashitems(self) -> Iterable[Hashable]: ... - def add(self, key: K, value: V) -> NoReturn: ... - def popitemlist(self) -> NoReturn: ... - def poplist(self, key: K) -> NoReturn: ... - def setlist(self, key: K, new_list: Iterable[V]) -> NoReturn: ... - def setlistdefault( - self, key: K, default_list: Optional[Iterable[V]] = None - ) -> NoReturn: ... - -def _calls_update(name: str) -> Callable[[UpdateDictMixin[K, V]], Any]: ... - -class UpdateDictMixin(Dict[K, V]): - on_update: Optional[Callable[[UpdateDictMixin[K, V]], None]] - def setdefault(self, key: K, default: Optional[V] = None) -> V: ... - @overload - def pop(self, key: K) -> V: ... - @overload - def pop(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... - def __setitem__(self, key: K, value: V) -> None: ... - def __delitem__(self, key: K) -> None: ... - def clear(self) -> None: ... - def popitem(self) -> Tuple[K, V]: ... - @overload - def update(self, __m: SupportsKeysAndGetItem[K, V], **kwargs: V) -> None: ... - @overload - def update(self, __m: Iterable[Tuple[K, V]], **kwargs: V) -> None: ... - @overload - def update(self, **kwargs: V) -> None: ... - -class TypeConversionDict(Dict[K, V]): - @overload - def get(self, key: K, default: None = ..., type: None = ...) -> Optional[V]: ... - @overload - def get(self, key: K, default: D, type: None = ...) -> Union[D, V]: ... - @overload - def get(self, key: K, default: D, type: Callable[[V], T]) -> Union[D, T]: ... - @overload - def get(self, key: K, type: Callable[[V], T]) -> Optional[T]: ... - -class ImmutableTypeConversionDict(ImmutableDictMixin[K, V], TypeConversionDict[K, V]): - def copy(self) -> TypeConversionDict[K, V]: ... - def __copy__(self) -> ImmutableTypeConversionDict: ... - -class MultiDict(TypeConversionDict[K, V]): - def __init__( - self, - mapping: Optional[ - Union[Mapping[K, Union[Iterable[V], V]], Iterable[Tuple[K, V]]] - ] = None, - ) -> None: ... - def __getitem__(self, item: K) -> V: ... - def __setitem__(self, key: K, value: V) -> None: ... - def add(self, key: K, value: V) -> None: ... - @overload - def getlist(self, key: K) -> List[V]: ... - @overload - def getlist(self, key: K, type: Callable[[V], T] = ...) -> List[T]: ... - def setlist(self, key: K, new_list: Iterable[V]) -> None: ... - def setdefault(self, key: K, default: Optional[V] = None) -> V: ... - def setlistdefault( - self, key: K, default_list: Optional[Iterable[V]] = None - ) -> List[V]: ... - def items(self, multi: bool = False) -> Iterator[Tuple[K, V]]: ... # type: ignore - def lists(self) -> Iterator[Tuple[K, List[V]]]: ... - def values(self) -> Iterator[V]: ... # type: ignore - def listvalues(self) -> Iterator[List[V]]: ... - def copy(self) -> MultiDict[K, V]: ... - def deepcopy(self, memo: Any = None) -> MultiDict[K, V]: ... - @overload - def to_dict(self) -> Dict[K, V]: ... - @overload - def to_dict(self, flat: Literal[False]) -> Dict[K, List[V]]: ... - def update( # type: ignore - self, mapping: Union[Mapping[K, Union[Iterable[V], V]], Iterable[Tuple[K, V]]] - ) -> None: ... - @overload - def pop(self, key: K) -> V: ... - @overload - def pop(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... - def popitem(self) -> Tuple[K, V]: ... - def poplist(self, key: K) -> List[V]: ... - def popitemlist(self) -> Tuple[K, List[V]]: ... - def __copy__(self) -> MultiDict[K, V]: ... - def __deepcopy__(self, memo: Any) -> MultiDict[K, V]: ... - -class _omd_bucket(Generic[K, V]): - prev: Optional[_omd_bucket] - next: Optional[_omd_bucket] - key: K - value: V - def __init__(self, omd: OrderedMultiDict, key: K, value: V) -> None: ... - def unlink(self, omd: OrderedMultiDict) -> None: ... - -class OrderedMultiDict(MultiDict[K, V]): - _first_bucket: Optional[_omd_bucket] - _last_bucket: Optional[_omd_bucket] - def __init__(self, mapping: Optional[Mapping[K, V]] = None) -> None: ... - def __eq__(self, other: object) -> bool: ... - def __getitem__(self, key: K) -> V: ... - def __setitem__(self, key: K, value: V) -> None: ... - def __delitem__(self, key: K) -> None: ... - def keys(self) -> Iterator[K]: ... # type: ignore - def __iter__(self) -> Iterator[K]: ... - def values(self) -> Iterator[V]: ... # type: ignore - def items(self, multi: bool = False) -> Iterator[Tuple[K, V]]: ... # type: ignore - def lists(self) -> Iterator[Tuple[K, List[V]]]: ... - def listvalues(self) -> Iterator[List[V]]: ... - def add(self, key: K, value: V) -> None: ... - @overload - def getlist(self, key: K) -> List[V]: ... - @overload - def getlist(self, key: K, type: Callable[[V], T] = ...) -> List[T]: ... - def setlist(self, key: K, new_list: Iterable[V]) -> None: ... - def setlistdefault( - self, key: K, default_list: Optional[Iterable[V]] = None - ) -> List[V]: ... - def update( # type: ignore - self, mapping: Union[Mapping[K, V], Iterable[Tuple[K, V]]] - ) -> None: ... - def poplist(self, key: K) -> List[V]: ... - @overload - def pop(self, key: K) -> V: ... - @overload - def pop(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... - def popitem(self) -> Tuple[K, V]: ... - def popitemlist(self) -> Tuple[K, List[V]]: ... - -def _options_header_vkw( - value: str, kw: Mapping[str, Optional[Union[str, int]]] -) -> str: ... -def _unicodify_header_value(value: Union[str, int]) -> str: ... - -HV = Union[str, int] - -class Headers(Dict[str, str]): - _list: List[Tuple[str, str]] - def __init__( - self, - defaults: Optional[ - Union[Mapping[str, Union[HV, Iterable[HV]]], Iterable[Tuple[str, HV]]] - ] = None, - ) -> None: ... - @overload - def __getitem__(self, key: str) -> str: ... - @overload - def __getitem__(self, key: int) -> Tuple[str, str]: ... - @overload - def __getitem__(self, key: slice) -> Headers: ... - @overload - def __getitem__(self, key: str, _get_mode: Literal[True] = ...) -> str: ... - def __eq__(self, other: object) -> bool: ... - @overload # type: ignore - def get(self, key: str, default: str) -> str: ... - @overload - def get(self, key: str, default: Optional[str] = None) -> Optional[str]: ... - @overload - def get( - self, key: str, default: Optional[T] = None, type: Callable[[str], T] = ... - ) -> Optional[T]: ... - @overload - def getlist(self, key: str) -> List[str]: ... - @overload - def getlist(self, key: str, type: Callable[[str], T]) -> List[T]: ... - def get_all(self, name: str) -> List[str]: ... - def items( # type: ignore - self, lower: bool = False - ) -> Iterator[Tuple[str, str]]: ... - def keys(self, lower: bool = False) -> Iterator[str]: ... # type: ignore - def values(self) -> Iterator[str]: ... # type: ignore - def extend( - self, - *args: Union[Mapping[str, Union[HV, Iterable[HV]]], Iterable[Tuple[str, HV]]], - **kwargs: Union[HV, Iterable[HV]], - ) -> None: ... - @overload - def __delitem__(self, key: Union[str, int, slice]) -> None: ... - @overload - def __delitem__(self, key: str, _index_operation: Literal[False]) -> None: ... - def remove(self, key: str) -> None: ... - @overload # type: ignore - def pop(self, key: str, default: Optional[str] = None) -> str: ... - @overload - def pop( - self, key: Optional[int] = None, default: Optional[Tuple[str, str]] = None - ) -> Tuple[str, str]: ... - def popitem(self) -> Tuple[str, str]: ... - def __contains__(self, key: str) -> bool: ... # type: ignore - def has_key(self, key: str) -> bool: ... - def __iter__(self) -> Iterator[Tuple[str, str]]: ... # type: ignore - def add(self, _key: str, _value: HV, **kw: HV) -> None: ... - def _validate_value(self, value: str) -> None: ... - def add_header(self, _key: str, _value: HV, **_kw: HV) -> None: ... - def clear(self) -> None: ... - def set(self, _key: str, _value: HV, **kw: HV) -> None: ... - def setlist(self, key: str, values: Iterable[HV]) -> None: ... - def setdefault(self, key: str, default: HV) -> str: ... # type: ignore - def setlistdefault(self, key: str, default: Iterable[HV]) -> None: ... - @overload - def __setitem__(self, key: str, value: HV) -> None: ... - @overload - def __setitem__(self, key: int, value: Tuple[str, HV]) -> None: ... - @overload - def __setitem__(self, key: slice, value: Iterable[Tuple[str, HV]]) -> None: ... - @overload - def update( - self, __m: SupportsKeysAndGetItem[str, HV], **kwargs: Union[HV, Iterable[HV]] - ) -> None: ... - @overload - def update( - self, __m: Iterable[Tuple[str, HV]], **kwargs: Union[HV, Iterable[HV]] - ) -> None: ... - @overload - def update(self, **kwargs: Union[HV, Iterable[HV]]) -> None: ... - def to_wsgi_list(self) -> List[Tuple[str, str]]: ... - def copy(self) -> Headers: ... - def __copy__(self) -> Headers: ... - -class ImmutableHeadersMixin(Headers): - def __delitem__(self, key: Any, _index_operation: bool = True) -> NoReturn: ... - def __setitem__(self, key: Any, value: Any) -> NoReturn: ... - def set(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... - def setlist(self, key: Any, values: Any) -> NoReturn: ... - def add(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... - def add_header(self, _key: Any, _value: Any, **_kw: Any) -> NoReturn: ... - def remove(self, key: Any) -> NoReturn: ... - def extend(self, *args: Any, **kwargs: Any) -> NoReturn: ... - def update(self, *args: Any, **kwargs: Any) -> NoReturn: ... - def insert(self, pos: Any, value: Any) -> NoReturn: ... - def pop(self, key: Any = None, default: Any = ...) -> NoReturn: ... - def popitem(self) -> NoReturn: ... - def setdefault(self, key: Any, default: Any) -> NoReturn: ... # type: ignore - def setlistdefault(self, key: Any, default: Any) -> NoReturn: ... - -class EnvironHeaders(ImmutableHeadersMixin, Headers): - environ: WSGIEnvironment - def __init__(self, environ: WSGIEnvironment) -> None: ... - def __eq__(self, other: object) -> bool: ... - def __getitem__( # type: ignore - self, key: str, _get_mode: Literal[False] = False - ) -> str: ... - def __iter__(self) -> Iterator[Tuple[str, str]]: ... # type: ignore - def copy(self) -> NoReturn: ... - -class CombinedMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): # type: ignore - dicts: List[MultiDict[K, V]] - def __init__(self, dicts: Optional[Iterable[MultiDict[K, V]]]) -> None: ... - @classmethod - def fromkeys(cls, keys: Any, value: Any = None) -> NoReturn: ... - def __getitem__(self, key: K) -> V: ... - @overload # type: ignore - def get(self, key: K) -> Optional[V]: ... - @overload - def get(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... - @overload - def get( - self, key: K, default: Optional[T] = None, type: Callable[[V], T] = ... - ) -> Optional[T]: ... - @overload - def getlist(self, key: K) -> List[V]: ... - @overload - def getlist(self, key: K, type: Callable[[V], T] = ...) -> List[T]: ... - def _keys_impl(self) -> Set[K]: ... - def keys(self) -> Set[K]: ... # type: ignore - def __iter__(self) -> Set[K]: ... # type: ignore - def items(self, multi: bool = False) -> Iterator[Tuple[K, V]]: ... # type: ignore - def values(self) -> Iterator[V]: ... # type: ignore - def lists(self) -> Iterator[Tuple[K, List[V]]]: ... - def listvalues(self) -> Iterator[List[V]]: ... - def copy(self) -> MultiDict[K, V]: ... - @overload - def to_dict(self) -> Dict[K, V]: ... - @overload - def to_dict(self, flat: Literal[False]) -> Dict[K, List[V]]: ... - def __contains__(self, key: K) -> bool: ... # type: ignore - def has_key(self, key: K) -> bool: ... - -class FileMultiDict(MultiDict[str, "FileStorage"]): - def add_file( - self, - name: str, - file: Union[FileStorage, str, IO[bytes]], - filename: Optional[str] = None, - content_type: Optional[str] = None, - ) -> None: ... - -class ImmutableDict(ImmutableDictMixin[K, V], Dict[K, V]): - def copy(self) -> Dict[K, V]: ... - def __copy__(self) -> ImmutableDict[K, V]: ... - -class ImmutableMultiDict( # type: ignore - ImmutableMultiDictMixin[K, V], MultiDict[K, V] -): - def copy(self) -> MultiDict[K, V]: ... - def __copy__(self) -> ImmutableMultiDict[K, V]: ... - -class ImmutableOrderedMultiDict( # type: ignore - ImmutableMultiDictMixin[K, V], OrderedMultiDict[K, V] -): - def _iter_hashitems(self) -> Iterator[Tuple[int, Tuple[K, V]]]: ... - def copy(self) -> OrderedMultiDict[K, V]: ... - def __copy__(self) -> ImmutableOrderedMultiDict[K, V]: ... - -class Accept(ImmutableList[Tuple[str, int]]): - provided: bool - def __init__( - self, values: Optional[Union[Accept, Iterable[Tuple[str, float]]]] = None - ) -> None: ... - def _specificity(self, value: str) -> Tuple[bool, ...]: ... - def _value_matches(self, value: str, item: str) -> bool: ... - @overload # type: ignore - def __getitem__(self, key: str) -> int: ... - @overload - def __getitem__(self, key: int) -> Tuple[str, int]: ... - @overload - def __getitem__(self, key: slice) -> Iterable[Tuple[str, int]]: ... - def quality(self, key: str) -> int: ... - def __contains__(self, value: str) -> bool: ... # type: ignore - def index(self, key: str) -> int: ... # type: ignore - def find(self, key: str) -> int: ... - def values(self) -> Iterator[str]: ... - def to_header(self) -> str: ... - def _best_single_match(self, match: str) -> Optional[Tuple[str, int]]: ... - def best_match( - self, matches: Iterable[str], default: Optional[str] = None - ) -> Optional[str]: ... - @property - def best(self) -> str: ... - -def _normalize_mime(value: str) -> List[str]: ... - -class MIMEAccept(Accept): - def _specificity(self, value: str) -> Tuple[bool, ...]: ... - def _value_matches(self, value: str, item: str) -> bool: ... - @property - def accept_html(self) -> bool: ... - @property - def accept_xhtml(self) -> bool: ... - @property - def accept_json(self) -> bool: ... - -def _normalize_lang(value: str) -> List[str]: ... - -class LanguageAccept(Accept): - def _value_matches(self, value: str, item: str) -> bool: ... - def best_match( - self, matches: Iterable[str], default: Optional[str] = None - ) -> Optional[str]: ... - -class CharsetAccept(Accept): - def _value_matches(self, value: str, item: str) -> bool: ... - -_CPT = TypeVar("_CPT", str, int, bool) -_OptCPT = Optional[_CPT] - -def cache_control_property(key: str, empty: _OptCPT, type: Type[_CPT]) -> property: ... - -class _CacheControl(UpdateDictMixin[str, _OptCPT], Dict[str, _OptCPT]): - provided: bool - def __init__( - self, - values: Union[Mapping[str, _OptCPT], Iterable[Tuple[str, _OptCPT]]] = (), - on_update: Optional[Callable[[_CacheControl], None]] = None, - ) -> None: ... - @property - def no_cache(self) -> Optional[bool]: ... - @no_cache.setter - def no_cache(self, value: Optional[bool]) -> None: ... - @no_cache.deleter - def no_cache(self) -> None: ... - @property - def no_store(self) -> Optional[bool]: ... - @no_store.setter - def no_store(self, value: Optional[bool]) -> None: ... - @no_store.deleter - def no_store(self) -> None: ... - @property - def max_age(self) -> Optional[int]: ... - @max_age.setter - def max_age(self, value: Optional[int]) -> None: ... - @max_age.deleter - def max_age(self) -> None: ... - @property - def no_transform(self) -> Optional[bool]: ... - @no_transform.setter - def no_transform(self, value: Optional[bool]) -> None: ... - @no_transform.deleter - def no_transform(self) -> None: ... - def _get_cache_value(self, key: str, empty: Optional[T], type: Type[T]) -> T: ... - def _set_cache_value(self, key: str, value: Optional[T], type: Type[T]) -> None: ... - def _del_cache_value(self, key: str) -> None: ... - def to_header(self) -> str: ... - @staticmethod - def cache_property(key: str, empty: _OptCPT, type: Type[_CPT]) -> property: ... - -class RequestCacheControl(ImmutableDictMixin[str, _OptCPT], _CacheControl): - @property - def max_stale(self) -> Optional[int]: ... - @max_stale.setter - def max_stale(self, value: Optional[int]) -> None: ... - @max_stale.deleter - def max_stale(self) -> None: ... - @property - def min_fresh(self) -> Optional[int]: ... - @min_fresh.setter - def min_fresh(self, value: Optional[int]) -> None: ... - @min_fresh.deleter - def min_fresh(self) -> None: ... - @property - def only_if_cached(self) -> Optional[bool]: ... - @only_if_cached.setter - def only_if_cached(self, value: Optional[bool]) -> None: ... - @only_if_cached.deleter - def only_if_cached(self) -> None: ... - -class ResponseCacheControl(_CacheControl): - @property - def public(self) -> Optional[bool]: ... - @public.setter - def public(self, value: Optional[bool]) -> None: ... - @public.deleter - def public(self) -> None: ... - @property - def private(self) -> Optional[bool]: ... - @private.setter - def private(self, value: Optional[bool]) -> None: ... - @private.deleter - def private(self) -> None: ... - @property - def must_revalidate(self) -> Optional[bool]: ... - @must_revalidate.setter - def must_revalidate(self, value: Optional[bool]) -> None: ... - @must_revalidate.deleter - def must_revalidate(self) -> None: ... - @property - def proxy_revalidate(self) -> Optional[bool]: ... - @proxy_revalidate.setter - def proxy_revalidate(self, value: Optional[bool]) -> None: ... - @proxy_revalidate.deleter - def proxy_revalidate(self) -> None: ... - @property - def s_maxage(self) -> Optional[int]: ... - @s_maxage.setter - def s_maxage(self, value: Optional[int]) -> None: ... - @s_maxage.deleter - def s_maxage(self) -> None: ... - @property - def immutable(self) -> Optional[bool]: ... - @immutable.setter - def immutable(self, value: Optional[bool]) -> None: ... - @immutable.deleter - def immutable(self) -> None: ... - -def csp_property(key: str) -> property: ... - -class ContentSecurityPolicy(UpdateDictMixin[str, str], Dict[str, str]): - @property - def base_uri(self) -> Optional[str]: ... - @base_uri.setter - def base_uri(self, value: Optional[str]) -> None: ... - @base_uri.deleter - def base_uri(self) -> None: ... - @property - def child_src(self) -> Optional[str]: ... - @child_src.setter - def child_src(self, value: Optional[str]) -> None: ... - @child_src.deleter - def child_src(self) -> None: ... - @property - def connect_src(self) -> Optional[str]: ... - @connect_src.setter - def connect_src(self, value: Optional[str]) -> None: ... - @connect_src.deleter - def connect_src(self) -> None: ... - @property - def default_src(self) -> Optional[str]: ... - @default_src.setter - def default_src(self, value: Optional[str]) -> None: ... - @default_src.deleter - def default_src(self) -> None: ... - @property - def font_src(self) -> Optional[str]: ... - @font_src.setter - def font_src(self, value: Optional[str]) -> None: ... - @font_src.deleter - def font_src(self) -> None: ... - @property - def form_action(self) -> Optional[str]: ... - @form_action.setter - def form_action(self, value: Optional[str]) -> None: ... - @form_action.deleter - def form_action(self) -> None: ... - @property - def frame_ancestors(self) -> Optional[str]: ... - @frame_ancestors.setter - def frame_ancestors(self, value: Optional[str]) -> None: ... - @frame_ancestors.deleter - def frame_ancestors(self) -> None: ... - @property - def frame_src(self) -> Optional[str]: ... - @frame_src.setter - def frame_src(self, value: Optional[str]) -> None: ... - @frame_src.deleter - def frame_src(self) -> None: ... - @property - def img_src(self) -> Optional[str]: ... - @img_src.setter - def img_src(self, value: Optional[str]) -> None: ... - @img_src.deleter - def img_src(self) -> None: ... - @property - def manifest_src(self) -> Optional[str]: ... - @manifest_src.setter - def manifest_src(self, value: Optional[str]) -> None: ... - @manifest_src.deleter - def manifest_src(self) -> None: ... - @property - def media_src(self) -> Optional[str]: ... - @media_src.setter - def media_src(self, value: Optional[str]) -> None: ... - @media_src.deleter - def media_src(self) -> None: ... - @property - def navigate_to(self) -> Optional[str]: ... - @navigate_to.setter - def navigate_to(self, value: Optional[str]) -> None: ... - @navigate_to.deleter - def navigate_to(self) -> None: ... - @property - def object_src(self) -> Optional[str]: ... - @object_src.setter - def object_src(self, value: Optional[str]) -> None: ... - @object_src.deleter - def object_src(self) -> None: ... - @property - def prefetch_src(self) -> Optional[str]: ... - @prefetch_src.setter - def prefetch_src(self, value: Optional[str]) -> None: ... - @prefetch_src.deleter - def prefetch_src(self) -> None: ... - @property - def plugin_types(self) -> Optional[str]: ... - @plugin_types.setter - def plugin_types(self, value: Optional[str]) -> None: ... - @plugin_types.deleter - def plugin_types(self) -> None: ... - @property - def report_to(self) -> Optional[str]: ... - @report_to.setter - def report_to(self, value: Optional[str]) -> None: ... - @report_to.deleter - def report_to(self) -> None: ... - @property - def report_uri(self) -> Optional[str]: ... - @report_uri.setter - def report_uri(self, value: Optional[str]) -> None: ... - @report_uri.deleter - def report_uri(self) -> None: ... - @property - def sandbox(self) -> Optional[str]: ... - @sandbox.setter - def sandbox(self, value: Optional[str]) -> None: ... - @sandbox.deleter - def sandbox(self) -> None: ... - @property - def script_src(self) -> Optional[str]: ... - @script_src.setter - def script_src(self, value: Optional[str]) -> None: ... - @script_src.deleter - def script_src(self) -> None: ... - @property - def script_src_attr(self) -> Optional[str]: ... - @script_src_attr.setter - def script_src_attr(self, value: Optional[str]) -> None: ... - @script_src_attr.deleter - def script_src_attr(self) -> None: ... - @property - def script_src_elem(self) -> Optional[str]: ... - @script_src_elem.setter - def script_src_elem(self, value: Optional[str]) -> None: ... - @script_src_elem.deleter - def script_src_elem(self) -> None: ... - @property - def style_src(self) -> Optional[str]: ... - @style_src.setter - def style_src(self, value: Optional[str]) -> None: ... - @style_src.deleter - def style_src(self) -> None: ... - @property - def style_src_attr(self) -> Optional[str]: ... - @style_src_attr.setter - def style_src_attr(self, value: Optional[str]) -> None: ... - @style_src_attr.deleter - def style_src_attr(self) -> None: ... - @property - def style_src_elem(self) -> Optional[str]: ... - @style_src_elem.setter - def style_src_elem(self, value: Optional[str]) -> None: ... - @style_src_elem.deleter - def style_src_elem(self) -> None: ... - @property - def worker_src(self) -> Optional[str]: ... - @worker_src.setter - def worker_src(self, value: Optional[str]) -> None: ... - @worker_src.deleter - def worker_src(self) -> None: ... - provided: bool - def __init__( - self, - values: Union[Mapping[str, str], Iterable[Tuple[str, str]]] = (), - on_update: Optional[Callable[[ContentSecurityPolicy], None]] = None, - ) -> None: ... - def _get_value(self, key: str) -> Optional[str]: ... - def _set_value(self, key: str, value: str) -> None: ... - def _del_value(self, key: str) -> None: ... - def to_header(self) -> str: ... - -class CallbackDict(UpdateDictMixin[K, V], Dict[K, V]): - def __init__( - self, - initial: Optional[Union[Mapping[K, V], Iterable[Tuple[K, V]]]] = None, - on_update: Optional[Callable[[_CD], None]] = None, - ) -> None: ... - -class HeaderSet(Set[str]): - _headers: List[str] - _set: Set[str] - on_update: Optional[Callable[[HeaderSet], None]] - def __init__( - self, - headers: Optional[Iterable[str]] = None, - on_update: Optional[Callable[[HeaderSet], None]] = None, - ) -> None: ... - def add(self, header: str) -> None: ... - def remove(self, header: str) -> None: ... - def update(self, iterable: Iterable[str]) -> None: ... # type: ignore - def discard(self, header: str) -> None: ... - def find(self, header: str) -> int: ... - def index(self, header: str) -> int: ... - def clear(self) -> None: ... - def as_set(self, preserve_casing: bool = False) -> Set[str]: ... - def to_header(self) -> str: ... - def __getitem__(self, idx: int) -> str: ... - def __delitem__(self, idx: int) -> None: ... - def __setitem__(self, idx: int, value: str) -> None: ... - def __contains__(self, header: str) -> bool: ... # type: ignore - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[str]: ... - -class ETags(Collection[str]): - _strong: FrozenSet[str] - _weak: FrozenSet[str] - star_tag: bool - def __init__( - self, - strong_etags: Optional[Iterable[str]] = None, - weak_etags: Optional[Iterable[str]] = None, - star_tag: bool = False, - ) -> None: ... - def as_set(self, include_weak: bool = False) -> Set[str]: ... - def is_weak(self, etag: str) -> bool: ... - def is_strong(self, etag: str) -> bool: ... - def contains_weak(self, etag: str) -> bool: ... - def contains(self, etag: str) -> bool: ... - def contains_raw(self, etag: str) -> bool: ... - def to_header(self) -> str: ... - def __call__( - self, - etag: Optional[str] = None, - data: Optional[bytes] = None, - include_weak: bool = False, - ) -> bool: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[str]: ... - def __contains__(self, item: str) -> bool: ... # type: ignore - -class IfRange: - etag: Optional[str] - date: Optional[datetime] - def __init__( - self, etag: Optional[str] = None, date: Optional[datetime] = None - ) -> None: ... - def to_header(self) -> str: ... - -class Range: - units: str - ranges: List[Tuple[int, Optional[int]]] - def __init__(self, units: str, ranges: List[Tuple[int, Optional[int]]]) -> None: ... - def range_for_length(self, length: Optional[int]) -> Optional[Tuple[int, int]]: ... - def make_content_range(self, length: Optional[int]) -> Optional[ContentRange]: ... - def to_header(self) -> str: ... - def to_content_range_header(self, length: Optional[int]) -> Optional[str]: ... - -def _callback_property(name: str) -> property: ... - -class ContentRange: - on_update: Optional[Callable[[ContentRange], None]] - def __init__( - self, - units: Optional[str], - start: Optional[int], - stop: Optional[int], - length: Optional[int] = None, - on_update: Optional[Callable[[ContentRange], None]] = None, - ) -> None: ... - @property - def units(self) -> Optional[str]: ... - @units.setter - def units(self, value: Optional[str]) -> None: ... - @property - def start(self) -> Optional[int]: ... - @start.setter - def start(self, value: Optional[int]) -> None: ... - @property - def stop(self) -> Optional[int]: ... - @stop.setter - def stop(self, value: Optional[int]) -> None: ... - @property - def length(self) -> Optional[int]: ... - @length.setter - def length(self, value: Optional[int]) -> None: ... - def set( - self, - start: Optional[int], - stop: Optional[int], - length: Optional[int] = None, - units: Optional[str] = "bytes", - ) -> None: ... - def unset(self) -> None: ... - def to_header(self) -> str: ... - -class Authorization(ImmutableDictMixin[str, str], Dict[str, str]): - type: str - def __init__( - self, - auth_type: str, - data: Optional[Union[Mapping[str, str], Iterable[Tuple[str, str]]]] = None, - ) -> None: ... - @property - def username(self) -> Optional[str]: ... - @property - def password(self) -> Optional[str]: ... - @property - def realm(self) -> Optional[str]: ... - @property - def nonce(self) -> Optional[str]: ... - @property - def uri(self) -> Optional[str]: ... - @property - def nc(self) -> Optional[str]: ... - @property - def cnonce(self) -> Optional[str]: ... - @property - def response(self) -> Optional[str]: ... - @property - def opaque(self) -> Optional[str]: ... - @property - def qop(self) -> Optional[str]: ... - def to_header(self) -> str: ... - -def auth_property(name: str, doc: Optional[str] = None) -> property: ... -def _set_property(name: str, doc: Optional[str] = None) -> property: ... - -class WWWAuthenticate(UpdateDictMixin[str, str], Dict[str, str]): - _require_quoting: FrozenSet[str] - def __init__( - self, - auth_type: Optional[str] = None, - values: Optional[Union[Mapping[str, str], Iterable[Tuple[str, str]]]] = None, - on_update: Optional[Callable[[WWWAuthenticate], None]] = None, - ) -> None: ... - def set_basic(self, realm: str = ...) -> None: ... - def set_digest( - self, - realm: str, - nonce: str, - qop: Iterable[str] = ("auth",), - opaque: Optional[str] = None, - algorithm: Optional[str] = None, - stale: bool = False, - ) -> None: ... - def to_header(self) -> str: ... - @property - def type(self) -> Optional[str]: ... - @type.setter - def type(self, value: Optional[str]) -> None: ... - @property - def realm(self) -> Optional[str]: ... - @realm.setter - def realm(self, value: Optional[str]) -> None: ... - @property - def domain(self) -> HeaderSet: ... - @property - def nonce(self) -> Optional[str]: ... - @nonce.setter - def nonce(self, value: Optional[str]) -> None: ... - @property - def opaque(self) -> Optional[str]: ... - @opaque.setter - def opaque(self, value: Optional[str]) -> None: ... - @property - def algorithm(self) -> Optional[str]: ... - @algorithm.setter - def algorithm(self, value: Optional[str]) -> None: ... - @property - def qop(self) -> HeaderSet: ... - @property - def stale(self) -> Optional[bool]: ... - @stale.setter - def stale(self, value: Optional[bool]) -> None: ... - @staticmethod - def auth_property(name: str, doc: Optional[str] = None) -> property: ... - -class FileStorage: - name: Optional[str] - stream: IO[bytes] - filename: Optional[str] - headers: Headers - _parsed_content_type: Tuple[str, Dict[str, str]] - def __init__( - self, - stream: Optional[IO[bytes]] = None, - filename: Union[str, PathLike, None] = None, - name: Optional[str] = None, - content_type: Optional[str] = None, - content_length: Optional[int] = None, - headers: Optional[Headers] = None, - ) -> None: ... - def _parse_content_type(self) -> None: ... - @property - def content_type(self) -> str: ... - @property - def content_length(self) -> int: ... - @property - def mimetype(self) -> str: ... - @property - def mimetype_params(self) -> Dict[str, str]: ... - def save( - self, dst: Union[str, PathLike, IO[bytes]], buffer_size: int = ... - ) -> None: ... - def close(self) -> None: ... - def __bool__(self) -> bool: ... - def __getattr__(self, name: str) -> Any: ... - def __iter__(self) -> Iterator[bytes]: ... - def __repr__(self) -> str: ... diff --git a/src/werkzeug/datastructures/__init__.py b/src/werkzeug/datastructures/__init__.py new file mode 100644 index 0000000..846ffce --- /dev/null +++ b/src/werkzeug/datastructures/__init__.py @@ -0,0 +1,34 @@ +from .accept import Accept as Accept +from .accept import CharsetAccept as CharsetAccept +from .accept import LanguageAccept as LanguageAccept +from .accept import MIMEAccept as MIMEAccept +from .auth import Authorization as Authorization +from .auth import WWWAuthenticate as WWWAuthenticate +from .cache_control import RequestCacheControl as RequestCacheControl +from .cache_control import ResponseCacheControl as ResponseCacheControl +from .csp import ContentSecurityPolicy as ContentSecurityPolicy +from .etag import ETags as ETags +from .file_storage import FileMultiDict as FileMultiDict +from .file_storage import FileStorage as FileStorage +from .headers import EnvironHeaders as EnvironHeaders +from .headers import Headers as Headers +from .mixins import ImmutableDictMixin as ImmutableDictMixin +from .mixins import ImmutableHeadersMixin as ImmutableHeadersMixin +from .mixins import ImmutableListMixin as ImmutableListMixin +from .mixins import ImmutableMultiDictMixin as ImmutableMultiDictMixin +from .mixins import UpdateDictMixin as UpdateDictMixin +from .range import ContentRange as ContentRange +from .range import IfRange as IfRange +from .range import Range as Range +from .structures import CallbackDict as CallbackDict +from .structures import CombinedMultiDict as CombinedMultiDict +from .structures import HeaderSet as HeaderSet +from .structures import ImmutableDict as ImmutableDict +from .structures import ImmutableList as ImmutableList +from .structures import ImmutableMultiDict as ImmutableMultiDict +from .structures import ImmutableOrderedMultiDict as ImmutableOrderedMultiDict +from .structures import ImmutableTypeConversionDict as ImmutableTypeConversionDict +from .structures import iter_multi_items as iter_multi_items +from .structures import MultiDict as MultiDict +from .structures import OrderedMultiDict as OrderedMultiDict +from .structures import TypeConversionDict as TypeConversionDict diff --git a/src/werkzeug/datastructures/accept.py b/src/werkzeug/datastructures/accept.py new file mode 100644 index 0000000..d80f0bb --- /dev/null +++ b/src/werkzeug/datastructures/accept.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import codecs +import re + +from .structures import ImmutableList + + +class Accept(ImmutableList): + """An :class:`Accept` object is just a list subclass for lists of + ``(value, quality)`` tuples. It is automatically sorted by specificity + and quality. + + All :class:`Accept` objects work similar to a list but provide extra + functionality for working with the data. Containment checks are + normalized to the rules of that header: + + >>> a = CharsetAccept([('ISO-8859-1', 1), ('utf-8', 0.7)]) + >>> a.best + 'ISO-8859-1' + >>> 'iso-8859-1' in a + True + >>> 'UTF8' in a + True + >>> 'utf7' in a + False + + To get the quality for an item you can use normal item lookup: + + >>> print a['utf-8'] + 0.7 + >>> a['utf7'] + 0 + + .. versionchanged:: 0.5 + :class:`Accept` objects are forced immutable now. + + .. versionchanged:: 1.0.0 + :class:`Accept` internal values are no longer ordered + alphabetically for equal quality tags. Instead the initial + order is preserved. + + """ + + def __init__(self, values=()): + if values is None: + list.__init__(self) + self.provided = False + elif isinstance(values, Accept): + self.provided = values.provided + list.__init__(self, values) + else: + self.provided = True + values = sorted( + values, key=lambda x: (self._specificity(x[0]), x[1]), reverse=True + ) + list.__init__(self, values) + + def _specificity(self, value): + """Returns a tuple describing the value's specificity.""" + return (value != "*",) + + def _value_matches(self, value, item): + """Check if a value matches a given accept item.""" + return item == "*" or item.lower() == value.lower() + + def __getitem__(self, key): + """Besides index lookup (getting item n) you can also pass it a string + to get the quality for the item. If the item is not in the list, the + returned quality is ``0``. + """ + if isinstance(key, str): + return self.quality(key) + return list.__getitem__(self, key) + + def quality(self, key): + """Returns the quality of the key. + + .. versionadded:: 0.6 + In previous versions you had to use the item-lookup syntax + (eg: ``obj[key]`` instead of ``obj.quality(key)``) + """ + for item, quality in self: + if self._value_matches(key, item): + return quality + return 0 + + def __contains__(self, value): + for item, _quality in self: + if self._value_matches(value, item): + return True + return False + + def __repr__(self): + pairs_str = ", ".join(f"({x!r}, {y})" for x, y in self) + return f"{type(self).__name__}([{pairs_str}])" + + def index(self, key): + """Get the position of an entry or raise :exc:`ValueError`. + + :param key: The key to be looked up. + + .. versionchanged:: 0.5 + This used to raise :exc:`IndexError`, which was inconsistent + with the list API. + """ + if isinstance(key, str): + for idx, (item, _quality) in enumerate(self): + if self._value_matches(key, item): + return idx + raise ValueError(key) + return list.index(self, key) + + def find(self, key): + """Get the position of an entry or return -1. + + :param key: The key to be looked up. + """ + try: + return self.index(key) + except ValueError: + return -1 + + def values(self): + """Iterate over all values.""" + for item in self: + yield item[0] + + def to_header(self): + """Convert the header set into an HTTP header string.""" + result = [] + for value, quality in self: + if quality != 1: + value = f"{value};q={quality}" + result.append(value) + return ",".join(result) + + def __str__(self): + return self.to_header() + + def _best_single_match(self, match): + for client_item, quality in self: + if self._value_matches(match, client_item): + # self is sorted by specificity descending, we can exit + return client_item, quality + return None + + def best_match(self, matches, default=None): + """Returns the best match from a list of possible matches based + on the specificity and quality of the client. If two items have the + same quality and specificity, the one is returned that comes first. + + :param matches: a list of matches to check for + :param default: the value that is returned if none match + """ + result = default + best_quality = -1 + best_specificity = (-1,) + for server_item in matches: + match = self._best_single_match(server_item) + if not match: + continue + client_item, quality = match + specificity = self._specificity(client_item) + if quality <= 0 or quality < best_quality: + continue + # better quality or same quality but more specific => better match + if quality > best_quality or specificity > best_specificity: + result = server_item + best_quality = quality + best_specificity = specificity + return result + + @property + def best(self): + """The best match as value.""" + if self: + return self[0][0] + + +_mime_split_re = re.compile(r"/|(?:\s*;\s*)") + + +def _normalize_mime(value): + return _mime_split_re.split(value.lower()) + + +class MIMEAccept(Accept): + """Like :class:`Accept` but with special methods and behavior for + mimetypes. + """ + + def _specificity(self, value): + return tuple(x != "*" for x in _mime_split_re.split(value)) + + def _value_matches(self, value, item): + # item comes from the client, can't match if it's invalid. + if "/" not in item: + return False + + # value comes from the application, tell the developer when it + # doesn't look valid. + if "/" not in value: + raise ValueError(f"invalid mimetype {value!r}") + + # Split the match value into type, subtype, and a sorted list of parameters. + normalized_value = _normalize_mime(value) + value_type, value_subtype = normalized_value[:2] + value_params = sorted(normalized_value[2:]) + + # "*/*" is the only valid value that can start with "*". + if value_type == "*" and value_subtype != "*": + raise ValueError(f"invalid mimetype {value!r}") + + # Split the accept item into type, subtype, and parameters. + normalized_item = _normalize_mime(item) + item_type, item_subtype = normalized_item[:2] + item_params = sorted(normalized_item[2:]) + + # "*/not-*" from the client is invalid, can't match. + if item_type == "*" and item_subtype != "*": + return False + + return ( + (item_type == "*" and item_subtype == "*") + or (value_type == "*" and value_subtype == "*") + ) or ( + item_type == value_type + and ( + item_subtype == "*" + or value_subtype == "*" + or (item_subtype == value_subtype and item_params == value_params) + ) + ) + + @property + def accept_html(self): + """True if this object accepts HTML.""" + return ( + "text/html" in self or "application/xhtml+xml" in self or self.accept_xhtml + ) + + @property + def accept_xhtml(self): + """True if this object accepts XHTML.""" + return "application/xhtml+xml" in self or "application/xml" in self + + @property + def accept_json(self): + """True if this object accepts JSON.""" + return "application/json" in self + + +_locale_delim_re = re.compile(r"[_-]") + + +def _normalize_lang(value): + """Process a language tag for matching.""" + return _locale_delim_re.split(value.lower()) + + +class LanguageAccept(Accept): + """Like :class:`Accept` but with normalization for language tags.""" + + def _value_matches(self, value, item): + return item == "*" or _normalize_lang(value) == _normalize_lang(item) + + def best_match(self, matches, default=None): + """Given a list of supported values, finds the best match from + the list of accepted values. + + Language tags are normalized for the purpose of matching, but + are returned unchanged. + + If no exact match is found, this will fall back to matching + the first subtag (primary language only), first with the + accepted values then with the match values. This partial is not + applied to any other language subtags. + + The default is returned if no exact or fallback match is found. + + :param matches: A list of supported languages to find a match. + :param default: The value that is returned if none match. + """ + # Look for an exact match first. If a client accepts "en-US", + # "en-US" is a valid match at this point. + result = super().best_match(matches) + + if result is not None: + return result + + # Fall back to accepting primary tags. If a client accepts + # "en-US", "en" is a valid match at this point. Need to use + # re.split to account for 2 or 3 letter codes. + fallback = Accept( + [(_locale_delim_re.split(item[0], 1)[0], item[1]) for item in self] + ) + result = fallback.best_match(matches) + + if result is not None: + return result + + # Fall back to matching primary tags. If the client accepts + # "en", "en-US" is a valid match at this point. + fallback_matches = [_locale_delim_re.split(item, 1)[0] for item in matches] + result = super().best_match(fallback_matches) + + # Return a value from the original match list. Find the first + # original value that starts with the matched primary tag. + if result is not None: + return next(item for item in matches if item.startswith(result)) + + return default + + +class CharsetAccept(Accept): + """Like :class:`Accept` but with normalization for charsets.""" + + def _value_matches(self, value, item): + def _normalize(name): + try: + return codecs.lookup(name).name + except LookupError: + return name.lower() + + return item == "*" or _normalize(value) == _normalize(item) diff --git a/src/werkzeug/datastructures/accept.pyi b/src/werkzeug/datastructures/accept.pyi new file mode 100644 index 0000000..4b74dd9 --- /dev/null +++ b/src/werkzeug/datastructures/accept.pyi @@ -0,0 +1,54 @@ +from collections.abc import Iterable +from collections.abc import Iterator +from typing import overload + +from .structures import ImmutableList + +class Accept(ImmutableList[tuple[str, int]]): + provided: bool + def __init__( + self, values: Accept | Iterable[tuple[str, float]] | None = None + ) -> None: ... + def _specificity(self, value: str) -> tuple[bool, ...]: ... + def _value_matches(self, value: str, item: str) -> bool: ... + @overload # type: ignore + def __getitem__(self, key: str) -> int: ... + @overload + def __getitem__(self, key: int) -> tuple[str, int]: ... + @overload + def __getitem__(self, key: slice) -> Iterable[tuple[str, int]]: ... + def quality(self, key: str) -> int: ... + def __contains__(self, value: str) -> bool: ... # type: ignore + def index(self, key: str) -> int: ... # type: ignore + def find(self, key: str) -> int: ... + def values(self) -> Iterator[str]: ... + def to_header(self) -> str: ... + def _best_single_match(self, match: str) -> tuple[str, int] | None: ... + @overload + def best_match(self, matches: Iterable[str], default: str) -> str: ... + @overload + def best_match( + self, matches: Iterable[str], default: str | None = None + ) -> str | None: ... + @property + def best(self) -> str: ... + +def _normalize_mime(value: str) -> list[str]: ... + +class MIMEAccept(Accept): + def _specificity(self, value: str) -> tuple[bool, ...]: ... + def _value_matches(self, value: str, item: str) -> bool: ... + @property + def accept_html(self) -> bool: ... + @property + def accept_xhtml(self) -> bool: ... + @property + def accept_json(self) -> bool: ... + +def _normalize_lang(value: str) -> list[str]: ... + +class LanguageAccept(Accept): + def _value_matches(self, value: str, item: str) -> bool: ... + +class CharsetAccept(Accept): + def _value_matches(self, value: str, item: str) -> bool: ... diff --git a/src/werkzeug/datastructures/auth.py b/src/werkzeug/datastructures/auth.py new file mode 100644 index 0000000..a3ca0de --- /dev/null +++ b/src/werkzeug/datastructures/auth.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import base64 +import binascii +import typing as t + +from ..http import dump_header +from ..http import parse_dict_header +from ..http import quote_header_value +from .structures import CallbackDict + +if t.TYPE_CHECKING: + import typing_extensions as te + + +class Authorization: + """Represents the parts of an ``Authorization`` request header. + + :attr:`.Request.authorization` returns an instance if the header is set. + + An instance can be used with the test :class:`.Client` request methods' ``auth`` + parameter to send the header in test requests. + + Depending on the auth scheme, either :attr:`parameters` or :attr:`token` will be + set. The ``Basic`` scheme's token is decoded into the ``username`` and ``password`` + parameters. + + For convenience, ``auth["key"]`` and ``auth.key`` both access the key in the + :attr:`parameters` dict, along with ``auth.get("key")`` and ``"key" in auth``. + + .. versionchanged:: 2.3 + The ``token`` parameter and attribute was added to support auth schemes that use + a token instead of parameters, such as ``Bearer``. + + .. versionchanged:: 2.3 + The object is no longer a ``dict``. + + .. versionchanged:: 0.5 + The object is an immutable dict. + """ + + def __init__( + self, + auth_type: str, + data: dict[str, str | None] | None = None, + token: str | None = None, + ) -> None: + self.type = auth_type + """The authorization scheme, like ``basic``, ``digest``, or ``bearer``.""" + + if data is None: + data = {} + + self.parameters = data + """A dict of parameters parsed from the header. Either this or :attr:`token` + will have a value for a given scheme. + """ + + self.token = token + """A token parsed from the header. Either this or :attr:`parameters` will have a + value for a given scheme. + + .. versionadded:: 2.3 + """ + + def __getattr__(self, name: str) -> str | None: + return self.parameters.get(name) + + def __getitem__(self, name: str) -> str | None: + return self.parameters.get(name) + + def get(self, key: str, default: str | None = None) -> str | None: + return self.parameters.get(key, default) + + def __contains__(self, key: str) -> bool: + return key in self.parameters + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Authorization): + return NotImplemented + + return ( + other.type == self.type + and other.token == self.token + and other.parameters == self.parameters + ) + + @classmethod + def from_header(cls, value: str | None) -> te.Self | None: + """Parse an ``Authorization`` header value and return an instance, or ``None`` + if the value is empty. + + :param value: The header value to parse. + + .. versionadded:: 2.3 + """ + if not value: + return None + + scheme, _, rest = value.partition(" ") + scheme = scheme.lower() + rest = rest.strip() + + if scheme == "basic": + try: + username, _, password = base64.b64decode(rest).decode().partition(":") + except (binascii.Error, UnicodeError): + return None + + return cls(scheme, {"username": username, "password": password}) + + if "=" in rest.rstrip("="): + # = that is not trailing, this is parameters. + return cls(scheme, parse_dict_header(rest), None) + + # No = or only trailing =, this is a token. + return cls(scheme, None, rest) + + def to_header(self) -> str: + """Produce an ``Authorization`` header value representing this data. + + .. versionadded:: 2.0 + """ + if self.type == "basic": + value = base64.b64encode( + f"{self.username}:{self.password}".encode() + ).decode("ascii") + return f"Basic {value}" + + if self.token is not None: + return f"{self.type.title()} {self.token}" + + return f"{self.type.title()} {dump_header(self.parameters)}" + + def __str__(self) -> str: + return self.to_header() + + def __repr__(self) -> str: + return f"<{type(self).__name__} {self.to_header()}>" + + +class WWWAuthenticate: + """Represents the parts of a ``WWW-Authenticate`` response header. + + Set :attr:`.Response.www_authenticate` to an instance of list of instances to set + values for this header in the response. Modifying this instance will modify the + header value. + + Depending on the auth scheme, either :attr:`parameters` or :attr:`token` should be + set. The ``Basic`` scheme will encode ``username`` and ``password`` parameters to a + token. + + For convenience, ``auth["key"]`` and ``auth.key`` both act on the :attr:`parameters` + dict, and can be used to get, set, or delete parameters. ``auth.get("key")`` and + ``"key" in auth`` are also provided. + + .. versionchanged:: 2.3 + The ``token`` parameter and attribute was added to support auth schemes that use + a token instead of parameters, such as ``Bearer``. + + .. versionchanged:: 2.3 + The object is no longer a ``dict``. + + .. versionchanged:: 2.3 + The ``on_update`` parameter was removed. + """ + + def __init__( + self, + auth_type: str, + values: dict[str, str | None] | None = None, + token: str | None = None, + ): + self._type = auth_type.lower() + self._parameters: dict[str, str | None] = CallbackDict( + values, lambda _: self._trigger_on_update() + ) + self._token = token + self._on_update: t.Callable[[WWWAuthenticate], None] | None = None + + def _trigger_on_update(self) -> None: + if self._on_update is not None: + self._on_update(self) + + @property + def type(self) -> str: + """The authorization scheme, like ``basic``, ``digest``, or ``bearer``.""" + return self._type + + @type.setter + def type(self, value: str) -> None: + self._type = value + self._trigger_on_update() + + @property + def parameters(self) -> dict[str, str | None]: + """A dict of parameters for the header. Only one of this or :attr:`token` should + have a value for a given scheme. + """ + return self._parameters + + @parameters.setter + def parameters(self, value: dict[str, str]) -> None: + self._parameters = CallbackDict(value, lambda _: self._trigger_on_update()) + self._trigger_on_update() + + @property + def token(self) -> str | None: + """A dict of parameters for the header. Only one of this or :attr:`token` should + have a value for a given scheme. + """ + return self._token + + @token.setter + def token(self, value: str | None) -> None: + """A token for the header. Only one of this or :attr:`parameters` should have a + value for a given scheme. + + .. versionadded:: 2.3 + """ + self._token = value + self._trigger_on_update() + + def __getitem__(self, key: str) -> str | None: + return self.parameters.get(key) + + def __setitem__(self, key: str, value: str | None) -> None: + if value is None: + if key in self.parameters: + del self.parameters[key] + else: + self.parameters[key] = value + + self._trigger_on_update() + + def __delitem__(self, key: str) -> None: + if key in self.parameters: + del self.parameters[key] + self._trigger_on_update() + + def __getattr__(self, name: str) -> str | None: + return self[name] + + def __setattr__(self, name: str, value: str | None) -> None: + if name in {"_type", "_parameters", "_token", "_on_update"}: + super().__setattr__(name, value) + else: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + def __contains__(self, key: str) -> bool: + return key in self.parameters + + def __eq__(self, other: object) -> bool: + if not isinstance(other, WWWAuthenticate): + return NotImplemented + + return ( + other.type == self.type + and other.token == self.token + and other.parameters == self.parameters + ) + + def get(self, key: str, default: str | None = None) -> str | None: + return self.parameters.get(key, default) + + @classmethod + def from_header(cls, value: str | None) -> te.Self | None: + """Parse a ``WWW-Authenticate`` header value and return an instance, or ``None`` + if the value is empty. + + :param value: The header value to parse. + + .. versionadded:: 2.3 + """ + if not value: + return None + + scheme, _, rest = value.partition(" ") + scheme = scheme.lower() + rest = rest.strip() + + if "=" in rest.rstrip("="): + # = that is not trailing, this is parameters. + return cls(scheme, parse_dict_header(rest), None) + + # No = or only trailing =, this is a token. + return cls(scheme, None, rest) + + def to_header(self) -> str: + """Produce a ``WWW-Authenticate`` header value representing this data.""" + if self.token is not None: + return f"{self.type.title()} {self.token}" + + if self.type == "digest": + items = [] + + for key, value in self.parameters.items(): + if key in {"realm", "domain", "nonce", "opaque", "qop"}: + value = quote_header_value(value, allow_token=False) + else: + value = quote_header_value(value) + + items.append(f"{key}={value}") + + return f"Digest {', '.join(items)}" + + return f"{self.type.title()} {dump_header(self.parameters)}" + + def __str__(self) -> str: + return self.to_header() + + def __repr__(self) -> str: + return f"<{type(self).__name__} {self.to_header()}>" diff --git a/src/werkzeug/datastructures/cache_control.py b/src/werkzeug/datastructures/cache_control.py new file mode 100644 index 0000000..bff4c18 --- /dev/null +++ b/src/werkzeug/datastructures/cache_control.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from .mixins import ImmutableDictMixin +from .mixins import UpdateDictMixin + + +def cache_control_property(key, empty, type): + """Return a new property object for a cache header. Useful if you + want to add support for a cache extension in a subclass. + + .. versionchanged:: 2.0 + Renamed from ``cache_property``. + """ + return property( + lambda x: x._get_cache_value(key, empty, type), + lambda x, v: x._set_cache_value(key, v, type), + lambda x: x._del_cache_value(key), + f"accessor for {key!r}", + ) + + +class _CacheControl(UpdateDictMixin, dict): + """Subclass of a dict that stores values for a Cache-Control header. It + has accessors for all the cache-control directives specified in RFC 2616. + The class does not differentiate between request and response directives. + + Because the cache-control directives in the HTTP header use dashes the + python descriptors use underscores for that. + + To get a header of the :class:`CacheControl` object again you can convert + the object into a string or call the :meth:`to_header` method. If you plan + to subclass it and add your own items have a look at the sourcecode for + that class. + + .. versionchanged:: 2.1.0 + Setting int properties such as ``max_age`` will convert the + value to an int. + + .. versionchanged:: 0.4 + + Setting `no_cache` or `private` to boolean `True` will set the implicit + none-value which is ``*``: + + >>> cc = ResponseCacheControl() + >>> cc.no_cache = True + >>> cc + + >>> cc.no_cache + '*' + >>> cc.no_cache = None + >>> cc + + + In versions before 0.5 the behavior documented here affected the now + no longer existing `CacheControl` class. + """ + + no_cache = cache_control_property("no-cache", "*", None) + no_store = cache_control_property("no-store", None, bool) + max_age = cache_control_property("max-age", -1, int) + no_transform = cache_control_property("no-transform", None, None) + + def __init__(self, values=(), on_update=None): + dict.__init__(self, values or ()) + self.on_update = on_update + self.provided = values is not None + + def _get_cache_value(self, key, empty, type): + """Used internally by the accessor properties.""" + if type is bool: + return key in self + if key in self: + value = self[key] + if value is None: + return empty + elif type is not None: + try: + value = type(value) + except ValueError: + pass + return value + return None + + def _set_cache_value(self, key, value, type): + """Used internally by the accessor properties.""" + if type is bool: + if value: + self[key] = None + else: + self.pop(key, None) + else: + if value is None: + self.pop(key, None) + elif value is True: + self[key] = None + else: + if type is not None: + self[key] = type(value) + else: + self[key] = value + + def _del_cache_value(self, key): + """Used internally by the accessor properties.""" + if key in self: + del self[key] + + def to_header(self): + """Convert the stored values into a cache control header.""" + return http.dump_header(self) + + def __str__(self): + return self.to_header() + + def __repr__(self): + kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) + return f"<{type(self).__name__} {kv_str}>" + + cache_property = staticmethod(cache_control_property) + + +class RequestCacheControl(ImmutableDictMixin, _CacheControl): + """A cache control for requests. This is immutable and gives access + to all the request-relevant cache control headers. + + To get a header of the :class:`RequestCacheControl` object again you can + convert the object into a string or call the :meth:`to_header` method. If + you plan to subclass it and add your own items have a look at the sourcecode + for that class. + + .. versionchanged:: 2.1.0 + Setting int properties such as ``max_age`` will convert the + value to an int. + + .. versionadded:: 0.5 + In previous versions a `CacheControl` class existed that was used + both for request and response. + """ + + max_stale = cache_control_property("max-stale", "*", int) + min_fresh = cache_control_property("min-fresh", "*", int) + only_if_cached = cache_control_property("only-if-cached", None, bool) + + +class ResponseCacheControl(_CacheControl): + """A cache control for responses. Unlike :class:`RequestCacheControl` + this is mutable and gives access to response-relevant cache control + headers. + + To get a header of the :class:`ResponseCacheControl` object again you can + convert the object into a string or call the :meth:`to_header` method. If + you plan to subclass it and add your own items have a look at the sourcecode + for that class. + + .. versionchanged:: 2.1.1 + ``s_maxage`` converts the value to an int. + + .. versionchanged:: 2.1.0 + Setting int properties such as ``max_age`` will convert the + value to an int. + + .. versionadded:: 0.5 + In previous versions a `CacheControl` class existed that was used + both for request and response. + """ + + public = cache_control_property("public", None, bool) + private = cache_control_property("private", "*", None) + must_revalidate = cache_control_property("must-revalidate", None, bool) + proxy_revalidate = cache_control_property("proxy-revalidate", None, bool) + s_maxage = cache_control_property("s-maxage", None, int) + immutable = cache_control_property("immutable", None, bool) + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/cache_control.pyi b/src/werkzeug/datastructures/cache_control.pyi new file mode 100644 index 0000000..54ec020 --- /dev/null +++ b/src/werkzeug/datastructures/cache_control.pyi @@ -0,0 +1,115 @@ +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Mapping +from typing import TypeVar + +from .mixins import ImmutableDictMixin +from .mixins import UpdateDictMixin + +T = TypeVar("T") +_CPT = TypeVar("_CPT", str, int, bool) + +def cache_control_property( + key: str, empty: _CPT | None, type: type[_CPT] +) -> property: ... + +class _CacheControl( + UpdateDictMixin[str, str | int | bool | None], dict[str, str | int | bool | None] +): + provided: bool + def __init__( + self, + values: Mapping[str, str | int | bool | None] + | Iterable[tuple[str, str | int | bool | None]] = (), + on_update: Callable[[_CacheControl], None] | None = None, + ) -> None: ... + @property + def no_cache(self) -> bool | None: ... + @no_cache.setter + def no_cache(self, value: bool | None) -> None: ... + @no_cache.deleter + def no_cache(self) -> None: ... + @property + def no_store(self) -> bool | None: ... + @no_store.setter + def no_store(self, value: bool | None) -> None: ... + @no_store.deleter + def no_store(self) -> None: ... + @property + def max_age(self) -> int | None: ... + @max_age.setter + def max_age(self, value: int | None) -> None: ... + @max_age.deleter + def max_age(self) -> None: ... + @property + def no_transform(self) -> bool | None: ... + @no_transform.setter + def no_transform(self, value: bool | None) -> None: ... + @no_transform.deleter + def no_transform(self) -> None: ... + def _get_cache_value(self, key: str, empty: T | None, type: type[T]) -> T: ... + def _set_cache_value(self, key: str, value: T | None, type: type[T]) -> None: ... + def _del_cache_value(self, key: str) -> None: ... + def to_header(self) -> str: ... + @staticmethod + def cache_property(key: str, empty: _CPT | None, type: type[_CPT]) -> property: ... + +class RequestCacheControl( # type: ignore[misc] + ImmutableDictMixin[str, str | int | bool | None], _CacheControl +): + @property + def max_stale(self) -> int | None: ... + @max_stale.setter + def max_stale(self, value: int | None) -> None: ... + @max_stale.deleter + def max_stale(self) -> None: ... + @property + def min_fresh(self) -> int | None: ... + @min_fresh.setter + def min_fresh(self, value: int | None) -> None: ... + @min_fresh.deleter + def min_fresh(self) -> None: ... + @property + def only_if_cached(self) -> bool | None: ... + @only_if_cached.setter + def only_if_cached(self, value: bool | None) -> None: ... + @only_if_cached.deleter + def only_if_cached(self) -> None: ... + +class ResponseCacheControl(_CacheControl): + @property + def public(self) -> bool | None: ... + @public.setter + def public(self, value: bool | None) -> None: ... + @public.deleter + def public(self) -> None: ... + @property + def private(self) -> bool | None: ... + @private.setter + def private(self, value: bool | None) -> None: ... + @private.deleter + def private(self) -> None: ... + @property + def must_revalidate(self) -> bool | None: ... + @must_revalidate.setter + def must_revalidate(self, value: bool | None) -> None: ... + @must_revalidate.deleter + def must_revalidate(self) -> None: ... + @property + def proxy_revalidate(self) -> bool | None: ... + @proxy_revalidate.setter + def proxy_revalidate(self, value: bool | None) -> None: ... + @proxy_revalidate.deleter + def proxy_revalidate(self) -> None: ... + @property + def s_maxage(self) -> int | None: ... + @s_maxage.setter + def s_maxage(self, value: int | None) -> None: ... + @s_maxage.deleter + def s_maxage(self) -> None: ... + @property + def immutable(self) -> bool | None: ... + @immutable.setter + def immutable(self, value: bool | None) -> None: ... + @immutable.deleter + def immutable(self) -> None: ... diff --git a/src/werkzeug/datastructures/csp.py b/src/werkzeug/datastructures/csp.py new file mode 100644 index 0000000..dde9414 --- /dev/null +++ b/src/werkzeug/datastructures/csp.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from .mixins import UpdateDictMixin + + +def csp_property(key): + """Return a new property object for a content security policy header. + Useful if you want to add support for a csp extension in a + subclass. + """ + return property( + lambda x: x._get_value(key), + lambda x, v: x._set_value(key, v), + lambda x: x._del_value(key), + f"accessor for {key!r}", + ) + + +class ContentSecurityPolicy(UpdateDictMixin, dict): + """Subclass of a dict that stores values for a Content Security Policy + header. It has accessors for all the level 3 policies. + + Because the csp directives in the HTTP header use dashes the + python descriptors use underscores for that. + + To get a header of the :class:`ContentSecuirtyPolicy` object again + you can convert the object into a string or call the + :meth:`to_header` method. If you plan to subclass it and add your + own items have a look at the sourcecode for that class. + + .. versionadded:: 1.0.0 + Support for Content Security Policy headers was added. + + """ + + base_uri = csp_property("base-uri") + child_src = csp_property("child-src") + connect_src = csp_property("connect-src") + default_src = csp_property("default-src") + font_src = csp_property("font-src") + form_action = csp_property("form-action") + frame_ancestors = csp_property("frame-ancestors") + frame_src = csp_property("frame-src") + img_src = csp_property("img-src") + manifest_src = csp_property("manifest-src") + media_src = csp_property("media-src") + navigate_to = csp_property("navigate-to") + object_src = csp_property("object-src") + prefetch_src = csp_property("prefetch-src") + plugin_types = csp_property("plugin-types") + report_to = csp_property("report-to") + report_uri = csp_property("report-uri") + sandbox = csp_property("sandbox") + script_src = csp_property("script-src") + script_src_attr = csp_property("script-src-attr") + script_src_elem = csp_property("script-src-elem") + style_src = csp_property("style-src") + style_src_attr = csp_property("style-src-attr") + style_src_elem = csp_property("style-src-elem") + worker_src = csp_property("worker-src") + + def __init__(self, values=(), on_update=None): + dict.__init__(self, values or ()) + self.on_update = on_update + self.provided = values is not None + + def _get_value(self, key): + """Used internally by the accessor properties.""" + return self.get(key) + + def _set_value(self, key, value): + """Used internally by the accessor properties.""" + if value is None: + self.pop(key, None) + else: + self[key] = value + + def _del_value(self, key): + """Used internally by the accessor properties.""" + if key in self: + del self[key] + + def to_header(self): + """Convert the stored values into a cache control header.""" + from ..http import dump_csp_header + + return dump_csp_header(self) + + def __str__(self): + return self.to_header() + + def __repr__(self): + kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) + return f"<{type(self).__name__} {kv_str}>" diff --git a/src/werkzeug/datastructures/csp.pyi b/src/werkzeug/datastructures/csp.pyi new file mode 100644 index 0000000..f9e2ac0 --- /dev/null +++ b/src/werkzeug/datastructures/csp.pyi @@ -0,0 +1,169 @@ +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Mapping + +from .mixins import UpdateDictMixin + +def csp_property(key: str) -> property: ... + +class ContentSecurityPolicy(UpdateDictMixin[str, str], dict[str, str]): + @property + def base_uri(self) -> str | None: ... + @base_uri.setter + def base_uri(self, value: str | None) -> None: ... + @base_uri.deleter + def base_uri(self) -> None: ... + @property + def child_src(self) -> str | None: ... + @child_src.setter + def child_src(self, value: str | None) -> None: ... + @child_src.deleter + def child_src(self) -> None: ... + @property + def connect_src(self) -> str | None: ... + @connect_src.setter + def connect_src(self, value: str | None) -> None: ... + @connect_src.deleter + def connect_src(self) -> None: ... + @property + def default_src(self) -> str | None: ... + @default_src.setter + def default_src(self, value: str | None) -> None: ... + @default_src.deleter + def default_src(self) -> None: ... + @property + def font_src(self) -> str | None: ... + @font_src.setter + def font_src(self, value: str | None) -> None: ... + @font_src.deleter + def font_src(self) -> None: ... + @property + def form_action(self) -> str | None: ... + @form_action.setter + def form_action(self, value: str | None) -> None: ... + @form_action.deleter + def form_action(self) -> None: ... + @property + def frame_ancestors(self) -> str | None: ... + @frame_ancestors.setter + def frame_ancestors(self, value: str | None) -> None: ... + @frame_ancestors.deleter + def frame_ancestors(self) -> None: ... + @property + def frame_src(self) -> str | None: ... + @frame_src.setter + def frame_src(self, value: str | None) -> None: ... + @frame_src.deleter + def frame_src(self) -> None: ... + @property + def img_src(self) -> str | None: ... + @img_src.setter + def img_src(self, value: str | None) -> None: ... + @img_src.deleter + def img_src(self) -> None: ... + @property + def manifest_src(self) -> str | None: ... + @manifest_src.setter + def manifest_src(self, value: str | None) -> None: ... + @manifest_src.deleter + def manifest_src(self) -> None: ... + @property + def media_src(self) -> str | None: ... + @media_src.setter + def media_src(self, value: str | None) -> None: ... + @media_src.deleter + def media_src(self) -> None: ... + @property + def navigate_to(self) -> str | None: ... + @navigate_to.setter + def navigate_to(self, value: str | None) -> None: ... + @navigate_to.deleter + def navigate_to(self) -> None: ... + @property + def object_src(self) -> str | None: ... + @object_src.setter + def object_src(self, value: str | None) -> None: ... + @object_src.deleter + def object_src(self) -> None: ... + @property + def prefetch_src(self) -> str | None: ... + @prefetch_src.setter + def prefetch_src(self, value: str | None) -> None: ... + @prefetch_src.deleter + def prefetch_src(self) -> None: ... + @property + def plugin_types(self) -> str | None: ... + @plugin_types.setter + def plugin_types(self, value: str | None) -> None: ... + @plugin_types.deleter + def plugin_types(self) -> None: ... + @property + def report_to(self) -> str | None: ... + @report_to.setter + def report_to(self, value: str | None) -> None: ... + @report_to.deleter + def report_to(self) -> None: ... + @property + def report_uri(self) -> str | None: ... + @report_uri.setter + def report_uri(self, value: str | None) -> None: ... + @report_uri.deleter + def report_uri(self) -> None: ... + @property + def sandbox(self) -> str | None: ... + @sandbox.setter + def sandbox(self, value: str | None) -> None: ... + @sandbox.deleter + def sandbox(self) -> None: ... + @property + def script_src(self) -> str | None: ... + @script_src.setter + def script_src(self, value: str | None) -> None: ... + @script_src.deleter + def script_src(self) -> None: ... + @property + def script_src_attr(self) -> str | None: ... + @script_src_attr.setter + def script_src_attr(self, value: str | None) -> None: ... + @script_src_attr.deleter + def script_src_attr(self) -> None: ... + @property + def script_src_elem(self) -> str | None: ... + @script_src_elem.setter + def script_src_elem(self, value: str | None) -> None: ... + @script_src_elem.deleter + def script_src_elem(self) -> None: ... + @property + def style_src(self) -> str | None: ... + @style_src.setter + def style_src(self, value: str | None) -> None: ... + @style_src.deleter + def style_src(self) -> None: ... + @property + def style_src_attr(self) -> str | None: ... + @style_src_attr.setter + def style_src_attr(self, value: str | None) -> None: ... + @style_src_attr.deleter + def style_src_attr(self) -> None: ... + @property + def style_src_elem(self) -> str | None: ... + @style_src_elem.setter + def style_src_elem(self, value: str | None) -> None: ... + @style_src_elem.deleter + def style_src_elem(self) -> None: ... + @property + def worker_src(self) -> str | None: ... + @worker_src.setter + def worker_src(self, value: str | None) -> None: ... + @worker_src.deleter + def worker_src(self) -> None: ... + provided: bool + def __init__( + self, + values: Mapping[str, str] | Iterable[tuple[str, str]] = (), + on_update: Callable[[ContentSecurityPolicy], None] | None = None, + ) -> None: ... + def _get_value(self, key: str) -> str | None: ... + def _set_value(self, key: str, value: str) -> None: ... + def _del_value(self, key: str) -> None: ... + def to_header(self) -> str: ... diff --git a/src/werkzeug/datastructures/etag.py b/src/werkzeug/datastructures/etag.py new file mode 100644 index 0000000..747d996 --- /dev/null +++ b/src/werkzeug/datastructures/etag.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from collections.abc import Collection + + +class ETags(Collection): + """A set that can be used to check if one etag is present in a collection + of etags. + """ + + def __init__(self, strong_etags=None, weak_etags=None, star_tag=False): + if not star_tag and strong_etags: + self._strong = frozenset(strong_etags) + else: + self._strong = frozenset() + + self._weak = frozenset(weak_etags or ()) + self.star_tag = star_tag + + def as_set(self, include_weak=False): + """Convert the `ETags` object into a python set. Per default all the + weak etags are not part of this set.""" + rv = set(self._strong) + if include_weak: + rv.update(self._weak) + return rv + + def is_weak(self, etag): + """Check if an etag is weak.""" + return etag in self._weak + + def is_strong(self, etag): + """Check if an etag is strong.""" + return etag in self._strong + + def contains_weak(self, etag): + """Check if an etag is part of the set including weak and strong tags.""" + return self.is_weak(etag) or self.contains(etag) + + def contains(self, etag): + """Check if an etag is part of the set ignoring weak tags. + It is also possible to use the ``in`` operator. + """ + if self.star_tag: + return True + return self.is_strong(etag) + + def contains_raw(self, etag): + """When passed a quoted tag it will check if this tag is part of the + set. If the tag is weak it is checked against weak and strong tags, + otherwise strong only.""" + from ..http import unquote_etag + + etag, weak = unquote_etag(etag) + if weak: + return self.contains_weak(etag) + return self.contains(etag) + + def to_header(self): + """Convert the etags set into a HTTP header string.""" + if self.star_tag: + return "*" + return ", ".join( + [f'"{x}"' for x in self._strong] + [f'W/"{x}"' for x in self._weak] + ) + + def __call__(self, etag=None, data=None, include_weak=False): + if [etag, data].count(None) != 1: + raise TypeError("either tag or data required, but at least one") + if etag is None: + from ..http import generate_etag + + etag = generate_etag(data) + if include_weak: + if etag in self._weak: + return True + return etag in self._strong + + def __bool__(self): + return bool(self.star_tag or self._strong or self._weak) + + def __str__(self): + return self.to_header() + + def __len__(self): + return len(self._strong) + + def __iter__(self): + return iter(self._strong) + + def __contains__(self, etag): + return self.contains(etag) + + def __repr__(self): + return f"<{type(self).__name__} {str(self)!r}>" diff --git a/src/werkzeug/datastructures/etag.pyi b/src/werkzeug/datastructures/etag.pyi new file mode 100644 index 0000000..88e54f1 --- /dev/null +++ b/src/werkzeug/datastructures/etag.pyi @@ -0,0 +1,30 @@ +from collections.abc import Collection +from collections.abc import Iterable +from collections.abc import Iterator + +class ETags(Collection[str]): + _strong: frozenset[str] + _weak: frozenset[str] + star_tag: bool + def __init__( + self, + strong_etags: Iterable[str] | None = None, + weak_etags: Iterable[str] | None = None, + star_tag: bool = False, + ) -> None: ... + def as_set(self, include_weak: bool = False) -> set[str]: ... + def is_weak(self, etag: str) -> bool: ... + def is_strong(self, etag: str) -> bool: ... + def contains_weak(self, etag: str) -> bool: ... + def contains(self, etag: str) -> bool: ... + def contains_raw(self, etag: str) -> bool: ... + def to_header(self) -> str: ... + def __call__( + self, + etag: str | None = None, + data: bytes | None = None, + include_weak: bool = False, + ) -> bool: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[str]: ... + def __contains__(self, item: str) -> bool: ... # type: ignore diff --git a/src/werkzeug/datastructures/file_storage.py b/src/werkzeug/datastructures/file_storage.py new file mode 100644 index 0000000..e878a56 --- /dev/null +++ b/src/werkzeug/datastructures/file_storage.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import mimetypes +from io import BytesIO +from os import fsdecode +from os import fspath + +from .._internal import _plain_int +from .structures import MultiDict + + +class FileStorage: + """The :class:`FileStorage` class is a thin wrapper over incoming files. + It is used by the request object to represent uploaded files. All the + attributes of the wrapper stream are proxied by the file storage so + it's possible to do ``storage.read()`` instead of the long form + ``storage.stream.read()``. + """ + + def __init__( + self, + stream=None, + filename=None, + name=None, + content_type=None, + content_length=None, + headers=None, + ): + self.name = name + self.stream = stream or BytesIO() + + # If no filename is provided, attempt to get the filename from + # the stream object. Python names special streams like + # ```` with angular brackets, skip these streams. + if filename is None: + filename = getattr(stream, "name", None) + + if filename is not None: + filename = fsdecode(filename) + + if filename and filename[0] == "<" and filename[-1] == ">": + filename = None + else: + filename = fsdecode(filename) + + self.filename = filename + + if headers is None: + from .headers import Headers + + headers = Headers() + self.headers = headers + if content_type is not None: + headers["Content-Type"] = content_type + if content_length is not None: + headers["Content-Length"] = str(content_length) + + def _parse_content_type(self): + if not hasattr(self, "_parsed_content_type"): + self._parsed_content_type = http.parse_options_header(self.content_type) + + @property + def content_type(self): + """The content-type sent in the header. Usually not available""" + return self.headers.get("content-type") + + @property + def content_length(self): + """The content-length sent in the header. Usually not available""" + if "content-length" in self.headers: + try: + return _plain_int(self.headers["content-length"]) + except ValueError: + pass + + return 0 + + @property + def mimetype(self): + """Like :attr:`content_type`, but without parameters (eg, without + charset, type etc.) and always lowercase. For example if the content + type is ``text/HTML; charset=utf-8`` the mimetype would be + ``'text/html'``. + + .. versionadded:: 0.7 + """ + self._parse_content_type() + return self._parsed_content_type[0].lower() + + @property + def mimetype_params(self): + """The mimetype parameters as dict. For example if the content + type is ``text/html; charset=utf-8`` the params would be + ``{'charset': 'utf-8'}``. + + .. versionadded:: 0.7 + """ + self._parse_content_type() + return self._parsed_content_type[1] + + def save(self, dst, buffer_size=16384): + """Save the file to a destination path or file object. If the + destination is a file object you have to close it yourself after the + call. The buffer size is the number of bytes held in memory during + the copy process. It defaults to 16KB. + + For secure file saving also have a look at :func:`secure_filename`. + + :param dst: a filename, :class:`os.PathLike`, or open file + object to write to. + :param buffer_size: Passed as the ``length`` parameter of + :func:`shutil.copyfileobj`. + + .. versionchanged:: 1.0 + Supports :mod:`pathlib`. + """ + from shutil import copyfileobj + + close_dst = False + + if hasattr(dst, "__fspath__"): + dst = fspath(dst) + + if isinstance(dst, str): + dst = open(dst, "wb") + close_dst = True + + try: + copyfileobj(self.stream, dst, buffer_size) + finally: + if close_dst: + dst.close() + + def close(self): + """Close the underlying file if possible.""" + try: + self.stream.close() + except Exception: + pass + + def __bool__(self): + return bool(self.filename) + + def __getattr__(self, name): + try: + return getattr(self.stream, name) + except AttributeError: + # SpooledTemporaryFile doesn't implement IOBase, get the + # attribute from its backing file instead. + # https://github.com/python/cpython/pull/3249 + if hasattr(self.stream, "_file"): + return getattr(self.stream._file, name) + raise + + def __iter__(self): + return iter(self.stream) + + def __repr__(self): + return f"<{type(self).__name__}: {self.filename!r} ({self.content_type!r})>" + + +class FileMultiDict(MultiDict): + """A special :class:`MultiDict` that has convenience methods to add + files to it. This is used for :class:`EnvironBuilder` and generally + useful for unittesting. + + .. versionadded:: 0.5 + """ + + def add_file(self, name, file, filename=None, content_type=None): + """Adds a new file to the dict. `file` can be a file name or + a :class:`file`-like or a :class:`FileStorage` object. + + :param name: the name of the field. + :param file: a filename or :class:`file`-like object + :param filename: an optional filename + :param content_type: an optional content type + """ + if isinstance(file, FileStorage): + value = file + else: + if isinstance(file, str): + if filename is None: + filename = file + file = open(file, "rb") + if filename and content_type is None: + content_type = ( + mimetypes.guess_type(filename)[0] or "application/octet-stream" + ) + value = FileStorage(file, filename, name, content_type) + + self.add(name, value) + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/file_storage.pyi b/src/werkzeug/datastructures/file_storage.pyi new file mode 100644 index 0000000..36a7ed9 --- /dev/null +++ b/src/werkzeug/datastructures/file_storage.pyi @@ -0,0 +1,49 @@ +from collections.abc import Iterator +from os import PathLike +from typing import Any +from typing import IO + +from .headers import Headers +from .structures import MultiDict + +class FileStorage: + name: str | None + stream: IO[bytes] + filename: str | None + headers: Headers + _parsed_content_type: tuple[str, dict[str, str]] + def __init__( + self, + stream: IO[bytes] | None = None, + filename: str | PathLike[str] | None = None, + name: str | None = None, + content_type: str | None = None, + content_length: int | None = None, + headers: Headers | None = None, + ) -> None: ... + def _parse_content_type(self) -> None: ... + @property + def content_type(self) -> str: ... + @property + def content_length(self) -> int: ... + @property + def mimetype(self) -> str: ... + @property + def mimetype_params(self) -> dict[str, str]: ... + def save( + self, dst: str | PathLike[str] | IO[bytes], buffer_size: int = ... + ) -> None: ... + def close(self) -> None: ... + def __bool__(self) -> bool: ... + def __getattr__(self, name: str) -> Any: ... + def __iter__(self) -> Iterator[bytes]: ... + def __repr__(self) -> str: ... + +class FileMultiDict(MultiDict[str, FileStorage]): + def add_file( + self, + name: str, + file: FileStorage | str | IO[bytes], + filename: str | None = None, + content_type: str | None = None, + ) -> None: ... diff --git a/src/werkzeug/datastructures/headers.py b/src/werkzeug/datastructures/headers.py new file mode 100644 index 0000000..d9dd655 --- /dev/null +++ b/src/werkzeug/datastructures/headers.py @@ -0,0 +1,515 @@ +from __future__ import annotations + +import re +import typing as t + +from .._internal import _missing +from ..exceptions import BadRequestKeyError +from .mixins import ImmutableHeadersMixin +from .structures import iter_multi_items +from .structures import MultiDict + + +class Headers: + """An object that stores some headers. It has a dict-like interface, + but is ordered, can store the same key multiple times, and iterating + yields ``(key, value)`` pairs instead of only keys. + + This data structure is useful if you want a nicer way to handle WSGI + headers which are stored as tuples in a list. + + From Werkzeug 0.3 onwards, the :exc:`KeyError` raised by this class is + also a subclass of the :class:`~exceptions.BadRequest` HTTP exception + and will render a page for a ``400 BAD REQUEST`` if caught in a + catch-all for HTTP exceptions. + + Headers is mostly compatible with the Python :class:`wsgiref.headers.Headers` + class, with the exception of `__getitem__`. :mod:`wsgiref` will return + `None` for ``headers['missing']``, whereas :class:`Headers` will raise + a :class:`KeyError`. + + To create a new ``Headers`` object, pass it a list, dict, or + other ``Headers`` object with default values. These values are + validated the same way values added later are. + + :param defaults: The list of default values for the :class:`Headers`. + + .. versionchanged:: 2.1.0 + Default values are validated the same as values added later. + + .. versionchanged:: 0.9 + This data structure now stores unicode values similar to how the + multi dicts do it. The main difference is that bytes can be set as + well which will automatically be latin1 decoded. + + .. versionchanged:: 0.9 + The :meth:`linked` function was removed without replacement as it + was an API that does not support the changes to the encoding model. + """ + + def __init__(self, defaults=None): + self._list = [] + if defaults is not None: + self.extend(defaults) + + def __getitem__(self, key, _get_mode=False): + if not _get_mode: + if isinstance(key, int): + return self._list[key] + elif isinstance(key, slice): + return self.__class__(self._list[key]) + if not isinstance(key, str): + raise BadRequestKeyError(key) + ikey = key.lower() + for k, v in self._list: + if k.lower() == ikey: + return v + # micro optimization: if we are in get mode we will catch that + # exception one stack level down so we can raise a standard + # key error instead of our special one. + if _get_mode: + raise KeyError() + raise BadRequestKeyError(key) + + def __eq__(self, other): + def lowered(item): + return (item[0].lower(),) + item[1:] + + return other.__class__ is self.__class__ and set( + map(lowered, other._list) + ) == set(map(lowered, self._list)) + + __hash__ = None + + def get(self, key, default=None, type=None): + """Return the default value if the requested data doesn't exist. + If `type` is provided and is a callable it should convert the value, + return it or raise a :exc:`ValueError` if that is not possible. In + this case the function will return the default as if the value was not + found: + + >>> d = Headers([('Content-Length', '42')]) + >>> d.get('Content-Length', type=int) + 42 + + :param key: The key to be looked up. + :param default: The default value to be returned if the key can't + be looked up. If not further specified `None` is + returned. + :param type: A callable that is used to cast the value in the + :class:`Headers`. If a :exc:`ValueError` is raised + by this callable the default value is returned. + + .. versionchanged:: 3.0 + The ``as_bytes`` parameter was removed. + + .. versionchanged:: 0.9 + The ``as_bytes`` parameter was added. + """ + try: + rv = self.__getitem__(key, _get_mode=True) + except KeyError: + return default + if type is None: + return rv + try: + return type(rv) + except ValueError: + return default + + def getlist(self, key, type=None): + """Return the list of items for a given key. If that key is not in the + :class:`Headers`, the return value will be an empty list. Just like + :meth:`get`, :meth:`getlist` accepts a `type` parameter. All items will + be converted with the callable defined there. + + :param key: The key to be looked up. + :param type: A callable that is used to cast the value in the + :class:`Headers`. If a :exc:`ValueError` is raised + by this callable the value will be removed from the list. + :return: a :class:`list` of all the values for the key. + + .. versionchanged:: 3.0 + The ``as_bytes`` parameter was removed. + + .. versionchanged:: 0.9 + The ``as_bytes`` parameter was added. + """ + ikey = key.lower() + result = [] + for k, v in self: + if k.lower() == ikey: + if type is not None: + try: + v = type(v) + except ValueError: + continue + result.append(v) + return result + + def get_all(self, name): + """Return a list of all the values for the named field. + + This method is compatible with the :mod:`wsgiref` + :meth:`~wsgiref.headers.Headers.get_all` method. + """ + return self.getlist(name) + + def items(self, lower=False): + for key, value in self: + if lower: + key = key.lower() + yield key, value + + def keys(self, lower=False): + for key, _ in self.items(lower): + yield key + + def values(self): + for _, value in self.items(): + yield value + + def extend(self, *args, **kwargs): + """Extend headers in this object with items from another object + containing header items as well as keyword arguments. + + To replace existing keys instead of extending, use + :meth:`update` instead. + + If provided, the first argument can be another :class:`Headers` + object, a :class:`MultiDict`, :class:`dict`, or iterable of + pairs. + + .. versionchanged:: 1.0 + Support :class:`MultiDict`. Allow passing ``kwargs``. + """ + if len(args) > 1: + raise TypeError(f"update expected at most 1 arguments, got {len(args)}") + + if args: + for key, value in iter_multi_items(args[0]): + self.add(key, value) + + for key, value in iter_multi_items(kwargs): + self.add(key, value) + + def __delitem__(self, key, _index_operation=True): + if _index_operation and isinstance(key, (int, slice)): + del self._list[key] + return + key = key.lower() + new = [] + for k, v in self._list: + if k.lower() != key: + new.append((k, v)) + self._list[:] = new + + def remove(self, key): + """Remove a key. + + :param key: The key to be removed. + """ + return self.__delitem__(key, _index_operation=False) + + def pop(self, key=None, default=_missing): + """Removes and returns a key or index. + + :param key: The key to be popped. If this is an integer the item at + that position is removed, if it's a string the value for + that key is. If the key is omitted or `None` the last + item is removed. + :return: an item. + """ + if key is None: + return self._list.pop() + if isinstance(key, int): + return self._list.pop(key) + try: + rv = self[key] + self.remove(key) + except KeyError: + if default is not _missing: + return default + raise + return rv + + def popitem(self): + """Removes a key or index and returns a (key, value) item.""" + return self.pop() + + def __contains__(self, key): + """Check if a key is present.""" + try: + self.__getitem__(key, _get_mode=True) + except KeyError: + return False + return True + + def __iter__(self): + """Yield ``(key, value)`` tuples.""" + return iter(self._list) + + def __len__(self): + return len(self._list) + + def add(self, _key, _value, **kw): + """Add a new header tuple to the list. + + Keyword arguments can specify additional parameters for the header + value, with underscores converted to dashes:: + + >>> d = Headers() + >>> d.add('Content-Type', 'text/plain') + >>> d.add('Content-Disposition', 'attachment', filename='foo.png') + + The keyword argument dumping uses :func:`dump_options_header` + behind the scenes. + + .. versionadded:: 0.4.1 + keyword arguments were added for :mod:`wsgiref` compatibility. + """ + if kw: + _value = _options_header_vkw(_value, kw) + _value = _str_header_value(_value) + self._list.append((_key, _value)) + + def add_header(self, _key, _value, **_kw): + """Add a new header tuple to the list. + + An alias for :meth:`add` for compatibility with the :mod:`wsgiref` + :meth:`~wsgiref.headers.Headers.add_header` method. + """ + self.add(_key, _value, **_kw) + + def clear(self): + """Clears all headers.""" + del self._list[:] + + def set(self, _key, _value, **kw): + """Remove all header tuples for `key` and add a new one. The newly + added key either appears at the end of the list if there was no + entry or replaces the first one. + + Keyword arguments can specify additional parameters for the header + value, with underscores converted to dashes. See :meth:`add` for + more information. + + .. versionchanged:: 0.6.1 + :meth:`set` now accepts the same arguments as :meth:`add`. + + :param key: The key to be inserted. + :param value: The value to be inserted. + """ + if kw: + _value = _options_header_vkw(_value, kw) + _value = _str_header_value(_value) + if not self._list: + self._list.append((_key, _value)) + return + listiter = iter(self._list) + ikey = _key.lower() + for idx, (old_key, _old_value) in enumerate(listiter): + if old_key.lower() == ikey: + # replace first occurrence + self._list[idx] = (_key, _value) + break + else: + self._list.append((_key, _value)) + return + self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey] + + def setlist(self, key, values): + """Remove any existing values for a header and add new ones. + + :param key: The header key to set. + :param values: An iterable of values to set for the key. + + .. versionadded:: 1.0 + """ + if values: + values_iter = iter(values) + self.set(key, next(values_iter)) + + for value in values_iter: + self.add(key, value) + else: + self.remove(key) + + def setdefault(self, key, default): + """Return the first value for the key if it is in the headers, + otherwise set the header to the value given by ``default`` and + return that. + + :param key: The header key to get. + :param default: The value to set for the key if it is not in the + headers. + """ + if key in self: + return self[key] + + self.set(key, default) + return default + + def setlistdefault(self, key, default): + """Return the list of values for the key if it is in the + headers, otherwise set the header to the list of values given + by ``default`` and return that. + + Unlike :meth:`MultiDict.setlistdefault`, modifying the returned + list will not affect the headers. + + :param key: The header key to get. + :param default: An iterable of values to set for the key if it + is not in the headers. + + .. versionadded:: 1.0 + """ + if key not in self: + self.setlist(key, default) + + return self.getlist(key) + + def __setitem__(self, key, value): + """Like :meth:`set` but also supports index/slice based setting.""" + if isinstance(key, (slice, int)): + if isinstance(key, int): + value = [value] + value = [(k, _str_header_value(v)) for (k, v) in value] + if isinstance(key, int): + self._list[key] = value[0] + else: + self._list[key] = value + else: + self.set(key, value) + + def update(self, *args, **kwargs): + """Replace headers in this object with items from another + headers object and keyword arguments. + + To extend existing keys instead of replacing, use :meth:`extend` + instead. + + If provided, the first argument can be another :class:`Headers` + object, a :class:`MultiDict`, :class:`dict`, or iterable of + pairs. + + .. versionadded:: 1.0 + """ + if len(args) > 1: + raise TypeError(f"update expected at most 1 arguments, got {len(args)}") + + if args: + mapping = args[0] + + if isinstance(mapping, (Headers, MultiDict)): + for key in mapping.keys(): + self.setlist(key, mapping.getlist(key)) + elif isinstance(mapping, dict): + for key, value in mapping.items(): + if isinstance(value, (list, tuple)): + self.setlist(key, value) + else: + self.set(key, value) + else: + for key, value in mapping: + self.set(key, value) + + for key, value in kwargs.items(): + if isinstance(value, (list, tuple)): + self.setlist(key, value) + else: + self.set(key, value) + + def to_wsgi_list(self): + """Convert the headers into a list suitable for WSGI. + + :return: list + """ + return list(self) + + def copy(self): + return self.__class__(self._list) + + def __copy__(self): + return self.copy() + + def __str__(self): + """Returns formatted headers suitable for HTTP transmission.""" + strs = [] + for key, value in self.to_wsgi_list(): + strs.append(f"{key}: {value}") + strs.append("\r\n") + return "\r\n".join(strs) + + def __repr__(self): + return f"{type(self).__name__}({list(self)!r})" + + +def _options_header_vkw(value: str, kw: dict[str, t.Any]): + return http.dump_options_header( + value, {k.replace("_", "-"): v for k, v in kw.items()} + ) + + +_newline_re = re.compile(r"[\r\n]") + + +def _str_header_value(value: t.Any) -> str: + if not isinstance(value, str): + value = str(value) + + if _newline_re.search(value) is not None: + raise ValueError("Header values must not contain newline characters.") + + return value + + +class EnvironHeaders(ImmutableHeadersMixin, Headers): + """Read only version of the headers from a WSGI environment. This + provides the same interface as `Headers` and is constructed from + a WSGI environment. + From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a + subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will + render a page for a ``400 BAD REQUEST`` if caught in a catch-all for + HTTP exceptions. + """ + + def __init__(self, environ): + self.environ = environ + + def __eq__(self, other): + return self.environ is other.environ + + __hash__ = None + + def __getitem__(self, key, _get_mode=False): + # _get_mode is a no-op for this class as there is no index but + # used because get() calls it. + if not isinstance(key, str): + raise KeyError(key) + key = key.upper().replace("-", "_") + if key in {"CONTENT_TYPE", "CONTENT_LENGTH"}: + return self.environ[key] + return self.environ[f"HTTP_{key}"] + + def __len__(self): + # the iter is necessary because otherwise list calls our + # len which would call list again and so forth. + return len(list(iter(self))) + + def __iter__(self): + for key, value in self.environ.items(): + if key.startswith("HTTP_") and key not in { + "HTTP_CONTENT_TYPE", + "HTTP_CONTENT_LENGTH", + }: + yield key[5:].replace("_", "-").title(), value + elif key in {"CONTENT_TYPE", "CONTENT_LENGTH"} and value: + yield key.replace("_", "-").title(), value + + def copy(self): + raise TypeError(f"cannot create {type(self).__name__!r} copies") + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/headers.pyi b/src/werkzeug/datastructures/headers.pyi new file mode 100644 index 0000000..8650222 --- /dev/null +++ b/src/werkzeug/datastructures/headers.pyi @@ -0,0 +1,109 @@ +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Iterator +from collections.abc import Mapping +from typing import Literal +from typing import NoReturn +from typing import overload +from typing import TypeVar + +from _typeshed import SupportsKeysAndGetItem +from _typeshed.wsgi import WSGIEnvironment + +from .mixins import ImmutableHeadersMixin + +D = TypeVar("D") +T = TypeVar("T") + +class Headers(dict[str, str]): + _list: list[tuple[str, str]] + def __init__( + self, + defaults: Mapping[str, str | Iterable[str]] + | Iterable[tuple[str, str]] + | None = None, + ) -> None: ... + @overload + def __getitem__(self, key: str) -> str: ... + @overload + def __getitem__(self, key: int) -> tuple[str, str]: ... + @overload + def __getitem__(self, key: slice) -> Headers: ... + @overload + def __getitem__(self, key: str, _get_mode: Literal[True] = ...) -> str: ... + def __eq__(self, other: object) -> bool: ... + @overload # type: ignore + def get(self, key: str, default: str) -> str: ... + @overload + def get(self, key: str, default: str | None = None) -> str | None: ... + @overload + def get( + self, key: str, default: T | None = None, type: Callable[[str], T] = ... + ) -> T | None: ... + @overload + def getlist(self, key: str) -> list[str]: ... + @overload + def getlist(self, key: str, type: Callable[[str], T]) -> list[T]: ... + def get_all(self, name: str) -> list[str]: ... + def items( # type: ignore + self, lower: bool = False + ) -> Iterator[tuple[str, str]]: ... + def keys(self, lower: bool = False) -> Iterator[str]: ... # type: ignore + def values(self) -> Iterator[str]: ... # type: ignore + def extend( + self, + *args: Mapping[str, str | Iterable[str]] | Iterable[tuple[str, str]], + **kwargs: str | Iterable[str], + ) -> None: ... + @overload + def __delitem__(self, key: str | int | slice) -> None: ... + @overload + def __delitem__(self, key: str, _index_operation: Literal[False]) -> None: ... + def remove(self, key: str) -> None: ... + @overload # type: ignore + def pop(self, key: str, default: str | None = None) -> str: ... + @overload + def pop( + self, key: int | None = None, default: tuple[str, str] | None = None + ) -> tuple[str, str]: ... + def popitem(self) -> tuple[str, str]: ... + def __contains__(self, key: str) -> bool: ... # type: ignore + def has_key(self, key: str) -> bool: ... + def __iter__(self) -> Iterator[tuple[str, str]]: ... # type: ignore + def add(self, _key: str, _value: str, **kw: str) -> None: ... + def _validate_value(self, value: str) -> None: ... + def add_header(self, _key: str, _value: str, **_kw: str) -> None: ... + def clear(self) -> None: ... + def set(self, _key: str, _value: str, **kw: str) -> None: ... + def setlist(self, key: str, values: Iterable[str]) -> None: ... + def setdefault(self, key: str, default: str) -> str: ... + def setlistdefault(self, key: str, default: Iterable[str]) -> None: ... + @overload + def __setitem__(self, key: str, value: str) -> None: ... + @overload + def __setitem__(self, key: int, value: tuple[str, str]) -> None: ... + @overload + def __setitem__(self, key: slice, value: Iterable[tuple[str, str]]) -> None: ... + @overload + def update( + self, __m: SupportsKeysAndGetItem[str, str], **kwargs: str | Iterable[str] + ) -> None: ... + @overload + def update( + self, __m: Iterable[tuple[str, str]], **kwargs: str | Iterable[str] + ) -> None: ... + @overload + def update(self, **kwargs: str | Iterable[str]) -> None: ... + def to_wsgi_list(self) -> list[tuple[str, str]]: ... + def copy(self) -> Headers: ... + def __copy__(self) -> Headers: ... + +class EnvironHeaders(ImmutableHeadersMixin, Headers): + environ: WSGIEnvironment + def __init__(self, environ: WSGIEnvironment) -> None: ... + def __eq__(self, other: object) -> bool: ... + def __getitem__( # type: ignore + self, key: str, _get_mode: Literal[False] = False + ) -> str: ... + def __iter__(self) -> Iterator[tuple[str, str]]: ... # type: ignore + def copy(self) -> NoReturn: ... diff --git a/src/werkzeug/datastructures/mixins.py b/src/werkzeug/datastructures/mixins.py new file mode 100644 index 0000000..2c84ca8 --- /dev/null +++ b/src/werkzeug/datastructures/mixins.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from itertools import repeat + +from .._internal import _missing + + +def is_immutable(self): + raise TypeError(f"{type(self).__name__!r} objects are immutable") + + +class ImmutableListMixin: + """Makes a :class:`list` immutable. + + .. versionadded:: 0.5 + + :private: + """ + + _hash_cache = None + + def __hash__(self): + if self._hash_cache is not None: + return self._hash_cache + rv = self._hash_cache = hash(tuple(self)) + return rv + + def __reduce_ex__(self, protocol): + return type(self), (list(self),) + + def __delitem__(self, key): + is_immutable(self) + + def __iadd__(self, other): + is_immutable(self) + + def __imul__(self, other): + is_immutable(self) + + def __setitem__(self, key, value): + is_immutable(self) + + def append(self, item): + is_immutable(self) + + def remove(self, item): + is_immutable(self) + + def extend(self, iterable): + is_immutable(self) + + def insert(self, pos, value): + is_immutable(self) + + def pop(self, index=-1): + is_immutable(self) + + def reverse(self): + is_immutable(self) + + def sort(self, key=None, reverse=False): + is_immutable(self) + + +class ImmutableDictMixin: + """Makes a :class:`dict` immutable. + + .. versionadded:: 0.5 + + :private: + """ + + _hash_cache = None + + @classmethod + def fromkeys(cls, keys, value=None): + instance = super().__new__(cls) + instance.__init__(zip(keys, repeat(value))) + return instance + + def __reduce_ex__(self, protocol): + return type(self), (dict(self),) + + def _iter_hashitems(self): + return self.items() + + def __hash__(self): + if self._hash_cache is not None: + return self._hash_cache + rv = self._hash_cache = hash(frozenset(self._iter_hashitems())) + return rv + + def setdefault(self, key, default=None): + is_immutable(self) + + def update(self, *args, **kwargs): + is_immutable(self) + + def pop(self, key, default=None): + is_immutable(self) + + def popitem(self): + is_immutable(self) + + def __setitem__(self, key, value): + is_immutable(self) + + def __delitem__(self, key): + is_immutable(self) + + def clear(self): + is_immutable(self) + + +class ImmutableMultiDictMixin(ImmutableDictMixin): + """Makes a :class:`MultiDict` immutable. + + .. versionadded:: 0.5 + + :private: + """ + + def __reduce_ex__(self, protocol): + return type(self), (list(self.items(multi=True)),) + + def _iter_hashitems(self): + return self.items(multi=True) + + def add(self, key, value): + is_immutable(self) + + def popitemlist(self): + is_immutable(self) + + def poplist(self, key): + is_immutable(self) + + def setlist(self, key, new_list): + is_immutable(self) + + def setlistdefault(self, key, default_list=None): + is_immutable(self) + + +class ImmutableHeadersMixin: + """Makes a :class:`Headers` immutable. We do not mark them as + hashable though since the only usecase for this datastructure + in Werkzeug is a view on a mutable structure. + + .. versionadded:: 0.5 + + :private: + """ + + def __delitem__(self, key, **kwargs): + is_immutable(self) + + def __setitem__(self, key, value): + is_immutable(self) + + def set(self, _key, _value, **kwargs): + is_immutable(self) + + def setlist(self, key, values): + is_immutable(self) + + def add(self, _key, _value, **kwargs): + is_immutable(self) + + def add_header(self, _key, _value, **_kwargs): + is_immutable(self) + + def remove(self, key): + is_immutable(self) + + def extend(self, *args, **kwargs): + is_immutable(self) + + def update(self, *args, **kwargs): + is_immutable(self) + + def insert(self, pos, value): + is_immutable(self) + + def pop(self, key=None, default=_missing): + is_immutable(self) + + def popitem(self): + is_immutable(self) + + def setdefault(self, key, default): + is_immutable(self) + + def setlistdefault(self, key, default): + is_immutable(self) + + +def _calls_update(name): + def oncall(self, *args, **kw): + rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) + + if self.on_update is not None: + self.on_update(self) + + return rv + + oncall.__name__ = name + return oncall + + +class UpdateDictMixin(dict): + """Makes dicts call `self.on_update` on modifications. + + .. versionadded:: 0.5 + + :private: + """ + + on_update = None + + def setdefault(self, key, default=None): + modified = key not in self + rv = super().setdefault(key, default) + if modified and self.on_update is not None: + self.on_update(self) + return rv + + def pop(self, key, default=_missing): + modified = key in self + if default is _missing: + rv = super().pop(key) + else: + rv = super().pop(key, default) + if modified and self.on_update is not None: + self.on_update(self) + return rv + + __setitem__ = _calls_update("__setitem__") + __delitem__ = _calls_update("__delitem__") + clear = _calls_update("clear") + popitem = _calls_update("popitem") + update = _calls_update("update") diff --git a/src/werkzeug/datastructures/mixins.pyi b/src/werkzeug/datastructures/mixins.pyi new file mode 100644 index 0000000..40453f7 --- /dev/null +++ b/src/werkzeug/datastructures/mixins.pyi @@ -0,0 +1,97 @@ +from collections.abc import Callable +from collections.abc import Hashable +from collections.abc import Iterable +from typing import Any +from typing import NoReturn +from typing import overload +from typing import SupportsIndex +from typing import TypeVar + +from _typeshed import SupportsKeysAndGetItem + +from .headers import Headers + +K = TypeVar("K") +T = TypeVar("T") +V = TypeVar("V") + +def is_immutable(self: object) -> NoReturn: ... + +class ImmutableListMixin(list[V]): + _hash_cache: int | None + def __hash__(self) -> int: ... # type: ignore + def __delitem__(self, key: SupportsIndex | slice) -> NoReturn: ... + def __iadd__(self, other: Any) -> NoReturn: ... # type: ignore + def __imul__(self, other: SupportsIndex) -> NoReturn: ... + def __setitem__(self, key: int | slice, value: V) -> NoReturn: ... # type: ignore + def append(self, value: V) -> NoReturn: ... + def remove(self, value: V) -> NoReturn: ... + def extend(self, values: Iterable[V]) -> NoReturn: ... + def insert(self, pos: SupportsIndex, value: V) -> NoReturn: ... + def pop(self, index: SupportsIndex = -1) -> NoReturn: ... + def reverse(self) -> NoReturn: ... + def sort( + self, key: Callable[[V], Any] | None = None, reverse: bool = False + ) -> NoReturn: ... + +class ImmutableDictMixin(dict[K, V]): + _hash_cache: int | None + @classmethod + def fromkeys( # type: ignore + cls, keys: Iterable[K], value: V | None = None + ) -> ImmutableDictMixin[K, V]: ... + def _iter_hashitems(self) -> Iterable[Hashable]: ... + def __hash__(self) -> int: ... # type: ignore + def setdefault(self, key: K, default: V | None = None) -> NoReturn: ... + def update(self, *args: Any, **kwargs: V) -> NoReturn: ... + def pop(self, key: K, default: V | None = None) -> NoReturn: ... # type: ignore + def popitem(self) -> NoReturn: ... + def __setitem__(self, key: K, value: V) -> NoReturn: ... + def __delitem__(self, key: K) -> NoReturn: ... + def clear(self) -> NoReturn: ... + +class ImmutableMultiDictMixin(ImmutableDictMixin[K, V]): + def _iter_hashitems(self) -> Iterable[Hashable]: ... + def add(self, key: K, value: V) -> NoReturn: ... + def popitemlist(self) -> NoReturn: ... + def poplist(self, key: K) -> NoReturn: ... + def setlist(self, key: K, new_list: Iterable[V]) -> NoReturn: ... + def setlistdefault( + self, key: K, default_list: Iterable[V] | None = None + ) -> NoReturn: ... + +class ImmutableHeadersMixin(Headers): + def __delitem__(self, key: Any, _index_operation: bool = True) -> NoReturn: ... + def __setitem__(self, key: Any, value: Any) -> NoReturn: ... + def set(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... + def setlist(self, key: Any, values: Any) -> NoReturn: ... + def add(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... + def add_header(self, _key: Any, _value: Any, **_kw: Any) -> NoReturn: ... + def remove(self, key: Any) -> NoReturn: ... + def extend(self, *args: Any, **kwargs: Any) -> NoReturn: ... + def update(self, *args: Any, **kwargs: Any) -> NoReturn: ... + def insert(self, pos: Any, value: Any) -> NoReturn: ... + def pop(self, key: Any = None, default: Any = ...) -> NoReturn: ... + def popitem(self) -> NoReturn: ... + def setdefault(self, key: Any, default: Any) -> NoReturn: ... + def setlistdefault(self, key: Any, default: Any) -> NoReturn: ... + +def _calls_update(name: str) -> Callable[[UpdateDictMixin[K, V]], Any]: ... + +class UpdateDictMixin(dict[K, V]): + on_update: Callable[[UpdateDictMixin[K, V] | None, None], None] + def setdefault(self, key: K, default: V | None = None) -> V: ... + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: V | T = ...) -> V | T: ... + def __setitem__(self, key: K, value: V) -> None: ... + def __delitem__(self, key: K) -> None: ... + def clear(self) -> None: ... + def popitem(self) -> tuple[K, V]: ... + @overload + def update(self, __m: SupportsKeysAndGetItem[K, V], **kwargs: V) -> None: ... + @overload + def update(self, __m: Iterable[tuple[K, V]], **kwargs: V) -> None: ... + @overload + def update(self, **kwargs: V) -> None: ... diff --git a/src/werkzeug/datastructures/range.py b/src/werkzeug/datastructures/range.py new file mode 100644 index 0000000..7011ea4 --- /dev/null +++ b/src/werkzeug/datastructures/range.py @@ -0,0 +1,180 @@ +from __future__ import annotations + + +class IfRange: + """Very simple object that represents the `If-Range` header in parsed + form. It will either have neither a etag or date or one of either but + never both. + + .. versionadded:: 0.7 + """ + + def __init__(self, etag=None, date=None): + #: The etag parsed and unquoted. Ranges always operate on strong + #: etags so the weakness information is not necessary. + self.etag = etag + #: The date in parsed format or `None`. + self.date = date + + def to_header(self): + """Converts the object back into an HTTP header.""" + if self.date is not None: + return http.http_date(self.date) + if self.etag is not None: + return http.quote_etag(self.etag) + return "" + + def __str__(self): + return self.to_header() + + def __repr__(self): + return f"<{type(self).__name__} {str(self)!r}>" + + +class Range: + """Represents a ``Range`` header. All methods only support only + bytes as the unit. Stores a list of ranges if given, but the methods + only work if only one range is provided. + + :raise ValueError: If the ranges provided are invalid. + + .. versionchanged:: 0.15 + The ranges passed in are validated. + + .. versionadded:: 0.7 + """ + + def __init__(self, units, ranges): + #: The units of this range. Usually "bytes". + self.units = units + #: A list of ``(begin, end)`` tuples for the range header provided. + #: The ranges are non-inclusive. + self.ranges = ranges + + for start, end in ranges: + if start is None or (end is not None and (start < 0 or start >= end)): + raise ValueError(f"{(start, end)} is not a valid range.") + + def range_for_length(self, length): + """If the range is for bytes, the length is not None and there is + exactly one range and it is satisfiable it returns a ``(start, stop)`` + tuple, otherwise `None`. + """ + if self.units != "bytes" or length is None or len(self.ranges) != 1: + return None + start, end = self.ranges[0] + if end is None: + end = length + if start < 0: + start += length + if http.is_byte_range_valid(start, end, length): + return start, min(end, length) + return None + + def make_content_range(self, length): + """Creates a :class:`~werkzeug.datastructures.ContentRange` object + from the current range and given content length. + """ + rng = self.range_for_length(length) + if rng is not None: + return ContentRange(self.units, rng[0], rng[1], length) + return None + + def to_header(self): + """Converts the object back into an HTTP header.""" + ranges = [] + for begin, end in self.ranges: + if end is None: + ranges.append(f"{begin}-" if begin >= 0 else str(begin)) + else: + ranges.append(f"{begin}-{end - 1}") + return f"{self.units}={','.join(ranges)}" + + def to_content_range_header(self, length): + """Converts the object into `Content-Range` HTTP header, + based on given length + """ + range = self.range_for_length(length) + if range is not None: + return f"{self.units} {range[0]}-{range[1] - 1}/{length}" + return None + + def __str__(self): + return self.to_header() + + def __repr__(self): + return f"<{type(self).__name__} {str(self)!r}>" + + +def _callback_property(name): + def fget(self): + return getattr(self, name) + + def fset(self, value): + setattr(self, name, value) + if self.on_update is not None: + self.on_update(self) + + return property(fget, fset) + + +class ContentRange: + """Represents the content range header. + + .. versionadded:: 0.7 + """ + + def __init__(self, units, start, stop, length=None, on_update=None): + assert http.is_byte_range_valid(start, stop, length), "Bad range provided" + self.on_update = on_update + self.set(start, stop, length, units) + + #: The units to use, usually "bytes" + units = _callback_property("_units") + #: The start point of the range or `None`. + start = _callback_property("_start") + #: The stop point of the range (non-inclusive) or `None`. Can only be + #: `None` if also start is `None`. + stop = _callback_property("_stop") + #: The length of the range or `None`. + length = _callback_property("_length") + + def set(self, start, stop, length=None, units="bytes"): + """Simple method to update the ranges.""" + assert http.is_byte_range_valid(start, stop, length), "Bad range provided" + self._units = units + self._start = start + self._stop = stop + self._length = length + if self.on_update is not None: + self.on_update(self) + + def unset(self): + """Sets the units to `None` which indicates that the header should + no longer be used. + """ + self.set(None, None, units=None) + + def to_header(self): + if self.units is None: + return "" + if self.length is None: + length = "*" + else: + length = self.length + if self.start is None: + return f"{self.units} */{length}" + return f"{self.units} {self.start}-{self.stop - 1}/{length}" + + def __bool__(self): + return self.units is not None + + def __str__(self): + return self.to_header() + + def __repr__(self): + return f"<{type(self).__name__} {str(self)!r}>" + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/range.pyi b/src/werkzeug/datastructures/range.pyi new file mode 100644 index 0000000..f38ad69 --- /dev/null +++ b/src/werkzeug/datastructures/range.pyi @@ -0,0 +1,57 @@ +from collections.abc import Callable +from datetime import datetime + +class IfRange: + etag: str | None + date: datetime | None + def __init__( + self, etag: str | None = None, date: datetime | None = None + ) -> None: ... + def to_header(self) -> str: ... + +class Range: + units: str + ranges: list[tuple[int, int | None]] + def __init__(self, units: str, ranges: list[tuple[int, int | None]]) -> None: ... + def range_for_length(self, length: int | None) -> tuple[int, int] | None: ... + def make_content_range(self, length: int | None) -> ContentRange | None: ... + def to_header(self) -> str: ... + def to_content_range_header(self, length: int | None) -> str | None: ... + +def _callback_property(name: str) -> property: ... + +class ContentRange: + on_update: Callable[[ContentRange], None] | None + def __init__( + self, + units: str | None, + start: int | None, + stop: int | None, + length: int | None = None, + on_update: Callable[[ContentRange], None] | None = None, + ) -> None: ... + @property + def units(self) -> str | None: ... + @units.setter + def units(self, value: str | None) -> None: ... + @property + def start(self) -> int | None: ... + @start.setter + def start(self, value: int | None) -> None: ... + @property + def stop(self) -> int | None: ... + @stop.setter + def stop(self, value: int | None) -> None: ... + @property + def length(self) -> int | None: ... + @length.setter + def length(self, value: int | None) -> None: ... + def set( + self, + start: int | None, + stop: int | None, + length: int | None = None, + units: str | None = "bytes", + ) -> None: ... + def unset(self) -> None: ... + def to_header(self) -> str: ... diff --git a/src/werkzeug/datastructures/structures.py b/src/werkzeug/datastructures/structures.py new file mode 100644 index 0000000..4279ceb --- /dev/null +++ b/src/werkzeug/datastructures/structures.py @@ -0,0 +1,1010 @@ +from __future__ import annotations + +from collections.abc import MutableSet +from copy import deepcopy + +from .. import exceptions +from .._internal import _missing +from .mixins import ImmutableDictMixin +from .mixins import ImmutableListMixin +from .mixins import ImmutableMultiDictMixin +from .mixins import UpdateDictMixin + + +def is_immutable(self): + raise TypeError(f"{type(self).__name__!r} objects are immutable") + + +def iter_multi_items(mapping): + """Iterates over the items of a mapping yielding keys and values + without dropping any from more complex structures. + """ + if isinstance(mapping, MultiDict): + yield from mapping.items(multi=True) + elif isinstance(mapping, dict): + for key, value in mapping.items(): + if isinstance(value, (tuple, list)): + for v in value: + yield key, v + else: + yield key, value + else: + yield from mapping + + +class ImmutableList(ImmutableListMixin, list): + """An immutable :class:`list`. + + .. versionadded:: 0.5 + + :private: + """ + + def __repr__(self): + return f"{type(self).__name__}({list.__repr__(self)})" + + +class TypeConversionDict(dict): + """Works like a regular dict but the :meth:`get` method can perform + type conversions. :class:`MultiDict` and :class:`CombinedMultiDict` + are subclasses of this class and provide the same feature. + + .. versionadded:: 0.5 + """ + + def get(self, key, default=None, type=None): + """Return the default value if the requested data doesn't exist. + If `type` is provided and is a callable it should convert the value, + return it or raise a :exc:`ValueError` if that is not possible. In + this case the function will return the default as if the value was not + found: + + >>> d = TypeConversionDict(foo='42', bar='blub') + >>> d.get('foo', type=int) + 42 + >>> d.get('bar', -1, type=int) + -1 + + :param key: The key to be looked up. + :param default: The default value to be returned if the key can't + be looked up. If not further specified `None` is + returned. + :param type: A callable that is used to cast the value in the + :class:`MultiDict`. If a :exc:`ValueError` or a + :exc:`TypeError` is raised by this callable the default + value is returned. + + .. versionchanged:: 3.0.2 + Returns the default value on :exc:`TypeError`, too. + """ + try: + rv = self[key] + except KeyError: + return default + if type is not None: + try: + rv = type(rv) + except (ValueError, TypeError): + rv = default + return rv + + +class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): + """Works like a :class:`TypeConversionDict` but does not support + modifications. + + .. versionadded:: 0.5 + """ + + def copy(self): + """Return a shallow mutable copy of this object. Keep in mind that + the standard library's :func:`copy` function is a no-op for this class + like for any other python immutable type (eg: :class:`tuple`). + """ + return TypeConversionDict(self) + + def __copy__(self): + return self + + +class MultiDict(TypeConversionDict): + """A :class:`MultiDict` is a dictionary subclass customized to deal with + multiple values for the same key which is for example used by the parsing + functions in the wrappers. This is necessary because some HTML form + elements pass multiple values for the same key. + + :class:`MultiDict` implements all standard dictionary methods. + Internally, it saves all values for a key as a list, but the standard dict + access methods will only return the first value for a key. If you want to + gain access to the other values, too, you have to use the `list` methods as + explained below. + + Basic Usage: + + >>> d = MultiDict([('a', 'b'), ('a', 'c')]) + >>> d + MultiDict([('a', 'b'), ('a', 'c')]) + >>> d['a'] + 'b' + >>> d.getlist('a') + ['b', 'c'] + >>> 'a' in d + True + + It behaves like a normal dict thus all dict functions will only return the + first value when multiple values for one key are found. + + From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a + subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will + render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP + exceptions. + + A :class:`MultiDict` can be constructed from an iterable of + ``(key, value)`` tuples, a dict, a :class:`MultiDict` or from Werkzeug 0.2 + onwards some keyword parameters. + + :param mapping: the initial value for the :class:`MultiDict`. Either a + regular dict, an iterable of ``(key, value)`` tuples + or `None`. + """ + + def __init__(self, mapping=None): + if isinstance(mapping, MultiDict): + dict.__init__(self, ((k, vs[:]) for k, vs in mapping.lists())) + elif isinstance(mapping, dict): + tmp = {} + for key, value in mapping.items(): + if isinstance(value, (tuple, list)): + if len(value) == 0: + continue + value = list(value) + else: + value = [value] + tmp[key] = value + dict.__init__(self, tmp) + else: + tmp = {} + for key, value in mapping or (): + tmp.setdefault(key, []).append(value) + dict.__init__(self, tmp) + + def __getstate__(self): + return dict(self.lists()) + + def __setstate__(self, value): + dict.clear(self) + dict.update(self, value) + + def __iter__(self): + # Work around https://bugs.python.org/issue43246. + # (`return super().__iter__()` also works here, which makes this look + # even more like it should be a no-op, yet it isn't.) + return dict.__iter__(self) + + def __getitem__(self, key): + """Return the first data value for this key; + raises KeyError if not found. + + :param key: The key to be looked up. + :raise KeyError: if the key does not exist. + """ + + if key in self: + lst = dict.__getitem__(self, key) + if len(lst) > 0: + return lst[0] + raise exceptions.BadRequestKeyError(key) + + def __setitem__(self, key, value): + """Like :meth:`add` but removes an existing key first. + + :param key: the key for the value. + :param value: the value to set. + """ + dict.__setitem__(self, key, [value]) + + def add(self, key, value): + """Adds a new value for the key. + + .. versionadded:: 0.6 + + :param key: the key for the value. + :param value: the value to add. + """ + dict.setdefault(self, key, []).append(value) + + def getlist(self, key, type=None): + """Return the list of items for a given key. If that key is not in the + `MultiDict`, the return value will be an empty list. Just like `get`, + `getlist` accepts a `type` parameter. All items will be converted + with the callable defined there. + + :param key: The key to be looked up. + :param type: A callable that is used to cast the value in the + :class:`MultiDict`. If a :exc:`ValueError` is raised + by this callable the value will be removed from the list. + :return: a :class:`list` of all the values for the key. + """ + try: + rv = dict.__getitem__(self, key) + except KeyError: + return [] + if type is None: + return list(rv) + result = [] + for item in rv: + try: + result.append(type(item)) + except ValueError: + pass + return result + + def setlist(self, key, new_list): + """Remove the old values for a key and add new ones. Note that the list + you pass the values in will be shallow-copied before it is inserted in + the dictionary. + + >>> d = MultiDict() + >>> d.setlist('foo', ['1', '2']) + >>> d['foo'] + '1' + >>> d.getlist('foo') + ['1', '2'] + + :param key: The key for which the values are set. + :param new_list: An iterable with the new values for the key. Old values + are removed first. + """ + dict.__setitem__(self, key, list(new_list)) + + def setdefault(self, key, default=None): + """Returns the value for the key if it is in the dict, otherwise it + returns `default` and sets that value for `key`. + + :param key: The key to be looked up. + :param default: The default value to be returned if the key is not + in the dict. If not further specified it's `None`. + """ + if key not in self: + self[key] = default + else: + default = self[key] + return default + + def setlistdefault(self, key, default_list=None): + """Like `setdefault` but sets multiple values. The list returned + is not a copy, but the list that is actually used internally. This + means that you can put new values into the dict by appending items + to the list: + + >>> d = MultiDict({"foo": 1}) + >>> d.setlistdefault("foo").extend([2, 3]) + >>> d.getlist("foo") + [1, 2, 3] + + :param key: The key to be looked up. + :param default_list: An iterable of default values. It is either copied + (in case it was a list) or converted into a list + before returned. + :return: a :class:`list` + """ + if key not in self: + default_list = list(default_list or ()) + dict.__setitem__(self, key, default_list) + else: + default_list = dict.__getitem__(self, key) + return default_list + + def items(self, multi=False): + """Return an iterator of ``(key, value)`` pairs. + + :param multi: If set to `True` the iterator returned will have a pair + for each value of each key. Otherwise it will only + contain pairs for the first value of each key. + """ + for key, values in dict.items(self): + if multi: + for value in values: + yield key, value + else: + yield key, values[0] + + def lists(self): + """Return a iterator of ``(key, values)`` pairs, where values is the list + of all values associated with the key.""" + for key, values in dict.items(self): + yield key, list(values) + + def values(self): + """Returns an iterator of the first value on every key's value list.""" + for values in dict.values(self): + yield values[0] + + def listvalues(self): + """Return an iterator of all values associated with a key. Zipping + :meth:`keys` and this is the same as calling :meth:`lists`: + + >>> d = MultiDict({"foo": [1, 2, 3]}) + >>> zip(d.keys(), d.listvalues()) == d.lists() + True + """ + return dict.values(self) + + def copy(self): + """Return a shallow copy of this object.""" + return self.__class__(self) + + def deepcopy(self, memo=None): + """Return a deep copy of this object.""" + return self.__class__(deepcopy(self.to_dict(flat=False), memo)) + + def to_dict(self, flat=True): + """Return the contents as regular dict. If `flat` is `True` the + returned dict will only have the first item present, if `flat` is + `False` all values will be returned as lists. + + :param flat: If set to `False` the dict returned will have lists + with all the values in it. Otherwise it will only + contain the first value for each key. + :return: a :class:`dict` + """ + if flat: + return dict(self.items()) + return dict(self.lists()) + + def update(self, mapping): + """update() extends rather than replaces existing key lists: + + >>> a = MultiDict({'x': 1}) + >>> b = MultiDict({'x': 2, 'y': 3}) + >>> a.update(b) + >>> a + MultiDict([('y', 3), ('x', 1), ('x', 2)]) + + If the value list for a key in ``other_dict`` is empty, no new values + will be added to the dict and the key will not be created: + + >>> x = {'empty_list': []} + >>> y = MultiDict() + >>> y.update(x) + >>> y + MultiDict([]) + """ + for key, value in iter_multi_items(mapping): + MultiDict.add(self, key, value) + + def pop(self, key, default=_missing): + """Pop the first item for a list on the dict. Afterwards the + key is removed from the dict, so additional values are discarded: + + >>> d = MultiDict({"foo": [1, 2, 3]}) + >>> d.pop("foo") + 1 + >>> "foo" in d + False + + :param key: the key to pop. + :param default: if provided the value to return if the key was + not in the dictionary. + """ + try: + lst = dict.pop(self, key) + + if len(lst) == 0: + raise exceptions.BadRequestKeyError(key) + + return lst[0] + except KeyError: + if default is not _missing: + return default + + raise exceptions.BadRequestKeyError(key) from None + + def popitem(self): + """Pop an item from the dict.""" + try: + item = dict.popitem(self) + + if len(item[1]) == 0: + raise exceptions.BadRequestKeyError(item[0]) + + return (item[0], item[1][0]) + except KeyError as e: + raise exceptions.BadRequestKeyError(e.args[0]) from None + + def poplist(self, key): + """Pop the list for a key from the dict. If the key is not in the dict + an empty list is returned. + + .. versionchanged:: 0.5 + If the key does no longer exist a list is returned instead of + raising an error. + """ + return dict.pop(self, key, []) + + def popitemlist(self): + """Pop a ``(key, list)`` tuple from the dict.""" + try: + return dict.popitem(self) + except KeyError as e: + raise exceptions.BadRequestKeyError(e.args[0]) from None + + def __copy__(self): + return self.copy() + + def __deepcopy__(self, memo): + return self.deepcopy(memo=memo) + + def __repr__(self): + return f"{type(self).__name__}({list(self.items(multi=True))!r})" + + +class _omd_bucket: + """Wraps values in the :class:`OrderedMultiDict`. This makes it + possible to keep an order over multiple different keys. It requires + a lot of extra memory and slows down access a lot, but makes it + possible to access elements in O(1) and iterate in O(n). + """ + + __slots__ = ("prev", "key", "value", "next") + + def __init__(self, omd, key, value): + self.prev = omd._last_bucket + self.key = key + self.value = value + self.next = None + + if omd._first_bucket is None: + omd._first_bucket = self + if omd._last_bucket is not None: + omd._last_bucket.next = self + omd._last_bucket = self + + def unlink(self, omd): + if self.prev: + self.prev.next = self.next + if self.next: + self.next.prev = self.prev + if omd._first_bucket is self: + omd._first_bucket = self.next + if omd._last_bucket is self: + omd._last_bucket = self.prev + + +class OrderedMultiDict(MultiDict): + """Works like a regular :class:`MultiDict` but preserves the + order of the fields. To convert the ordered multi dict into a + list you can use the :meth:`items` method and pass it ``multi=True``. + + In general an :class:`OrderedMultiDict` is an order of magnitude + slower than a :class:`MultiDict`. + + .. admonition:: note + + Due to a limitation in Python you cannot convert an ordered + multi dict into a regular dict by using ``dict(multidict)``. + Instead you have to use the :meth:`to_dict` method, otherwise + the internal bucket objects are exposed. + """ + + def __init__(self, mapping=None): + dict.__init__(self) + self._first_bucket = self._last_bucket = None + if mapping is not None: + OrderedMultiDict.update(self, mapping) + + def __eq__(self, other): + if not isinstance(other, MultiDict): + return NotImplemented + if isinstance(other, OrderedMultiDict): + iter1 = iter(self.items(multi=True)) + iter2 = iter(other.items(multi=True)) + try: + for k1, v1 in iter1: + k2, v2 = next(iter2) + if k1 != k2 or v1 != v2: + return False + except StopIteration: + return False + try: + next(iter2) + except StopIteration: + return True + return False + if len(self) != len(other): + return False + for key, values in self.lists(): + if other.getlist(key) != values: + return False + return True + + __hash__ = None + + def __reduce_ex__(self, protocol): + return type(self), (list(self.items(multi=True)),) + + def __getstate__(self): + return list(self.items(multi=True)) + + def __setstate__(self, values): + dict.clear(self) + for key, value in values: + self.add(key, value) + + def __getitem__(self, key): + if key in self: + return dict.__getitem__(self, key)[0].value + raise exceptions.BadRequestKeyError(key) + + def __setitem__(self, key, value): + self.poplist(key) + self.add(key, value) + + def __delitem__(self, key): + self.pop(key) + + def keys(self): + return (key for key, value in self.items()) + + def __iter__(self): + return iter(self.keys()) + + def values(self): + return (value for key, value in self.items()) + + def items(self, multi=False): + ptr = self._first_bucket + if multi: + while ptr is not None: + yield ptr.key, ptr.value + ptr = ptr.next + else: + returned_keys = set() + while ptr is not None: + if ptr.key not in returned_keys: + returned_keys.add(ptr.key) + yield ptr.key, ptr.value + ptr = ptr.next + + def lists(self): + returned_keys = set() + ptr = self._first_bucket + while ptr is not None: + if ptr.key not in returned_keys: + yield ptr.key, self.getlist(ptr.key) + returned_keys.add(ptr.key) + ptr = ptr.next + + def listvalues(self): + for _key, values in self.lists(): + yield values + + def add(self, key, value): + dict.setdefault(self, key, []).append(_omd_bucket(self, key, value)) + + def getlist(self, key, type=None): + try: + rv = dict.__getitem__(self, key) + except KeyError: + return [] + if type is None: + return [x.value for x in rv] + result = [] + for item in rv: + try: + result.append(type(item.value)) + except ValueError: + pass + return result + + def setlist(self, key, new_list): + self.poplist(key) + for value in new_list: + self.add(key, value) + + def setlistdefault(self, key, default_list=None): + raise TypeError("setlistdefault is unsupported for ordered multi dicts") + + def update(self, mapping): + for key, value in iter_multi_items(mapping): + OrderedMultiDict.add(self, key, value) + + def poplist(self, key): + buckets = dict.pop(self, key, ()) + for bucket in buckets: + bucket.unlink(self) + return [x.value for x in buckets] + + def pop(self, key, default=_missing): + try: + buckets = dict.pop(self, key) + except KeyError: + if default is not _missing: + return default + + raise exceptions.BadRequestKeyError(key) from None + + for bucket in buckets: + bucket.unlink(self) + + return buckets[0].value + + def popitem(self): + try: + key, buckets = dict.popitem(self) + except KeyError as e: + raise exceptions.BadRequestKeyError(e.args[0]) from None + + for bucket in buckets: + bucket.unlink(self) + + return key, buckets[0].value + + def popitemlist(self): + try: + key, buckets = dict.popitem(self) + except KeyError as e: + raise exceptions.BadRequestKeyError(e.args[0]) from None + + for bucket in buckets: + bucket.unlink(self) + + return key, [x.value for x in buckets] + + +class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): + """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict` + instances as sequence and it will combine the return values of all wrapped + dicts: + + >>> from werkzeug.datastructures import CombinedMultiDict, MultiDict + >>> post = MultiDict([('foo', 'bar')]) + >>> get = MultiDict([('blub', 'blah')]) + >>> combined = CombinedMultiDict([get, post]) + >>> combined['foo'] + 'bar' + >>> combined['blub'] + 'blah' + + This works for all read operations and will raise a `TypeError` for + methods that usually change data which isn't possible. + + From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a + subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will + render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP + exceptions. + """ + + def __reduce_ex__(self, protocol): + return type(self), (self.dicts,) + + def __init__(self, dicts=None): + self.dicts = list(dicts) or [] + + @classmethod + def fromkeys(cls, keys, value=None): + raise TypeError(f"cannot create {cls.__name__!r} instances by fromkeys") + + def __getitem__(self, key): + for d in self.dicts: + if key in d: + return d[key] + raise exceptions.BadRequestKeyError(key) + + def get(self, key, default=None, type=None): + for d in self.dicts: + if key in d: + if type is not None: + try: + return type(d[key]) + except ValueError: + continue + return d[key] + return default + + def getlist(self, key, type=None): + rv = [] + for d in self.dicts: + rv.extend(d.getlist(key, type)) + return rv + + def _keys_impl(self): + """This function exists so __len__ can be implemented more efficiently, + saving one list creation from an iterator. + """ + rv = set() + rv.update(*self.dicts) + return rv + + def keys(self): + return self._keys_impl() + + def __iter__(self): + return iter(self.keys()) + + def items(self, multi=False): + found = set() + for d in self.dicts: + for key, value in d.items(multi): + if multi: + yield key, value + elif key not in found: + found.add(key) + yield key, value + + def values(self): + for _key, value in self.items(): + yield value + + def lists(self): + rv = {} + for d in self.dicts: + for key, values in d.lists(): + rv.setdefault(key, []).extend(values) + return list(rv.items()) + + def listvalues(self): + return (x[1] for x in self.lists()) + + def copy(self): + """Return a shallow mutable copy of this object. + + This returns a :class:`MultiDict` representing the data at the + time of copying. The copy will no longer reflect changes to the + wrapped dicts. + + .. versionchanged:: 0.15 + Return a mutable :class:`MultiDict`. + """ + return MultiDict(self) + + def to_dict(self, flat=True): + """Return the contents as regular dict. If `flat` is `True` the + returned dict will only have the first item present, if `flat` is + `False` all values will be returned as lists. + + :param flat: If set to `False` the dict returned will have lists + with all the values in it. Otherwise it will only + contain the first item for each key. + :return: a :class:`dict` + """ + if flat: + return dict(self.items()) + + return dict(self.lists()) + + def __len__(self): + return len(self._keys_impl()) + + def __contains__(self, key): + for d in self.dicts: + if key in d: + return True + return False + + def __repr__(self): + return f"{type(self).__name__}({self.dicts!r})" + + +class ImmutableDict(ImmutableDictMixin, dict): + """An immutable :class:`dict`. + + .. versionadded:: 0.5 + """ + + def __repr__(self): + return f"{type(self).__name__}({dict.__repr__(self)})" + + def copy(self): + """Return a shallow mutable copy of this object. Keep in mind that + the standard library's :func:`copy` function is a no-op for this class + like for any other python immutable type (eg: :class:`tuple`). + """ + return dict(self) + + def __copy__(self): + return self + + +class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): + """An immutable :class:`MultiDict`. + + .. versionadded:: 0.5 + """ + + def copy(self): + """Return a shallow mutable copy of this object. Keep in mind that + the standard library's :func:`copy` function is a no-op for this class + like for any other python immutable type (eg: :class:`tuple`). + """ + return MultiDict(self) + + def __copy__(self): + return self + + +class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): + """An immutable :class:`OrderedMultiDict`. + + .. versionadded:: 0.6 + """ + + def _iter_hashitems(self): + return enumerate(self.items(multi=True)) + + def copy(self): + """Return a shallow mutable copy of this object. Keep in mind that + the standard library's :func:`copy` function is a no-op for this class + like for any other python immutable type (eg: :class:`tuple`). + """ + return OrderedMultiDict(self) + + def __copy__(self): + return self + + +class CallbackDict(UpdateDictMixin, dict): + """A dict that calls a function passed every time something is changed. + The function is passed the dict instance. + """ + + def __init__(self, initial=None, on_update=None): + dict.__init__(self, initial or ()) + self.on_update = on_update + + def __repr__(self): + return f"<{type(self).__name__} {dict.__repr__(self)}>" + + +class HeaderSet(MutableSet): + """Similar to the :class:`ETags` class this implements a set-like structure. + Unlike :class:`ETags` this is case insensitive and used for vary, allow, and + content-language headers. + + If not constructed using the :func:`parse_set_header` function the + instantiation works like this: + + >>> hs = HeaderSet(['foo', 'bar', 'baz']) + >>> hs + HeaderSet(['foo', 'bar', 'baz']) + """ + + def __init__(self, headers=None, on_update=None): + self._headers = list(headers or ()) + self._set = {x.lower() for x in self._headers} + self.on_update = on_update + + def add(self, header): + """Add a new header to the set.""" + self.update((header,)) + + def remove(self, header): + """Remove a header from the set. This raises an :exc:`KeyError` if the + header is not in the set. + + .. versionchanged:: 0.5 + In older versions a :exc:`IndexError` was raised instead of a + :exc:`KeyError` if the object was missing. + + :param header: the header to be removed. + """ + key = header.lower() + if key not in self._set: + raise KeyError(header) + self._set.remove(key) + for idx, key in enumerate(self._headers): + if key.lower() == header: + del self._headers[idx] + break + if self.on_update is not None: + self.on_update(self) + + def update(self, iterable): + """Add all the headers from the iterable to the set. + + :param iterable: updates the set with the items from the iterable. + """ + inserted_any = False + for header in iterable: + key = header.lower() + if key not in self._set: + self._headers.append(header) + self._set.add(key) + inserted_any = True + if inserted_any and self.on_update is not None: + self.on_update(self) + + def discard(self, header): + """Like :meth:`remove` but ignores errors. + + :param header: the header to be discarded. + """ + try: + self.remove(header) + except KeyError: + pass + + def find(self, header): + """Return the index of the header in the set or return -1 if not found. + + :param header: the header to be looked up. + """ + header = header.lower() + for idx, item in enumerate(self._headers): + if item.lower() == header: + return idx + return -1 + + def index(self, header): + """Return the index of the header in the set or raise an + :exc:`IndexError`. + + :param header: the header to be looked up. + """ + rv = self.find(header) + if rv < 0: + raise IndexError(header) + return rv + + def clear(self): + """Clear the set.""" + self._set.clear() + del self._headers[:] + if self.on_update is not None: + self.on_update(self) + + def as_set(self, preserve_casing=False): + """Return the set as real python set type. When calling this, all + the items are converted to lowercase and the ordering is lost. + + :param preserve_casing: if set to `True` the items in the set returned + will have the original case like in the + :class:`HeaderSet`, otherwise they will + be lowercase. + """ + if preserve_casing: + return set(self._headers) + return set(self._set) + + def to_header(self): + """Convert the header set into an HTTP header string.""" + return ", ".join(map(http.quote_header_value, self._headers)) + + def __getitem__(self, idx): + return self._headers[idx] + + def __delitem__(self, idx): + rv = self._headers.pop(idx) + self._set.remove(rv.lower()) + if self.on_update is not None: + self.on_update(self) + + def __setitem__(self, idx, value): + old = self._headers[idx] + self._set.remove(old.lower()) + self._headers[idx] = value + self._set.add(value.lower()) + if self.on_update is not None: + self.on_update(self) + + def __contains__(self, header): + return header.lower() in self._set + + def __len__(self): + return len(self._set) + + def __iter__(self): + return iter(self._headers) + + def __bool__(self): + return bool(self._set) + + def __str__(self): + return self.to_header() + + def __repr__(self): + return f"{type(self).__name__}({self._headers!r})" + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/structures.pyi b/src/werkzeug/datastructures/structures.pyi new file mode 100644 index 0000000..7086dda --- /dev/null +++ b/src/werkzeug/datastructures/structures.pyi @@ -0,0 +1,206 @@ +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Iterator +from collections.abc import Mapping +from typing import Any +from typing import Generic +from typing import Literal +from typing import NoReturn +from typing import overload +from typing import TypeVar + +from .mixins import ImmutableDictMixin +from .mixins import ImmutableListMixin +from .mixins import ImmutableMultiDictMixin +from .mixins import UpdateDictMixin + +D = TypeVar("D") +K = TypeVar("K") +T = TypeVar("T") +V = TypeVar("V") +_CD = TypeVar("_CD", bound="CallbackDict[Any, Any]") + +def is_immutable(self: object) -> NoReturn: ... +def iter_multi_items( + mapping: Mapping[K, V | Iterable[V]] | Iterable[tuple[K, V]], +) -> Iterator[tuple[K, V]]: ... + +class ImmutableList(ImmutableListMixin[V]): ... + +class TypeConversionDict(dict[K, V]): + @overload + def get(self, key: K, default: None = ..., type: None = ...) -> V | None: ... + @overload + def get(self, key: K, default: D, type: None = ...) -> D | V: ... + @overload + def get(self, key: K, default: D, type: Callable[[V], T]) -> D | T: ... + @overload + def get(self, key: K, type: Callable[[V], T]) -> T | None: ... + +class ImmutableTypeConversionDict(ImmutableDictMixin[K, V], TypeConversionDict[K, V]): + def copy(self) -> TypeConversionDict[K, V]: ... + def __copy__(self) -> ImmutableTypeConversionDict[K, V]: ... + +class MultiDict(TypeConversionDict[K, V]): + def __init__( + self, + mapping: Mapping[K, Iterable[V] | V] | Iterable[tuple[K, V]] | None = None, + ) -> None: ... + def __getitem__(self, item: K) -> V: ... + def __setitem__(self, key: K, value: V) -> None: ... + def add(self, key: K, value: V) -> None: ... + @overload + def getlist(self, key: K) -> list[V]: ... + @overload + def getlist(self, key: K, type: Callable[[V], T] = ...) -> list[T]: ... + def setlist(self, key: K, new_list: Iterable[V]) -> None: ... + def setdefault(self, key: K, default: V | None = None) -> V: ... + def setlistdefault( + self, key: K, default_list: Iterable[V] | None = None + ) -> list[V]: ... + def items(self, multi: bool = False) -> Iterator[tuple[K, V]]: ... # type: ignore + def lists(self) -> Iterator[tuple[K, list[V]]]: ... + def values(self) -> Iterator[V]: ... # type: ignore + def listvalues(self) -> Iterator[list[V]]: ... + def copy(self) -> MultiDict[K, V]: ... + def deepcopy(self, memo: Any = None) -> MultiDict[K, V]: ... + @overload + def to_dict(self) -> dict[K, V]: ... + @overload + def to_dict(self, flat: Literal[False]) -> dict[K, list[V]]: ... + def update( # type: ignore + self, mapping: Mapping[K, Iterable[V] | V] | Iterable[tuple[K, V]] + ) -> None: ... + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: V | T = ...) -> V | T: ... + def popitem(self) -> tuple[K, V]: ... + def poplist(self, key: K) -> list[V]: ... + def popitemlist(self) -> tuple[K, list[V]]: ... + def __copy__(self) -> MultiDict[K, V]: ... + def __deepcopy__(self, memo: Any) -> MultiDict[K, V]: ... + +class _omd_bucket(Generic[K, V]): + prev: _omd_bucket[K, V] | None + next: _omd_bucket[K, V] | None + key: K + value: V + def __init__(self, omd: OrderedMultiDict[K, V], key: K, value: V) -> None: ... + def unlink(self, omd: OrderedMultiDict[K, V]) -> None: ... + +class OrderedMultiDict(MultiDict[K, V]): + _first_bucket: _omd_bucket[K, V] | None + _last_bucket: _omd_bucket[K, V] | None + def __init__(self, mapping: Mapping[K, V] | None = None) -> None: ... + def __eq__(self, other: object) -> bool: ... + def __getitem__(self, key: K) -> V: ... + def __setitem__(self, key: K, value: V) -> None: ... + def __delitem__(self, key: K) -> None: ... + def keys(self) -> Iterator[K]: ... # type: ignore + def __iter__(self) -> Iterator[K]: ... + def values(self) -> Iterator[V]: ... # type: ignore + def items(self, multi: bool = False) -> Iterator[tuple[K, V]]: ... # type: ignore + def lists(self) -> Iterator[tuple[K, list[V]]]: ... + def listvalues(self) -> Iterator[list[V]]: ... + def add(self, key: K, value: V) -> None: ... + @overload + def getlist(self, key: K) -> list[V]: ... + @overload + def getlist(self, key: K, type: Callable[[V], T] = ...) -> list[T]: ... + def setlist(self, key: K, new_list: Iterable[V]) -> None: ... + def setlistdefault( + self, key: K, default_list: Iterable[V] | None = None + ) -> list[V]: ... + def update( # type: ignore + self, mapping: Mapping[K, V] | Iterable[tuple[K, V]] + ) -> None: ... + def poplist(self, key: K) -> list[V]: ... + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: V | T = ...) -> V | T: ... + def popitem(self) -> tuple[K, V]: ... + def popitemlist(self) -> tuple[K, list[V]]: ... + +class CombinedMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): # type: ignore + dicts: list[MultiDict[K, V]] + def __init__(self, dicts: Iterable[MultiDict[K, V]] | None) -> None: ... + @classmethod + def fromkeys(cls, keys: Any, value: Any = None) -> NoReturn: ... + def __getitem__(self, key: K) -> V: ... + @overload # type: ignore + def get(self, key: K) -> V | None: ... + @overload + def get(self, key: K, default: V | T = ...) -> V | T: ... + @overload + def get( + self, key: K, default: T | None = None, type: Callable[[V], T] = ... + ) -> T | None: ... + @overload + def getlist(self, key: K) -> list[V]: ... + @overload + def getlist(self, key: K, type: Callable[[V], T] = ...) -> list[T]: ... + def _keys_impl(self) -> set[K]: ... + def keys(self) -> set[K]: ... # type: ignore + def __iter__(self) -> set[K]: ... # type: ignore + def items(self, multi: bool = False) -> Iterator[tuple[K, V]]: ... # type: ignore + def values(self) -> Iterator[V]: ... # type: ignore + def lists(self) -> Iterator[tuple[K, list[V]]]: ... + def listvalues(self) -> Iterator[list[V]]: ... + def copy(self) -> MultiDict[K, V]: ... + @overload + def to_dict(self) -> dict[K, V]: ... + @overload + def to_dict(self, flat: Literal[False]) -> dict[K, list[V]]: ... + def __contains__(self, key: K) -> bool: ... # type: ignore + def has_key(self, key: K) -> bool: ... + +class ImmutableDict(ImmutableDictMixin[K, V], dict[K, V]): + def copy(self) -> dict[K, V]: ... + def __copy__(self) -> ImmutableDict[K, V]: ... + +class ImmutableMultiDict( # type: ignore + ImmutableMultiDictMixin[K, V], MultiDict[K, V] +): + def copy(self) -> MultiDict[K, V]: ... + def __copy__(self) -> ImmutableMultiDict[K, V]: ... + +class ImmutableOrderedMultiDict( # type: ignore + ImmutableMultiDictMixin[K, V], OrderedMultiDict[K, V] +): + def _iter_hashitems(self) -> Iterator[tuple[int, tuple[K, V]]]: ... + def copy(self) -> OrderedMultiDict[K, V]: ... + def __copy__(self) -> ImmutableOrderedMultiDict[K, V]: ... + +class CallbackDict(UpdateDictMixin[K, V], dict[K, V]): + def __init__( + self, + initial: Mapping[K, V] | Iterable[tuple[K, V]] | None = None, + on_update: Callable[[_CD], None] | None = None, + ) -> None: ... + +class HeaderSet(set[str]): + _headers: list[str] + _set: set[str] + on_update: Callable[[HeaderSet], None] | None + def __init__( + self, + headers: Iterable[str] | None = None, + on_update: Callable[[HeaderSet], None] | None = None, + ) -> None: ... + def add(self, header: str) -> None: ... + def remove(self, header: str) -> None: ... + def update(self, iterable: Iterable[str]) -> None: ... # type: ignore + def discard(self, header: str) -> None: ... + def find(self, header: str) -> int: ... + def index(self, header: str) -> int: ... + def clear(self) -> None: ... + def as_set(self, preserve_casing: bool = False) -> set[str]: ... + def to_header(self) -> str: ... + def __getitem__(self, idx: int) -> str: ... + def __delitem__(self, idx: int) -> None: ... + def __setitem__(self, idx: int, value: str) -> None: ... + def __contains__(self, header: str) -> bool: ... # type: ignore + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[str]: ... diff --git a/src/werkzeug/debug/__init__.py b/src/werkzeug/debug/__init__.py index e0dcc65..6bef30f 100644 --- a/src/werkzeug/debug/__init__.py +++ b/src/werkzeug/debug/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import getpass import hashlib import json @@ -9,7 +11,6 @@ import typing as t import uuid from contextlib import ExitStack -from contextlib import nullcontext from io import BytesIO from itertools import chain from os.path import basename @@ -18,7 +19,9 @@ from .._internal import _log from ..exceptions import NotFound +from ..exceptions import SecurityError from ..http import parse_cookie +from ..sansio.utils import host_is_trusted from ..security import gen_salt from ..utils import send_file from ..wrappers.request import Request @@ -41,16 +44,16 @@ def hash_pin(pin: str) -> str: return hashlib.sha1(f"{pin} added salt".encode("utf-8", "replace")).hexdigest()[:12] -_machine_id: t.Optional[t.Union[str, bytes]] = None +_machine_id: str | bytes | None = None -def get_machine_id() -> t.Optional[t.Union[str, bytes]]: +def get_machine_id() -> str | bytes | None: global _machine_id if _machine_id is not None: return _machine_id - def _generate() -> t.Optional[t.Union[str, bytes]]: + def _generate() -> str | bytes | None: linux = b"" # machine-id is stable across boots, boot_id is not. @@ -81,7 +84,8 @@ def _generate() -> t.Optional[t.Union[str, bytes]]: try: # subprocess may not be available, e.g. Google App Engine # https://github.com/pallets/werkzeug/issues/925 - from subprocess import Popen, PIPE + from subprocess import PIPE + from subprocess import Popen dump = Popen( ["ioreg", "-c", "IOPlatformExpertDevice", "-d", "2"], stdout=PIPE @@ -104,12 +108,12 @@ def _generate() -> t.Optional[t.Union[str, bytes]]: 0, winreg.KEY_READ | winreg.KEY_WOW64_64KEY, ) as rk: - guid: t.Union[str, bytes] + guid: str | bytes guid_type: int guid, guid_type = winreg.QueryValueEx(rk, "MachineGuid") if guid_type == winreg.REG_SZ: - return guid.encode("utf-8") + return guid.encode() return guid except OSError: @@ -126,7 +130,7 @@ class _ConsoleFrame: standalone console. """ - def __init__(self, namespace: t.Dict[str, t.Any]): + def __init__(self, namespace: dict[str, t.Any]): self.console = Console(namespace) self.id = 0 @@ -135,8 +139,8 @@ def eval(self, code: str) -> t.Any: def get_pin_and_cookie_name( - app: "WSGIApplication", -) -> t.Union[t.Tuple[str, str], t.Tuple[None, None]]: + app: WSGIApplication, +) -> tuple[str, str] | tuple[None, None]: """Given an application object this returns a semi-stable 9 digit pin code and a random key. The hope is that this is stable between restarts to not make debugging particularly frustrating. If the pin @@ -161,7 +165,7 @@ def get_pin_and_cookie_name( num = pin modname = getattr(app, "__module__", t.cast(object, app).__class__.__module__) - username: t.Optional[str] + username: str | None try: # getuser imports the pwd module, which does not exist in Google @@ -192,7 +196,7 @@ def get_pin_and_cookie_name( if not bit: continue if isinstance(bit, str): - bit = bit.encode("utf-8") + bit = bit.encode() h.update(bit) h.update(b"cookiesalt") @@ -229,8 +233,8 @@ class DebuggedApplication: The ``evalex`` argument allows evaluating expressions in any frame of a traceback. This works by preserving each frame with its local - state. Some state, such as :doc:`local`, cannot be restored with the - frame by default. When ``evalex`` is enabled, + state. Some state, such as context globals, cannot be restored with + the frame by default. When ``evalex`` is enabled, ``environ["werkzeug.debug.preserve_context"]`` will be a callable that takes a context manager, and can be called multiple times. Each context manager will be entered before evaluating code in the @@ -262,11 +266,11 @@ class DebuggedApplication: def __init__( self, - app: "WSGIApplication", + app: WSGIApplication, evalex: bool = False, request_key: str = "werkzeug.request", console_path: str = "/console", - console_init_func: t.Optional[t.Callable[[], t.Dict[str, t.Any]]] = None, + console_init_func: t.Callable[[], dict[str, t.Any]] | None = None, show_hidden_frames: bool = False, pin_security: bool = True, pin_logging: bool = True, @@ -275,8 +279,8 @@ def __init__( console_init_func = None self.app = app self.evalex = evalex - self.frames: t.Dict[int, t.Union[DebugFrameSummary, _ConsoleFrame]] = {} - self.frame_contexts: t.Dict[int, t.List[t.ContextManager[None]]] = {} + self.frames: dict[int, DebugFrameSummary | _ConsoleFrame] = {} + self.frame_contexts: dict[int, list[t.ContextManager[None]]] = {} self.request_key = request_key self.console_path = console_path self.console_init_func = console_init_func @@ -296,8 +300,16 @@ def __init__( else: self.pin = None + self.trusted_hosts: list[str] = [".localhost", "127.0.0.1"] + """List of domains to allow requests to the debugger from. A leading dot + allows all subdomains. This only allows ``".localhost"`` domains by + default. + + .. versionadded:: 3.0.3 + """ + @property - def pin(self) -> t.Optional[str]: + def pin(self) -> str | None: if not hasattr(self, "_pin"): pin_cookie = get_pin_and_cookie_name(self.app) self._pin, self._pin_cookie = pin_cookie # type: ignore @@ -316,10 +328,10 @@ def pin_cookie_name(self) -> str: return self._pin_cookie def debug_application( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterator[bytes]: """Run the application and conserve the traceback frames.""" - contexts: t.List[t.ContextManager[t.Any]] = [] + contexts: list[t.ContextManager[t.Any]] = [] if self.evalex: environ["werkzeug.debug.preserve_context"] = contexts.append @@ -329,7 +341,7 @@ def debug_application( app_iter = self.app(environ, start_response) yield from app_iter if hasattr(app_iter, "close"): - app_iter.close() # type: ignore + app_iter.close() except Exception as e: if hasattr(app_iter, "close"): app_iter.close() # type: ignore @@ -342,7 +354,7 @@ def debug_application( is_trusted = bool(self.check_pin_trust(environ)) html = tb.render_debugger_html( - evalex=self.evalex, + evalex=self.evalex and self.check_host_trust(environ), secret=self.secret, evalex_trusted=is_trusted, ) @@ -367,9 +379,12 @@ def execute_command( # type: ignore[return] self, request: Request, command: str, - frame: t.Union[DebugFrameSummary, _ConsoleFrame], + frame: DebugFrameSummary | _ConsoleFrame, ) -> Response: """Execute a command in a console.""" + if not self.check_host_trust(request.environ): + return SecurityError() # type: ignore[return-value] + contexts = self.frame_contexts.get(id(frame), []) with ExitStack() as exit_stack: @@ -380,6 +395,9 @@ def execute_command( # type: ignore[return] def display_console(self, request: Request) -> Response: """Display a standalone shell.""" + if not self.check_host_trust(request.environ): + return SecurityError() # type: ignore[return-value] + if 0 not in self.frames: if self.console_init_func is None: ns = {} @@ -410,7 +428,7 @@ def get_resource(self, request: Request, filename: str) -> Response: BytesIO(data), request.environ, download_name=filename, etag=etag ) - def check_pin_trust(self, environ: "WSGIEnvironment") -> t.Optional[bool]: + def check_pin_trust(self, environ: WSGIEnvironment) -> bool | None: """Checks if the request passed the pin test. This returns `True` if the request is trusted on a pin/cookie basis and returns `False` if not. Additionally if the cookie's stored pin hash is wrong it will return @@ -432,12 +450,18 @@ def check_pin_trust(self, environ: "WSGIEnvironment") -> t.Optional[bool]: return None return (time.time() - PIN_TIME) < ts + def check_host_trust(self, environ: WSGIEnvironment) -> bool: + return host_is_trusted(environ.get("HTTP_HOST"), self.trusted_hosts) + def _fail_pin_auth(self) -> None: time.sleep(5.0 if self._failed_pin_auth > 5 else 0.5) self._failed_pin_auth += 1 def pin_auth(self, request: Request) -> Response: """Authenticates with the pin.""" + if not self.check_host_trust(request.environ): + return SecurityError() # type: ignore[return-value] + exhausted = False auth = False trust = self.check_pin_trust(request.environ) @@ -487,8 +511,11 @@ def pin_auth(self, request: Request) -> Response: rv.delete_cookie(self.pin_cookie_name) return rv - def log_pin_request(self) -> Response: + def log_pin_request(self, request: Request) -> Response: """Log the pin if needed.""" + if not self.check_host_trust(request.environ): + return SecurityError() # type: ignore[return-value] + if self.pin_logging and self.pin is not None: _log( "info", " * To enable the debugger you need to enter the security pin:" @@ -497,7 +524,7 @@ def log_pin_request(self) -> Response: return Response("") def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: """Dispatch the requests.""" # important: don't ever access a function here that reads the incoming @@ -515,7 +542,7 @@ def __call__( elif cmd == "pinauth" and secret == self.secret: response = self.pin_auth(request) # type: ignore elif cmd == "printpin" and secret == self.secret: - response = self.log_pin_request() # type: ignore + response = self.log_pin_request(request) # type: ignore elif ( self.evalex and cmd is not None diff --git a/src/werkzeug/debug/console.py b/src/werkzeug/debug/console.py index 69974d1..4e40475 100644 --- a/src/werkzeug/debug/console.py +++ b/src/werkzeug/debug/console.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import code import sys import typing as t @@ -10,18 +12,15 @@ from .repr import dump from .repr import helper -if t.TYPE_CHECKING: - import codeop # noqa: F401 - -_stream: ContextVar["HTMLStringO"] = ContextVar("werkzeug.debug.console.stream") -_ipy: ContextVar = ContextVar("werkzeug.debug.console.ipy") +_stream: ContextVar[HTMLStringO] = ContextVar("werkzeug.debug.console.stream") +_ipy: ContextVar[_InteractiveConsole] = ContextVar("werkzeug.debug.console.ipy") class HTMLStringO: """A StringO version that HTML escapes on write.""" def __init__(self) -> None: - self._buffer: t.List[str] = [] + self._buffer: list[str] = [] def isatty(self) -> bool: return False @@ -48,8 +47,6 @@ def reset(self) -> str: return val def _write(self, x: str) -> None: - if isinstance(x, bytes): - x = x.decode("utf-8", "replace") self._buffer.append(x) def write(self, x: str) -> None: @@ -94,7 +91,7 @@ def displayhook(obj: object) -> None: def __setattr__(self, name: str, value: t.Any) -> None: raise AttributeError(f"read only attribute {name}") - def __dir__(self) -> t.List[str]: + def __dir__(self) -> list[str]: return dir(sys.__stdout__) def __getattribute__(self, name: str) -> t.Any: @@ -116,7 +113,7 @@ def __repr__(self) -> str: class _ConsoleLoader: def __init__(self) -> None: - self._storage: t.Dict[int, str] = {} + self._storage: dict[int, str] = {} def register(self, code: CodeType, source: str) -> None: self._storage[id(code)] = source @@ -125,7 +122,7 @@ def register(self, code: CodeType, source: str) -> None: if isinstance(var, CodeType): self._storage[id(var)] = source - def get_source_by_code(self, code: CodeType) -> t.Optional[str]: + def get_source_by_code(self, code: CodeType) -> str | None: try: return self._storage[id(code)] except KeyError: @@ -133,9 +130,9 @@ def get_source_by_code(self, code: CodeType) -> t.Optional[str]: class _InteractiveConsole(code.InteractiveInterpreter): - locals: t.Dict[str, t.Any] + locals: dict[str, t.Any] - def __init__(self, globals: t.Dict[str, t.Any], locals: t.Dict[str, t.Any]) -> None: + def __init__(self, globals: dict[str, t.Any], locals: dict[str, t.Any]) -> None: self.loader = _ConsoleLoader() locals = { **globals, @@ -147,7 +144,7 @@ def __init__(self, globals: t.Dict[str, t.Any], locals: t.Dict[str, t.Any]) -> N super().__init__(locals) original_compile = self.compile - def compile(source: str, filename: str, symbol: str) -> t.Optional[CodeType]: + def compile(source: str, filename: str, symbol: str) -> CodeType | None: code = original_compile(source, filename, symbol) if code is not None: @@ -157,7 +154,7 @@ def compile(source: str, filename: str, symbol: str) -> t.Optional[CodeType]: self.compile = compile # type: ignore[assignment] self.more = False - self.buffer: t.List[str] = [] + self.buffer: list[str] = [] def runsource(self, source: str, **kwargs: t.Any) -> str: # type: ignore source = f"{source.rstrip()}\n" @@ -188,7 +185,7 @@ def showtraceback(self) -> None: te = DebugTraceback(exc, skip=1) sys.stdout._write(te.render_traceback_html()) # type: ignore - def showsyntaxerror(self, filename: t.Optional[str] = None) -> None: + def showsyntaxerror(self, filename: str | None = None) -> None: from .tbtools import DebugTraceback exc = t.cast(BaseException, sys.exc_info()[1]) @@ -204,8 +201,8 @@ class Console: def __init__( self, - globals: t.Optional[t.Dict[str, t.Any]] = None, - locals: t.Optional[t.Dict[str, t.Any]] = None, + globals: dict[str, t.Any] | None = None, + locals: dict[str, t.Any] | None = None, ) -> None: if locals is None: locals = {} diff --git a/src/werkzeug/debug/repr.py b/src/werkzeug/debug/repr.py index c0872f1..2bbd9d5 100644 --- a/src/werkzeug/debug/repr.py +++ b/src/werkzeug/debug/repr.py @@ -4,6 +4,9 @@ Together with the CSS and JavaScript of the debugger this gives a colorful and more compact output. """ + +from __future__ import annotations + import codecs import re import sys @@ -57,7 +60,7 @@ class _Helper: def __repr__(self) -> str: return "Type help(object) for help about object." - def __call__(self, topic: t.Optional[t.Any] = None) -> None: + def __call__(self, topic: t.Any | None = None) -> None: if topic is None: sys.stdout._write(f"{self!r}") # type: ignore return @@ -65,8 +68,6 @@ def __call__(self, topic: t.Optional[t.Any] = None) -> None: pydoc.help(topic) rv = sys.stdout.reset() # type: ignore - if isinstance(rv, bytes): - rv = rv.decode("utf-8", "ignore") paragraphs = _paragraph_re.split(rv) if len(paragraphs) > 1: title = paragraphs[0] @@ -80,9 +81,7 @@ def __call__(self, topic: t.Optional[t.Any] = None) -> None: helper = _Helper() -def _add_subclass_info( - inner: str, obj: object, base: t.Union[t.Type, t.Tuple[t.Type, ...]] -) -> str: +def _add_subclass_info(inner: str, obj: object, base: type | tuple[type, ...]) -> str: if isinstance(base, tuple): for cls in base: if type(obj) is cls: @@ -96,9 +95,9 @@ def _add_subclass_info( def _sequence_repr_maker( - left: str, right: str, base: t.Type, limit: int = 8 -) -> t.Callable[["DebugReprGenerator", t.Iterable, bool], str]: - def proxy(self: "DebugReprGenerator", obj: t.Iterable, recursive: bool) -> str: + left: str, right: str, base: type, limit: int = 8 +) -> t.Callable[[DebugReprGenerator, t.Iterable[t.Any], bool], str]: + def proxy(self: DebugReprGenerator, obj: t.Iterable[t.Any], recursive: bool) -> str: if recursive: return _add_subclass_info(f"{left}...{right}", obj, base) buf = [left] @@ -120,7 +119,7 @@ def proxy(self: "DebugReprGenerator", obj: t.Iterable, recursive: bool) -> str: class DebugReprGenerator: def __init__(self) -> None: - self._stack: t.List[t.Any] = [] + self._stack: list[t.Any] = [] list_repr = _sequence_repr_maker("[", "]", list) tuple_repr = _sequence_repr_maker("(", ")", tuple) @@ -130,13 +129,13 @@ def __init__(self) -> None: 'collections.deque([', "])", deque ) - def regex_repr(self, obj: t.Pattern) -> str: + def regex_repr(self, obj: t.Pattern[t.AnyStr]) -> str: pattern = repr(obj.pattern) - pattern = codecs.decode(pattern, "unicode-escape", "ignore") # type: ignore + pattern = codecs.decode(pattern, "unicode-escape", "ignore") pattern = f"r{pattern}" return f're.compile({pattern})' - def string_repr(self, obj: t.Union[str, bytes], limit: int = 70) -> str: + def string_repr(self, obj: str | bytes, limit: int = 70) -> str: buf = [''] r = repr(obj) @@ -165,7 +164,7 @@ def string_repr(self, obj: t.Union[str, bytes], limit: int = 70) -> str: def dict_repr( self, - d: t.Union[t.Dict[int, None], t.Dict[str, int], t.Dict[t.Union[str, int], int]], + d: dict[int, None] | dict[str, int] | dict[str | int, int], recursive: bool, limit: int = 5, ) -> str: @@ -188,9 +187,7 @@ def dict_repr( buf.append("}") return _add_subclass_info("".join(buf), d, dict) - def object_repr( - self, obj: t.Optional[t.Union[t.Type[dict], t.Callable, t.Type[list]]] - ) -> str: + def object_repr(self, obj: t.Any) -> str: r = repr(obj) return f'{escape(r)}' @@ -244,7 +241,7 @@ def repr(self, obj: object) -> str: def dump_object(self, obj: object) -> str: repr = None - items: t.Optional[t.List[t.Tuple[str, str]]] = None + items: list[tuple[str, str]] | None = None if isinstance(obj, dict): title = "Contents of" @@ -266,12 +263,12 @@ def dump_object(self, obj: object) -> str: title += f" {object.__repr__(obj)[1:-1]}" return self.render_object_dump(items, title, repr) - def dump_locals(self, d: t.Dict[str, t.Any]) -> str: + def dump_locals(self, d: dict[str, t.Any]) -> str: items = [(key, self.repr(value)) for key, value in d.items()] return self.render_object_dump(items, "Local variables in frame") def render_object_dump( - self, items: t.List[t.Tuple[str, str]], title: str, repr: t.Optional[str] = None + self, items: list[tuple[str, str]], title: str, repr: str | None = None ) -> str: html_items = [] for key, value in items: diff --git a/src/werkzeug/debug/shared/debugger.js b/src/werkzeug/debug/shared/debugger.js index 2354f03..18c6583 100644 --- a/src/werkzeug/debug/shared/debugger.js +++ b/src/werkzeug/debug/shared/debugger.js @@ -48,7 +48,7 @@ function initPinBox() { btn.disabled = true; fetch( - `${document.location.pathname}?__debugger__=yes&cmd=pinauth&pin=${pin}&s=${encodedSecret}` + `${document.location}?__debugger__=yes&cmd=pinauth&pin=${pin}&s=${encodedSecret}` ) .then((res) => res.json()) .then(({auth, exhausted}) => { @@ -79,7 +79,7 @@ function promptForPin() { if (!EVALEX_TRUSTED) { const encodedSecret = encodeURIComponent(SECRET); fetch( - `${document.location.pathname}?__debugger__=yes&cmd=printpin&s=${encodedSecret}` + `${document.location}?__debugger__=yes&cmd=printpin&s=${encodedSecret}` ); const pinPrompt = document.getElementsByClassName("pin-prompt")[0]; fadeIn(pinPrompt); @@ -305,7 +305,8 @@ function handleConsoleSubmit(e, command, frameID) { wrapperSpan.append(spanToWrap); spanToWrap.hidden = true; - expansionButton.addEventListener("click", () => { + expansionButton.addEventListener("click", (event) => { + event.preventDefault(); spanToWrap.hidden = !spanToWrap.hidden; expansionButton.classList.toggle("open"); return false; diff --git a/src/werkzeug/debug/tbtools.py b/src/werkzeug/debug/tbtools.py index ea90de9..0574c96 100644 --- a/src/werkzeug/debug/tbtools.py +++ b/src/werkzeug/debug/tbtools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import linecache import os @@ -123,7 +125,7 @@ def _process_traceback( exc: BaseException, - te: t.Optional[traceback.TracebackException] = None, + te: traceback.TracebackException | None = None, *, skip: int = 0, hide: bool = True, @@ -146,7 +148,7 @@ def _process_traceback( frame_gen = itertools.islice(frame_gen, skip, None) del te.stack[:skip] - new_stack: t.List[DebugFrameSummary] = [] + new_stack: list[DebugFrameSummary] = [] hidden = False # Match each frame with the FrameSummary that was generated. @@ -175,7 +177,7 @@ def _process_traceback( elif hide_value or hidden: continue - frame_args: t.Dict[str, t.Any] = { + frame_args: dict[str, t.Any] = { "filename": fs.filename, "lineno": fs.lineno, "name": fs.name, @@ -184,8 +186,8 @@ def _process_traceback( } if hasattr(fs, "colno"): - frame_args["colno"] = fs.colno # type: ignore[attr-defined] - frame_args["end_colno"] = fs.end_colno # type: ignore[attr-defined] + frame_args["colno"] = fs.colno + frame_args["end_colno"] = fs.end_colno new_stack.append(DebugFrameSummary(**frame_args)) @@ -221,7 +223,7 @@ class DebugTraceback: def __init__( self, exc: BaseException, - te: t.Optional[traceback.TracebackException] = None, + te: traceback.TracebackException | None = None, *, skip: int = 0, hide: bool = True, @@ -234,7 +236,7 @@ def __str__(self) -> str: @cached_property def all_tracebacks( self, - ) -> t.List[t.Tuple[t.Optional[str], traceback.TracebackException]]: + ) -> list[tuple[str | None, traceback.TracebackException]]: out = [] current = self._te @@ -261,9 +263,11 @@ def all_tracebacks( return out @cached_property - def all_frames(self) -> t.List["DebugFrameSummary"]: + def all_frames(self) -> list[DebugFrameSummary]: return [ - f for _, te in self.all_tracebacks for f in te.stack # type: ignore[misc] + f # type: ignore[misc] + for _, te in self.all_tracebacks + for f in te.stack ] def render_traceback_text(self) -> str: @@ -325,7 +329,7 @@ def render_debugger_html( "evalex": "true" if evalex else "false", "evalex_trusted": "true" if evalex_trusted else "false", "console": "false", - "title": exc_lines[0], + "title": escape(exc_lines[0]), "exception": escape("".join(exc_lines)), "exception_type": escape(self._te.exc_type.__name__), "summary": self.render_traceback_html(include_title=False), @@ -351,8 +355,8 @@ class DebugFrameSummary(traceback.FrameSummary): def __init__( self, *, - locals: t.Dict[str, t.Any], - globals: t.Dict[str, t.Any], + locals: dict[str, t.Any], + globals: dict[str, t.Any], **kwargs: t.Any, ) -> None: super().__init__(locals=None, **kwargs) @@ -360,7 +364,7 @@ def __init__( self.global_ns = globals @cached_property - def info(self) -> t.Optional[str]: + def info(self) -> str | None: return self.local_ns.get("__traceback_info__") @cached_property diff --git a/src/werkzeug/exceptions.py b/src/werkzeug/exceptions.py index 013df72..6ce7ef9 100644 --- a/src/werkzeug/exceptions.py +++ b/src/werkzeug/exceptions.py @@ -43,6 +43,9 @@ def application(request): return e """ + +from __future__ import annotations + import typing as t from datetime import datetime @@ -52,13 +55,13 @@ def application(request): from ._internal import _get_environ if t.TYPE_CHECKING: - import typing_extensions as te from _typeshed.wsgi import StartResponse from _typeshed.wsgi import WSGIEnvironment + from .datastructures import WWWAuthenticate from .sansio.response import Response - from .wrappers.request import Request as WSGIRequest # noqa: F401 - from .wrappers.response import Response as WSGIResponse # noqa: F401 + from .wrappers.request import Request as WSGIRequest + from .wrappers.response import Response as WSGIResponse class HTTPException(Exception): @@ -70,13 +73,13 @@ class HTTPException(Exception): Removed the ``wrap`` class method. """ - code: t.Optional[int] = None - description: t.Optional[str] = None + code: int | None = None + description: str | None = None def __init__( self, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, + description: str | None = None, + response: Response | None = None, ) -> None: super().__init__() if description is not None: @@ -92,14 +95,12 @@ def name(self) -> str: def get_description( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, + environ: WSGIEnvironment | None = None, + scope: dict[str, t.Any] | None = None, ) -> str: """Get the description.""" if self.description is None: description = "" - elif not isinstance(self.description, str): - description = str(self.description) else: description = self.description @@ -108,8 +109,8 @@ def get_description( def get_body( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, + environ: WSGIEnvironment | None = None, + scope: dict[str, t.Any] | None = None, ) -> str: """Get the HTML body.""" return ( @@ -122,17 +123,17 @@ def get_body( def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict[str, t.Any] | None = None, + ) -> list[tuple[str, str]]: """Get a list of headers.""" return [("Content-Type", "text/html; charset=utf-8")] def get_response( self, - environ: t.Optional[t.Union["WSGIEnvironment", "WSGIRequest"]] = None, - scope: t.Optional[dict] = None, - ) -> "Response": + environ: WSGIEnvironment | WSGIRequest | None = None, + scope: dict[str, t.Any] | None = None, + ) -> Response: """Get a response object. If one was passed to the exception it's returned directly. @@ -151,7 +152,7 @@ def get_response( return WSGIResponse(self.get_body(environ, scope), self.code, headers) def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: """Call the exception as WSGI application. @@ -196,7 +197,7 @@ class BadRequestKeyError(BadRequest, KeyError): #: useful in a debug mode. show_exception = False - def __init__(self, arg: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any): + def __init__(self, arg: str | None = None, *args: t.Any, **kwargs: t.Any): super().__init__(*args, **kwargs) if arg is None: @@ -205,7 +206,7 @@ def __init__(self, arg: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any): KeyError.__init__(self, arg) @property # type: ignore - def description(self) -> str: # type: ignore + def description(self) -> str: if self.show_exception: return ( f"{self._description}\n" @@ -297,11 +298,9 @@ class Unauthorized(HTTPException): def __init__( self, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, - www_authenticate: t.Optional[ - t.Union["WWWAuthenticate", t.Iterable["WWWAuthenticate"]] - ] = None, + description: str | None = None, + response: Response | None = None, + www_authenticate: None | (WWWAuthenticate | t.Iterable[WWWAuthenticate]) = None, ) -> None: super().__init__(description, response) @@ -314,9 +313,9 @@ def __init__( def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict[str, t.Any] | None = None, + ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.www_authenticate: headers.extend(("WWW-Authenticate", str(x)) for x in self.www_authenticate) @@ -367,9 +366,9 @@ class MethodNotAllowed(HTTPException): def __init__( self, - valid_methods: t.Optional[t.Iterable[str]] = None, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, + valid_methods: t.Iterable[str] | None = None, + description: str | None = None, + response: Response | None = None, ) -> None: """Takes an optional list of valid http methods starting with werkzeug 0.3 the list will be mandatory.""" @@ -378,9 +377,9 @@ def __init__( def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict[str, t.Any] | None = None, + ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.valid_methods: headers.append(("Allow", ", ".join(self.valid_methods))) @@ -524,10 +523,10 @@ class RequestedRangeNotSatisfiable(HTTPException): def __init__( self, - length: t.Optional[int] = None, + length: int | None = None, units: str = "bytes", - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, + description: str | None = None, + response: Response | None = None, ) -> None: """Takes an optional `Content-Range` header value based on ``length`` parameter. @@ -538,9 +537,9 @@ def __init__( def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict[str, t.Any] | None = None, + ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.length is not None: headers.append(("Content-Range", f"{self.units} */{self.length}")) @@ -638,18 +637,18 @@ class _RetryAfter(HTTPException): def __init__( self, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, - retry_after: t.Optional[t.Union[datetime, int]] = None, + description: str | None = None, + response: Response | None = None, + retry_after: datetime | int | None = None, ) -> None: super().__init__(description, response) self.retry_after = retry_after def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict[str, t.Any] | None = None, + ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.retry_after: @@ -728,9 +727,9 @@ class InternalServerError(HTTPException): def __init__( self, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, - original_exception: t.Optional[BaseException] = None, + description: str | None = None, + response: Response | None = None, + original_exception: BaseException | None = None, ) -> None: #: The original exception that caused this 500 error. Can be #: used by frameworks to provide context when handling @@ -809,7 +808,7 @@ class HTTPVersionNotSupported(HTTPException): ) -default_exceptions: t.Dict[int, t.Type[HTTPException]] = {} +default_exceptions: dict[int, type[HTTPException]] = {} def _find_exceptions() -> None: @@ -841,8 +840,8 @@ class Aborter: def __init__( self, - mapping: t.Optional[t.Dict[int, t.Type[HTTPException]]] = None, - extra: t.Optional[t.Dict[int, t.Type[HTTPException]]] = None, + mapping: dict[int, type[HTTPException]] | None = None, + extra: dict[int, type[HTTPException]] | None = None, ) -> None: if mapping is None: mapping = default_exceptions @@ -851,8 +850,8 @@ def __init__( self.mapping.update(extra) def __call__( - self, code: t.Union[int, "Response"], *args: t.Any, **kwargs: t.Any - ) -> "te.NoReturn": + self, code: int | Response, *args: t.Any, **kwargs: t.Any + ) -> t.NoReturn: from .sansio.response import Response if isinstance(code, Response): @@ -864,9 +863,7 @@ def __call__( raise self.mapping[code](*args, **kwargs) -def abort( - status: t.Union[int, "Response"], *args: t.Any, **kwargs: t.Any -) -> "te.NoReturn": +def abort(status: int | Response, *args: t.Any, **kwargs: t.Any) -> t.NoReturn: """Raises an :py:exc:`HTTPException` for the given status code or WSGI application. diff --git a/src/werkzeug/formparser.py b/src/werkzeug/formparser.py index 10d58ca..ba84721 100644 --- a/src/werkzeug/formparser.py +++ b/src/werkzeug/formparser.py @@ -1,13 +1,14 @@ +from __future__ import annotations + import typing as t -from functools import update_wrapper from io import BytesIO -from itertools import chain -from typing import Union +from urllib.parse import parse_qsl -from . import exceptions +from ._internal import _plain_int from .datastructures import FileStorage from .datastructures import Headers from .datastructures import MultiDict +from .exceptions import RequestEntityTooLarge from .http import parse_options_header from .sansio.multipart import Data from .sansio.multipart import Epilogue @@ -15,8 +16,6 @@ from .sansio.multipart import File from .sansio.multipart import MultipartDecoder from .sansio.multipart import NeedData -from .urls import url_decode_stream -from .wsgi import _make_chunk_iter from .wsgi import get_content_length from .wsgi import get_input_stream @@ -31,35 +30,31 @@ if t.TYPE_CHECKING: import typing as te + from _typeshed.wsgi import WSGIEnvironment - t_parse_result = t.Tuple[t.IO[bytes], MultiDict, MultiDict] + t_parse_result = t.Tuple[ + t.IO[bytes], MultiDict[str, str], MultiDict[str, FileStorage] + ] class TStreamFactory(te.Protocol): def __call__( self, - total_content_length: t.Optional[int], - content_type: t.Optional[str], - filename: t.Optional[str], - content_length: t.Optional[int] = None, - ) -> t.IO[bytes]: - ... + total_content_length: int | None, + content_type: str | None, + filename: str | None, + content_length: int | None = None, + ) -> t.IO[bytes]: ... F = t.TypeVar("F", bound=t.Callable[..., t.Any]) -def _exhaust(stream: t.IO[bytes]) -> None: - bts = stream.read(64 * 1024) - while bts: - bts = stream.read(64 * 1024) - - def default_stream_factory( - total_content_length: t.Optional[int], - content_type: t.Optional[str], - filename: t.Optional[str], - content_length: t.Optional[int] = None, + total_content_length: int | None, + content_type: str | None, + filename: str | None, + content_length: int | None = None, ) -> t.IO[bytes]: max_size = 1024 * 500 @@ -72,15 +67,15 @@ def default_stream_factory( def parse_form_data( - environ: "WSGIEnvironment", - stream_factory: t.Optional["TStreamFactory"] = None, - charset: str = "utf-8", - errors: str = "replace", - max_form_memory_size: t.Optional[int] = None, - max_content_length: t.Optional[int] = None, - cls: t.Optional[t.Type[MultiDict]] = None, + environ: WSGIEnvironment, + stream_factory: TStreamFactory | None = None, + max_form_memory_size: int | None = None, + max_content_length: int | None = None, + cls: type[MultiDict[str, t.Any]] | None = None, silent: bool = True, -) -> "t_parse_result": + *, + max_form_parts: int | None = None, +) -> t_parse_result: """Parse the form data in the environ and return it as tuple in the form ``(stream, form, files)``. You should only call this method if the transport method is `POST`, `PUT`, or `PATCH`. @@ -92,21 +87,10 @@ def parse_form_data( This is a shortcut for the common usage of :class:`FormDataParser`. - Have a look at :doc:`/request_data` for more details. - - .. versionadded:: 0.5 - The `max_form_memory_size`, `max_content_length` and - `cls` parameters were added. - - .. versionadded:: 0.5.1 - The optional `silent` flag was added. - :param environ: the WSGI environment to be used for parsing. :param stream_factory: An optional callable that returns a new read and writeable file descriptor. This callable works the same as :meth:`Response._get_file_stream`. - :param charset: The character set for URL and url encoded form data. - :param errors: The encoding error behavior. :param max_form_memory_size: the maximum number of bytes to be accepted for in-memory stored form data. If the data exceeds the value specified an @@ -119,38 +103,31 @@ def parse_form_data( :param cls: an optional dict class to use. If this is not specified or `None` the default :class:`MultiDict` is used. :param silent: If set to False parsing errors will not be caught. + :param max_form_parts: The maximum number of multipart parts to be parsed. If this + is exceeded, a :exc:`~exceptions.RequestEntityTooLarge` exception is raised. :return: A tuple in the form ``(stream, form, files)``. - """ - return FormDataParser( - stream_factory, - charset, - errors, - max_form_memory_size, - max_content_length, - cls, - silent, - ).parse_from_environ(environ) + .. versionchanged:: 3.0 + The ``charset`` and ``errors`` parameters were removed. -def exhaust_stream(f: F) -> F: - """Helper decorator for methods that exhausts the stream on return.""" + .. versionchanged:: 2.3 + Added the ``max_form_parts`` parameter. - def wrapper(self, stream, *args, **kwargs): # type: ignore - try: - return f(self, stream, *args, **kwargs) - finally: - exhaust = getattr(stream, "exhaust", None) - - if exhaust is not None: - exhaust() - else: - while True: - chunk = stream.read(1024 * 64) - - if not chunk: - break + .. versionadded:: 0.5.1 + Added the ``silent`` parameter. - return update_wrapper(t.cast(F, wrapper), f) + .. versionadded:: 0.5 + Added the ``max_form_memory_size``, ``max_content_length``, and ``cls`` + parameters. + """ + return FormDataParser( + stream_factory=stream_factory, + max_form_memory_size=max_form_memory_size, + max_content_length=max_content_length, + max_form_parts=max_form_parts, + silent=silent, + cls=cls, + ).parse_from_environ(environ) class FormDataParser: @@ -160,13 +137,9 @@ class FormDataParser: untouched stream and expose it as separate attributes on a request object. - .. versionadded:: 0.8 - :param stream_factory: An optional callable that returns a new read and writeable file descriptor. This callable works the same as :meth:`Response._get_file_stream`. - :param charset: The character set for URL and url encoded form data. - :param errors: The encoding error behavior. :param max_form_memory_size: the maximum number of bytes to be accepted for in-memory stored form data. If the data exceeds the value specified an @@ -179,61 +152,68 @@ class FormDataParser: :param cls: an optional dict class to use. If this is not specified or `None` the default :class:`MultiDict` is used. :param silent: If set to False parsing errors will not be caught. + :param max_form_parts: The maximum number of multipart parts to be parsed. If this + is exceeded, a :exc:`~exceptions.RequestEntityTooLarge` exception is raised. + + .. versionchanged:: 3.0 + The ``charset`` and ``errors`` parameters were removed. + + .. versionchanged:: 3.0 + The ``parse_functions`` attribute and ``get_parse_func`` methods were removed. + + .. versionchanged:: 2.2.3 + Added the ``max_form_parts`` parameter. + + .. versionadded:: 0.8 """ def __init__( self, - stream_factory: t.Optional["TStreamFactory"] = None, - charset: str = "utf-8", - errors: str = "replace", - max_form_memory_size: t.Optional[int] = None, - max_content_length: t.Optional[int] = None, - cls: t.Optional[t.Type[MultiDict]] = None, + stream_factory: TStreamFactory | None = None, + max_form_memory_size: int | None = None, + max_content_length: int | None = None, + cls: type[MultiDict[str, t.Any]] | None = None, silent: bool = True, + *, + max_form_parts: int | None = None, ) -> None: if stream_factory is None: stream_factory = default_stream_factory self.stream_factory = stream_factory - self.charset = charset - self.errors = errors self.max_form_memory_size = max_form_memory_size self.max_content_length = max_content_length + self.max_form_parts = max_form_parts if cls is None: - cls = MultiDict + cls = t.cast("type[MultiDict[str, t.Any]]", MultiDict) self.cls = cls self.silent = silent - def get_parse_func( - self, mimetype: str, options: t.Dict[str, str] - ) -> t.Optional[ - t.Callable[ - ["FormDataParser", t.IO[bytes], str, t.Optional[int], t.Dict[str, str]], - "t_parse_result", - ] - ]: - return self.parse_functions.get(mimetype) - - def parse_from_environ(self, environ: "WSGIEnvironment") -> "t_parse_result": + def parse_from_environ(self, environ: WSGIEnvironment) -> t_parse_result: """Parses the information from the environment as form data. :param environ: the WSGI environment to be used for parsing. :return: A tuple in the form ``(stream, form, files)``. """ - content_type = environ.get("CONTENT_TYPE", "") + stream = get_input_stream(environ, max_content_length=self.max_content_length) content_length = get_content_length(environ) - mimetype, options = parse_options_header(content_type) - return self.parse(get_input_stream(environ), mimetype, content_length, options) + mimetype, options = parse_options_header(environ.get("CONTENT_TYPE")) + return self.parse( + stream, + content_length=content_length, + mimetype=mimetype, + options=options, + ) def parse( self, stream: t.IO[bytes], mimetype: str, - content_length: t.Optional[int], - options: t.Optional[t.Dict[str, str]] = None, - ) -> "t_parse_result": + content_length: int | None, + options: dict[str, str] | None = None, + ) -> t_parse_result: """Parses the information from the given stream, mimetype, content length and mimetype parameters. @@ -243,43 +223,40 @@ def parse( :param options: optional mimetype parameters (used for the multipart boundary for instance) :return: A tuple in the form ``(stream, form, files)``. + + .. versionchanged:: 3.0 + The invalid ``application/x-url-encoded`` content type is not + treated as ``application/x-www-form-urlencoded``. """ - if ( - self.max_content_length is not None - and content_length is not None - and content_length > self.max_content_length - ): - # if the input stream is not exhausted, firefox reports Connection Reset - _exhaust(stream) - raise exceptions.RequestEntityTooLarge() + if mimetype == "multipart/form-data": + parse_func = self._parse_multipart + elif mimetype == "application/x-www-form-urlencoded": + parse_func = self._parse_urlencoded + else: + return stream, self.cls(), self.cls() if options is None: options = {} - parse_func = self.get_parse_func(mimetype, options) - - if parse_func is not None: - try: - return parse_func(self, stream, mimetype, content_length, options) - except ValueError: - if not self.silent: - raise + try: + return parse_func(stream, mimetype, content_length, options) + except ValueError: + if not self.silent: + raise return stream, self.cls(), self.cls() - @exhaust_stream def _parse_multipart( self, stream: t.IO[bytes], mimetype: str, - content_length: t.Optional[int], - options: t.Dict[str, str], - ) -> "t_parse_result": + content_length: int | None, + options: dict[str, str], + ) -> t_parse_result: parser = MultiPartParser( - self.stream_factory, - self.charset, - self.errors, + stream_factory=self.stream_factory, max_form_memory_size=self.max_form_memory_size, + max_form_parts=self.max_form_parts, cls=self.cls, ) boundary = options.get("boundary", "").encode("ascii") @@ -290,66 +267,43 @@ def _parse_multipart( form, files = parser.parse(stream, boundary, content_length) return stream, form, files - @exhaust_stream def _parse_urlencoded( self, stream: t.IO[bytes], mimetype: str, - content_length: t.Optional[int], - options: t.Dict[str, str], - ) -> "t_parse_result": + content_length: int | None, + options: dict[str, str], + ) -> t_parse_result: if ( self.max_form_memory_size is not None and content_length is not None and content_length > self.max_form_memory_size ): - # if the input stream is not exhausted, firefox reports Connection Reset - _exhaust(stream) - raise exceptions.RequestEntityTooLarge() - - form = url_decode_stream(stream, self.charset, errors=self.errors, cls=self.cls) - return stream, form, self.cls() - - #: mapping of mimetypes to parsing functions - parse_functions: t.Dict[ - str, - t.Callable[ - ["FormDataParser", t.IO[bytes], str, t.Optional[int], t.Dict[str, str]], - "t_parse_result", - ], - ] = { - "multipart/form-data": _parse_multipart, - "application/x-www-form-urlencoded": _parse_urlencoded, - "application/x-url-encoded": _parse_urlencoded, - } - - -def _line_parse(line: str) -> t.Tuple[str, bool]: - """Removes line ending characters and returns a tuple (`stripped_line`, - `is_terminated`). - """ - if line[-2:] == "\r\n": - return line[:-2], True + raise RequestEntityTooLarge() - elif line[-1:] in {"\r", "\n"}: - return line[:-1], True + try: + items = parse_qsl( + stream.read().decode(), + keep_blank_values=True, + errors="werkzeug.url_quote", + ) + except ValueError as e: + raise RequestEntityTooLarge() from e - return line, False + return stream, self.cls(items), self.cls() class MultiPartParser: def __init__( self, - stream_factory: t.Optional["TStreamFactory"] = None, - charset: str = "utf-8", - errors: str = "replace", - max_form_memory_size: t.Optional[int] = None, - cls: t.Optional[t.Type[MultiDict]] = None, + stream_factory: TStreamFactory | None = None, + max_form_memory_size: int | None = None, + cls: type[MultiDict[str, t.Any]] | None = None, buffer_size: int = 64 * 1024, + max_form_parts: int | None = None, ) -> None: - self.charset = charset - self.errors = errors self.max_form_memory_size = max_form_memory_size + self.max_form_parts = max_form_parts if stream_factory is None: stream_factory = default_stream_factory @@ -357,13 +311,12 @@ def __init__( self.stream_factory = stream_factory if cls is None: - cls = MultiDict + cls = t.cast("type[MultiDict[str, t.Any]]", MultiDict) self.cls = cls - self.buffer_size = buffer_size - def fail(self, message: str) -> "te.NoReturn": + def fail(self, message: str) -> te.NoReturn: raise ValueError(message) def get_part_charset(self, headers: Headers) -> str: @@ -371,18 +324,23 @@ def get_part_charset(self, headers: Headers) -> str: content_type = headers.get("content-type") if content_type: - mimetype, ct_params = parse_options_header(content_type) - return ct_params.get("charset", self.charset) + parameters = parse_options_header(content_type)[1] + ct_charset = parameters.get("charset", "").lower() + + # A safe list of encodings. Modern clients should only send ASCII or UTF-8. + # This list will not be extended further. + if ct_charset in {"ascii", "us-ascii", "utf-8", "iso-8859-1"}: + return ct_charset - return self.charset + return "utf-8" def start_file_streaming( - self, event: File, total_content_length: t.Optional[int] + self, event: File, total_content_length: int | None ) -> t.IO[bytes]: content_type = event.headers.get("content-type") try: - content_length = int(event.headers["content-length"]) + content_length = _plain_int(event.headers["content-length"]) except (KeyError, ValueError): content_length = 0 @@ -395,27 +353,22 @@ def start_file_streaming( return container def parse( - self, stream: t.IO[bytes], boundary: bytes, content_length: t.Optional[int] - ) -> t.Tuple[MultiDict, MultiDict]: - container: t.Union[t.IO[bytes], t.List[bytes]] + self, stream: t.IO[bytes], boundary: bytes, content_length: int | None + ) -> tuple[MultiDict[str, str], MultiDict[str, FileStorage]]: + current_part: Field | File + container: t.IO[bytes] | list[bytes] _write: t.Callable[[bytes], t.Any] - iterator = chain( - _make_chunk_iter( - stream, - limit=content_length, - buffer_size=self.buffer_size, - ), - [None], + parser = MultipartDecoder( + boundary, + max_form_memory_size=self.max_form_memory_size, + max_parts=self.max_form_parts, ) - parser = MultipartDecoder(boundary, self.max_form_memory_size) - fields = [] files = [] - current_part: Union[Field, File] - for data in iterator: + for data in _chunk_iter(stream.read, self.buffer_size): parser.receive_data(data) event = parser.next_event() while not isinstance(event, (Epilogue, NeedData)): @@ -432,7 +385,7 @@ def parse( if not event.more_data: if isinstance(current_part, Field): value = b"".join(container).decode( - self.get_part_charset(current_part.headers), self.errors + self.get_part_charset(current_part.headers), "replace" ) fields.append((current_part.name, value)) else: @@ -453,3 +406,18 @@ def parse( event = parser.next_event() return self.cls(fields), self.cls(files) + + +def _chunk_iter(read: t.Callable[[int], bytes], size: int) -> t.Iterator[bytes | None]: + """Read data in chunks for multipart/form-data parsing. Stop if no data is read. + Yield ``None`` at the end to signal end of parsing. + """ + while True: + data = read(size) + + if not data: + break + + yield data + + yield None diff --git a/src/werkzeug/http.py b/src/werkzeug/http.py index 9777685..27fa9af 100644 --- a/src/werkzeug/http.py +++ b/src/werkzeug/http.py @@ -1,7 +1,7 @@ -import base64 +from __future__ import annotations + import email.utils import re -import typing import typing as t import warnings from datetime import date @@ -13,74 +13,20 @@ from hashlib import sha1 from time import mktime from time import struct_time -from urllib.parse import unquote_to_bytes as _unquote +from urllib.parse import quote +from urllib.parse import unquote from urllib.request import parse_http_list as _parse_list_header -from ._internal import _cookie_quote from ._internal import _dt_as_utc -from ._internal import _make_cookie_domain -from ._internal import _to_bytes -from ._internal import _to_str -from ._internal import _wsgi_decoding_dance +from ._internal import _plain_int if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment -# for explanation of "media-range", etc. see Sections 5.3.{1,2} of RFC 7231 -_accept_re = re.compile( - r""" - ( # media-range capturing-parenthesis - [^\s;,]+ # type/subtype - (?:[ \t]*;[ \t]* # ";" - (?: # parameter non-capturing-parenthesis - [^\s;,q][^\s;,]* # token that doesn't start with "q" - | # or - q[^\s;,=][^\s;,]* # token that is more than just "q" - ) - )* # zero or more parameters - ) # end of media-range - (?:[ \t]*;[ \t]*q= # weight is a "q" parameter - (\d*(?:\.\d+)?) # qvalue capturing-parentheses - [^,]* # "extension" accept params: who cares? - )? # accept params are optional - """, - re.VERBOSE, -) _token_chars = frozenset( "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~" ) _etag_re = re.compile(r'([Ww]/)?(?:"(.*?)"|(.*?))(?:\s*,\s*|$)') -_option_header_piece_re = re.compile( - r""" - ;\s*,?\s* # newlines were replaced with commas - (?P - "[^"\\]*(?:\\.[^"\\]*)*" # quoted string - | - [^\s;,=*]+ # token - ) - (?:\*(?P\d+))? # *1, optional continuation index - \s* - (?: # optionally followed by =value - (?: # equals sign, possibly with encoding - \*\s*=\s* # * indicates extended notation - (?: # optional encoding - (?P[^\s]+?) - '(?P[^\s]*?)' - )? - | - =\s* # basic notation - ) - (?P - "[^"\\]*(?:\\.[^"\\]*)*" # quoted string - | - [^;,]+ # token - )? - )? - \s* - """, - flags=re.VERBOSE, -) -_option_header_start_mime_type = re.compile(r",\s*([^;,\s]+)([;,]\s*.+)?") _entity_headers = frozenset( [ "allow", @@ -190,108 +136,155 @@ class COOP(Enum): SAME_ORIGIN = "same-origin" -def quote_header_value( - value: t.Union[str, int], extra_chars: str = "", allow_token: bool = True -) -> str: - """Quote a header value if necessary. +def quote_header_value(value: t.Any, allow_token: bool = True) -> str: + """Add double quotes around a header value. If the header contains only ASCII token + characters, it will be returned unchanged. If the header contains ``"`` or ``\\`` + characters, they will be escaped with an additional ``\\`` character. - .. versionadded:: 0.5 + This is the reverse of :func:`unquote_header_value`. + + :param value: The value to quote. Will be converted to a string. + :param allow_token: Disable to quote the value even if it only has token characters. + + .. versionchanged:: 3.0 + Passing bytes is not supported. - :param value: the value to quote. - :param extra_chars: a list of extra characters to skip quoting. - :param allow_token: if this is enabled token values are returned - unchanged. + .. versionchanged:: 3.0 + The ``extra_chars`` parameter is removed. + + .. versionchanged:: 2.3 + The value is quoted if it is the empty string. + + .. versionadded:: 0.5 """ - if isinstance(value, bytes): - value = value.decode("latin1") - value = str(value) + value_str = str(value) + + if not value_str: + return '""' + if allow_token: - token_chars = _token_chars | set(extra_chars) - if set(value).issubset(token_chars): - return value - value = value.replace("\\", "\\\\").replace('"', '\\"') - return f'"{value}"' + token_chars = _token_chars + if token_chars.issuperset(value_str): + return value_str -def unquote_header_value(value: str, is_filename: bool = False) -> str: - r"""Unquotes a header value. (Reversal of :func:`quote_header_value`). - This does not use the real unquoting but what browsers are actually - using for quoting. + value_str = value_str.replace("\\", "\\\\").replace('"', '\\"') + return f'"{value_str}"' - .. versionadded:: 0.5 - :param value: the header value to unquote. - :param is_filename: The value represents a filename or path. +def unquote_header_value(value: str) -> str: + """Remove double quotes and decode slash-escaped ``"`` and ``\\`` characters in a + header value. + + This is the reverse of :func:`quote_header_value`. + + :param value: The header value to unquote. + + .. versionchanged:: 3.0 + The ``is_filename`` parameter is removed. """ - if value and value[0] == value[-1] == '"': - # this is not the real unquoting, but fixing this so that the - # RFC is met will result in bugs with internet explorer and - # probably some other browsers as well. IE for example is - # uploading files with "C:\foo\bar.txt" as filename + if len(value) >= 2 and value[0] == value[-1] == '"': value = value[1:-1] + return value.replace("\\\\", "\\").replace('\\"', '"') - # if this is a filename and the starting characters look like - # a UNC path, then just return the value without quotes. Using the - # replace sequence below on a UNC path has the effect of turning - # the leading double slash into a single slash and then - # _fix_ie_filename() doesn't work correctly. See #458. - if not is_filename or value[:2] != "\\\\": - return value.replace("\\\\", "\\").replace('\\"', '"') return value -def dump_options_header( - header: t.Optional[str], options: t.Mapping[str, t.Optional[t.Union[str, int]]] -) -> str: - """The reverse function to :func:`parse_options_header`. +def dump_options_header(header: str | None, options: t.Mapping[str, t.Any]) -> str: + """Produce a header value and ``key=value`` parameters separated by semicolons + ``;``. For example, the ``Content-Type`` header. + + .. code-block:: python + + dump_options_header("text/html", {"charset": "UTF-8"}) + 'text/html; charset=UTF-8' + + This is the reverse of :func:`parse_options_header`. + + If a value contains non-token characters, it will be quoted. + + If a value is ``None``, the parameter is skipped. - :param header: the header to dump - :param options: a dict of options to append. + In some keys for some headers, a UTF-8 value can be encoded using a special + ``key*=UTF-8''value`` form, where ``value`` is percent encoded. This function will + not produce that format automatically, but if a given key ends with an asterisk + ``*``, the value is assumed to have that form and will not be quoted further. + + :param header: The primary header value. + :param options: Parameters to encode as ``key=value`` pairs. + + .. versionchanged:: 2.3 + Keys with ``None`` values are skipped rather than treated as a bare key. + + .. versionchanged:: 2.2.3 + If a key ends with ``*``, its value will not be quoted. """ segments = [] + if header is not None: segments.append(header) + for key, value in options.items(): if value is None: - segments.append(key) + continue + + if key[-1] == "*": + segments.append(f"{key}={value}") else: segments.append(f"{key}={quote_header_value(value)}") + return "; ".join(segments) -def dump_header( - iterable: t.Union[t.Dict[str, t.Union[str, int]], t.Iterable[str]], - allow_token: bool = True, -) -> str: - """Dump an HTTP header again. This is the reversal of - :func:`parse_list_header`, :func:`parse_set_header` and - :func:`parse_dict_header`. This also quotes strings that include an - equals sign unless you pass it as dict of key, value pairs. - - >>> dump_header({'foo': 'bar baz'}) - 'foo="bar baz"' - >>> dump_header(('foo', 'bar baz')) - 'foo, "bar baz"' - - :param iterable: the iterable or dict of values to quote. - :param allow_token: if set to `False` tokens as values are disallowed. - See :func:`quote_header_value` for more details. +def dump_header(iterable: dict[str, t.Any] | t.Iterable[t.Any]) -> str: + """Produce a header value from a list of items or ``key=value`` pairs, separated by + commas ``,``. + + This is the reverse of :func:`parse_list_header`, :func:`parse_dict_header`, and + :func:`parse_set_header`. + + If a value contains non-token characters, it will be quoted. + + If a value is ``None``, the key is output alone. + + In some keys for some headers, a UTF-8 value can be encoded using a special + ``key*=UTF-8''value`` form, where ``value`` is percent encoded. This function will + not produce that format automatically, but if a given key ends with an asterisk + ``*``, the value is assumed to have that form and will not be quoted further. + + .. code-block:: python + + dump_header(["foo", "bar baz"]) + 'foo, "bar baz"' + + dump_header({"foo": "bar baz"}) + 'foo="bar baz"' + + :param iterable: The items to create a header from. + + .. versionchanged:: 3.0 + The ``allow_token`` parameter is removed. + + .. versionchanged:: 2.2.3 + If a key ends with ``*``, its value will not be quoted. """ if isinstance(iterable, dict): items = [] + for key, value in iterable.items(): if value is None: items.append(key) + elif key[-1] == "*": + items.append(f"{key}={value}") else: - items.append( - f"{key}={quote_header_value(value, allow_token=allow_token)}" - ) + items.append(f"{key}={quote_header_value(value)}") else: - items = [quote_header_value(x, allow_token=allow_token) for x in iterable] + items = [quote_header_value(x) for x in iterable] + return ", ".join(items) -def dump_csp_header(header: "ds.ContentSecurityPolicy") -> str: +def dump_csp_header(header: ds.ContentSecurityPolicy) -> str: """Dump a Content Security Policy header. These are structured into policies such as "default-src 'self'; @@ -304,187 +297,285 @@ def dump_csp_header(header: "ds.ContentSecurityPolicy") -> str: return "; ".join(f"{key} {value}" for key, value in header.items()) -def parse_list_header(value: str) -> t.List[str]: - """Parse lists as described by RFC 2068 Section 2. +def parse_list_header(value: str) -> list[str]: + """Parse a header value that consists of a list of comma separated items according + to `RFC 9110 `__. - In particular, parse comma-separated lists where the elements of - the list may include quoted-strings. A quoted-string could - contain a comma. A non-quoted string could have quotes in the - middle. Quotes are removed automatically after parsing. + This extends :func:`urllib.request.parse_http_list` to remove surrounding quotes + from values. - It basically works like :func:`parse_set_header` just that items - may appear multiple times and case sensitivity is preserved. + .. code-block:: python - The return value is a standard :class:`list`: + parse_list_header('token, "quoted value"') + ['token', 'quoted value'] - >>> parse_list_header('token, "quoted value"') - ['token', 'quoted value'] + This is the reverse of :func:`dump_header`. - To create a header from the :class:`list` again, use the - :func:`dump_header` function. - - :param value: a string with a list header. - :return: :class:`list` + :param value: The header value to parse. """ result = [] + for item in _parse_list_header(value): - if item[:1] == item[-1:] == '"': - item = unquote_header_value(item[1:-1]) + if len(item) >= 2 and item[0] == item[-1] == '"': + item = item[1:-1] + result.append(item) + return result -def parse_dict_header(value: str, cls: t.Type[dict] = dict) -> t.Dict[str, str]: - """Parse lists of key, value pairs as described by RFC 2068 Section 2 and - convert them into a python dict (or any other mapping object created from - the type with a dict like interface provided by the `cls` argument): +def parse_dict_header(value: str) -> dict[str, str | None]: + """Parse a list header using :func:`parse_list_header`, then parse each item as a + ``key=value`` pair. - >>> d = parse_dict_header('foo="is a fish", bar="as well"') - >>> type(d) is dict - True - >>> sorted(d.items()) - [('bar', 'as well'), ('foo', 'is a fish')] + .. code-block:: python - If there is no value for a key it will be `None`: + parse_dict_header('a=b, c="d, e", f') + {"a": "b", "c": "d, e", "f": None} - >>> parse_dict_header('key_without_value') - {'key_without_value': None} + This is the reverse of :func:`dump_header`. - To create a header from the :class:`dict` again, use the - :func:`dump_header` function. + If a key does not have a value, it is ``None``. - .. versionchanged:: 0.9 - Added support for `cls` argument. + This handles charsets for values as described in + `RFC 2231 `__. Only ASCII, UTF-8, + and ISO-8859-1 charsets are accepted, otherwise the value remains quoted. - :param value: a string with a dict header. - :param cls: callable to use for storage of parsed results. - :return: an instance of `cls` + :param value: The header value to parse. + + .. versionchanged:: 3.0 + Passing bytes is not supported. + + .. versionchanged:: 3.0 + The ``cls`` argument is removed. + + .. versionchanged:: 2.3 + Added support for ``key*=charset''value`` encoded items. + + .. versionchanged:: 0.9 + The ``cls`` argument was added. """ - result = cls() - if isinstance(value, bytes): - value = value.decode("latin1") - for item in _parse_list_header(value): - if "=" not in item: - result[item] = None + result: dict[str, str | None] = {} + + for item in parse_list_header(value): + key, has_value, value = item.partition("=") + key = key.strip() + + if not has_value: + result[key] = None continue - name, value = item.split("=", 1) - if value[:1] == value[-1:] == '"': - value = unquote_header_value(value[1:-1]) - result[name] = value + + value = value.strip() + encoding: str | None = None + + if key[-1] == "*": + # key*=charset''value becomes key=value, where value is percent encoded + # adapted from parse_options_header, without the continuation handling + key = key[:-1] + match = _charset_value_re.match(value) + + if match: + # If there is a charset marker in the value, split it off. + encoding, value = match.groups() + encoding = encoding.lower() + + # A safe list of encodings. Modern clients should only send ASCII or UTF-8. + # This list will not be extended further. An invalid encoding will leave the + # value quoted. + if encoding in {"ascii", "us-ascii", "utf-8", "iso-8859-1"}: + # invalid bytes are replaced during unquoting + value = unquote(value, encoding=encoding) + + if len(value) >= 2 and value[0] == value[-1] == '"': + value = value[1:-1] + + result[key] = value + return result -def parse_options_header(value: t.Optional[str]) -> t.Tuple[str, t.Dict[str, str]]: - """Parse a ``Content-Type``-like header into a tuple with the - value and any options: +# https://httpwg.org/specs/rfc9110.html#parameter +_parameter_re = re.compile( + r""" + # don't match multiple empty parts, that causes backtracking + \s*;\s* # find the part delimiter + (?: + ([\w!#$%&'*+\-.^`|~]+) # key, one or more token chars + = # equals, with no space on either side + ( # value, token or quoted string + [\w!#$%&'*+\-.^`|~]+ # one or more token chars + | + "(?:\\\\|\\"|.)*?" # quoted string, consuming slash escapes + ) + )? # optionally match key=value, to account for empty parts + """, + re.ASCII | re.VERBOSE, +) +# https://www.rfc-editor.org/rfc/rfc2231#section-4 +_charset_value_re = re.compile( + r""" + ([\w!#$%&*+\-.^`|~]*)' # charset part, could be empty + [\w!#$%&*+\-.^`|~]*' # don't care about language part, usually empty + ([\w!#$%&'*+\-.^`|~]+) # one or more token chars with percent encoding + """, + re.ASCII | re.VERBOSE, +) +# https://www.rfc-editor.org/rfc/rfc2231#section-3 +_continuation_re = re.compile(r"\*(\d+)$", re.ASCII) + + +def parse_options_header(value: str | None) -> tuple[str, dict[str, str]]: + """Parse a header that consists of a value with ``key=value`` parameters separated + by semicolons ``;``. For example, the ``Content-Type`` header. + + .. code-block:: python + + parse_options_header("text/html; charset=UTF-8") + ('text/html', {'charset': 'UTF-8'}) + + parse_options_header("") + ("", {}) + + This is the reverse of :func:`dump_options_header`. + + This parses valid parameter parts as described in + `RFC 9110 `__. Invalid parts are + skipped. + + This handles continuations and charsets as described in + `RFC 2231 `__, although not as + strictly as the RFC. Only ASCII, UTF-8, and ISO-8859-1 charsets are accepted, + otherwise the value remains quoted. - >>> parse_options_header('text/html; charset=utf8') - ('text/html', {'charset': 'utf8'}) + Clients may not be consistent in how they handle a quote character within a quoted + value. The `HTML Standard `__ + replaces it with ``%22`` in multipart form data. + `RFC 9110 `__ uses backslash + escapes in HTTP headers. Both are decoded to the ``"`` character. - This should is not for ``Cache-Control``-like headers, which use a - different format. For those, use :func:`parse_dict_header`. + Clients may not be consistent in how they handle non-ASCII characters. HTML + documents must declare ````, otherwise browsers may replace with + HTML character references, which can be decoded using :func:`html.unescape`. :param value: The header value to parse. + :return: ``(value, options)``, where ``options`` is a dict + + .. versionchanged:: 2.3 + Invalid parts, such as keys with no value, quoted keys, and incorrectly quoted + values, are discarded instead of treating as ``None``. + + .. versionchanged:: 2.3 + Only ASCII, UTF-8, and ISO-8859-1 are accepted for charset values. + + .. versionchanged:: 2.3 + Escaped quotes in quoted values, like ``%22`` and ``\\"``, are handled. .. versionchanged:: 2.2 Option names are always converted to lowercase. - .. versionchanged:: 2.1 - The ``multiple`` parameter is deprecated and will be removed in - Werkzeug 2.2. + .. versionchanged:: 2.2 + The ``multiple`` parameter was removed. .. versionchanged:: 0.15 :rfc:`2231` parameter continuations are handled. .. versionadded:: 0.5 """ - if not value: + if value is None: return "", {} - result: t.List[t.Any] = [] + value, _, rest = value.partition(";") + value = value.strip() + rest = rest.strip() - value = "," + value.replace("\n", ",") - while value: - match = _option_header_start_mime_type.match(value) - if not match: - break - result.append(match.group(1)) # mimetype - options: t.Dict[str, str] = {} - # Parse options - rest = match.group(2) - encoding: t.Optional[str] - continued_encoding: t.Optional[str] = None - while rest: - optmatch = _option_header_piece_re.match(rest) - if not optmatch: - break - option, count, encoding, language, option_value = optmatch.groups() - # Continuations don't have to supply the encoding after the - # first line. If we're in a continuation, track the current - # encoding to use for subsequent lines. Reset it when the - # continuation ends. - if not count: - continued_encoding = None - else: - if not encoding: - encoding = continued_encoding - continued_encoding = encoding - option = unquote_header_value(option).lower() + if not value or not rest: + # empty (invalid) value, or value without options + return value, {} - if option_value is not None: - option_value = unquote_header_value(option_value, option == "filename") + rest = f";{rest}" + options: dict[str, str] = {} + encoding: str | None = None + continued_encoding: str | None = None - if encoding is not None: - option_value = _unquote(option_value).decode(encoding) + for pk, pv in _parameter_re.findall(rest): + if not pk: + # empty or invalid part + continue - if count: - # Continuations append to the existing value. For - # simplicity, this ignores the possibility of - # out-of-order indices, which shouldn't happen anyway. - if option_value is not None: - options[option] = options.get(option, "") + option_value - else: - options[option] = option_value # type: ignore[assignment] + pk = pk.lower() + + if pk[-1] == "*": + # key*=charset''value becomes key=value, where value is percent encoded + pk = pk[:-1] + match = _charset_value_re.match(pv) + + if match: + # If there is a valid charset marker in the value, split it off. + encoding, pv = match.groups() + # This might be the empty string, handled next. + encoding = encoding.lower() + + # No charset marker, or marker with empty charset value. + if not encoding: + encoding = continued_encoding + + # A safe list of encodings. Modern clients should only send ASCII or UTF-8. + # This list will not be extended further. An invalid encoding will leave the + # value quoted. + if encoding in {"ascii", "us-ascii", "utf-8", "iso-8859-1"}: + # Continuation parts don't require their own charset marker. This is + # looser than the RFC, it will persist across different keys and allows + # changing the charset during a continuation. But this implementation is + # much simpler than tracking the full state. + continued_encoding = encoding + # invalid bytes are replaced during unquoting + pv = unquote(pv, encoding=encoding) + + # Remove quotes. At this point the value cannot be empty or a single quote. + if pv[0] == pv[-1] == '"': + # HTTP headers use slash, multipart form data uses percent + pv = pv[1:-1].replace("\\\\", "\\").replace('\\"', '"').replace("%22", '"') - rest = rest[optmatch.end() :] - result.append(options) - return tuple(result) # type: ignore[return-value] + match = _continuation_re.search(pk) + + if match: + # key*0=a; key*1=b becomes key=ab + pk = pk[: match.start()] + options[pk] = options.get(pk, "") + pv + else: + options[pk] = pv - return tuple(result) if result else ("", {}) # type: ignore[return-value] + return value, options +_q_value_re = re.compile(r"-?\d+(\.\d+)?", re.ASCII) _TAnyAccept = t.TypeVar("_TAnyAccept", bound="ds.Accept") -@typing.overload -def parse_accept_header(value: t.Optional[str]) -> "ds.Accept": - ... +@t.overload +def parse_accept_header(value: str | None) -> ds.Accept: ... -@typing.overload -def parse_accept_header( - value: t.Optional[str], cls: t.Type[_TAnyAccept] -) -> _TAnyAccept: - ... +@t.overload +def parse_accept_header(value: str | None, cls: type[_TAnyAccept]) -> _TAnyAccept: ... def parse_accept_header( - value: t.Optional[str], cls: t.Optional[t.Type[_TAnyAccept]] = None + value: str | None, cls: type[_TAnyAccept] | None = None ) -> _TAnyAccept: - """Parses an HTTP Accept-* header. This does not implement a complete - valid algorithm but one that supports at least value and quality - extraction. + """Parse an ``Accept`` header according to + `RFC 9110 `__. - Returns a new :class:`Accept` object (basically a list of ``(value, quality)`` - tuples sorted by the quality with some additional accessor methods). + Returns an :class:`.Accept` instance, which can sort and inspect items based on + their quality parameter. When parsing ``Accept-Charset``, ``Accept-Encoding``, or + ``Accept-Language``, pass the appropriate :class:`.Accept` subclass. - The second parameter can be a subclass of :class:`Accept` that is created - with the parsed values and returned. + :param value: The header value to parse. + :param cls: The :class:`.Accept` class to wrap the result in. + :return: An instance of ``cls``. - :param value: the accept header string to be parsed. - :param cls: the wrapper class for the return value (can be - :class:`Accept` or a subclass thereof) - :return: an instance of `cls`. + .. versionchanged:: 2.3 + Parse according to RFC 9110. Items with invalid ``q`` values are skipped. """ if cls is None: cls = t.cast(t.Type[_TAnyAccept], ds.Accept) @@ -493,38 +584,57 @@ def parse_accept_header( return cls(None) result = [] - for match in _accept_re.finditer(value): - quality_match = match.group(2) - if not quality_match: - quality: float = 1 + + for item in parse_list_header(value): + item, options = parse_options_header(item) + + if "q" in options: + # pop q, remaining options are reconstructed + q_str = options.pop("q").strip() + + if _q_value_re.fullmatch(q_str) is None: + # ignore an invalid q + continue + + q = float(q_str) + + if q < 0 or q > 1: + # ignore an invalid q + continue else: - quality = max(min(float(quality_match), 1), 0) - result.append((match.group(1), quality)) + q = 1 + + if options: + # reconstruct the media type with any options + item = dump_options_header(item, options) + + result.append((item, q)) + return cls(result) -_TAnyCC = t.TypeVar("_TAnyCC", bound="ds._CacheControl") -_t_cc_update = t.Optional[t.Callable[[_TAnyCC], None]] +_TAnyCC = t.TypeVar("_TAnyCC", bound="ds.cache_control._CacheControl") -@typing.overload +@t.overload def parse_cache_control_header( - value: t.Optional[str], on_update: _t_cc_update, cls: None = None -) -> "ds.RequestCacheControl": - ... + value: str | None, + on_update: t.Callable[[ds.cache_control._CacheControl], None] | None = None, +) -> ds.RequestCacheControl: ... -@typing.overload +@t.overload def parse_cache_control_header( - value: t.Optional[str], on_update: _t_cc_update, cls: t.Type[_TAnyCC] -) -> _TAnyCC: - ... + value: str | None, + on_update: t.Callable[[ds.cache_control._CacheControl], None] | None = None, + cls: type[_TAnyCC] = ..., +) -> _TAnyCC: ... def parse_cache_control_header( - value: t.Optional[str], - on_update: _t_cc_update = None, - cls: t.Optional[t.Type[_TAnyCC]] = None, + value: str | None, + on_update: t.Callable[[ds.cache_control._CacheControl], None] | None = None, + cls: type[_TAnyCC] | None = None, ) -> _TAnyCC: """Parse a cache control header. The RFC differs between response and request cache control, this method does not. It's your responsibility @@ -543,7 +653,7 @@ def parse_cache_control_header( :return: a `cls` object. """ if cls is None: - cls = t.cast(t.Type[_TAnyCC], ds.RequestCacheControl) + cls = t.cast("type[_TAnyCC]", ds.RequestCacheControl) if not value: return cls((), on_update) @@ -552,27 +662,27 @@ def parse_cache_control_header( _TAnyCSP = t.TypeVar("_TAnyCSP", bound="ds.ContentSecurityPolicy") -_t_csp_update = t.Optional[t.Callable[[_TAnyCSP], None]] -@typing.overload +@t.overload def parse_csp_header( - value: t.Optional[str], on_update: _t_csp_update, cls: None = None -) -> "ds.ContentSecurityPolicy": - ... + value: str | None, + on_update: t.Callable[[ds.ContentSecurityPolicy], None] | None = None, +) -> ds.ContentSecurityPolicy: ... -@typing.overload +@t.overload def parse_csp_header( - value: t.Optional[str], on_update: _t_csp_update, cls: t.Type[_TAnyCSP] -) -> _TAnyCSP: - ... + value: str | None, + on_update: t.Callable[[ds.ContentSecurityPolicy], None] | None = None, + cls: type[_TAnyCSP] = ..., +) -> _TAnyCSP: ... def parse_csp_header( - value: t.Optional[str], - on_update: _t_csp_update = None, - cls: t.Optional[t.Type[_TAnyCSP]] = None, + value: str | None, + on_update: t.Callable[[ds.ContentSecurityPolicy], None] | None = None, + cls: type[_TAnyCSP] | None = None, ) -> _TAnyCSP: """Parse a Content Security Policy header. @@ -587,7 +697,7 @@ def parse_csp_header( :return: a `cls` object. """ if cls is None: - cls = t.cast(t.Type[_TAnyCSP], ds.ContentSecurityPolicy) + cls = t.cast("type[_TAnyCSP]", ds.ContentSecurityPolicy) if value is None: return cls((), on_update) @@ -606,9 +716,9 @@ def parse_csp_header( def parse_set_header( - value: t.Optional[str], - on_update: t.Optional[t.Callable[["ds.HeaderSet"], None]] = None, -) -> "ds.HeaderSet": + value: str | None, + on_update: t.Callable[[ds.HeaderSet], None] | None = None, +) -> ds.HeaderSet: """Parse a set-like header and return a :class:`~werkzeug.datastructures.HeaderSet` object: @@ -638,76 +748,7 @@ def parse_set_header( return ds.HeaderSet(parse_list_header(value), on_update) -def parse_authorization_header( - value: t.Optional[str], -) -> t.Optional["ds.Authorization"]: - """Parse an HTTP basic/digest authorization header transmitted by the web - browser. The return value is either `None` if the header was invalid or - not given, otherwise an :class:`~werkzeug.datastructures.Authorization` - object. - - :param value: the authorization header to parse. - :return: a :class:`~werkzeug.datastructures.Authorization` object or `None`. - """ - if not value: - return None - value = _wsgi_decoding_dance(value) - try: - auth_type, auth_info = value.split(None, 1) - auth_type = auth_type.lower() - except ValueError: - return None - if auth_type == "basic": - try: - username, password = base64.b64decode(auth_info).split(b":", 1) - except Exception: - return None - try: - return ds.Authorization( - "basic", - { - "username": _to_str(username, "utf-8"), - "password": _to_str(password, "utf-8"), - }, - ) - except UnicodeDecodeError: - return None - elif auth_type == "digest": - auth_map = parse_dict_header(auth_info) - for key in "username", "realm", "nonce", "uri", "response": - if key not in auth_map: - return None - if "qop" in auth_map: - if not auth_map.get("nc") or not auth_map.get("cnonce"): - return None - return ds.Authorization("digest", auth_map) - return None - - -def parse_www_authenticate_header( - value: t.Optional[str], - on_update: t.Optional[t.Callable[["ds.WWWAuthenticate"], None]] = None, -) -> "ds.WWWAuthenticate": - """Parse an HTTP WWW-Authenticate header into a - :class:`~werkzeug.datastructures.WWWAuthenticate` object. - - :param value: a WWW-Authenticate header to parse. - :param on_update: an optional callable that is called every time a value - on the :class:`~werkzeug.datastructures.WWWAuthenticate` - object is changed. - :return: a :class:`~werkzeug.datastructures.WWWAuthenticate` object. - """ - if not value: - return ds.WWWAuthenticate(on_update=on_update) - try: - auth_type, auth_info = value.split(None, 1) - auth_type = auth_type.lower() - except (ValueError, AttributeError): - return ds.WWWAuthenticate(value.strip().lower(), on_update=on_update) - return ds.WWWAuthenticate(auth_type, parse_dict_header(auth_info), on_update) - - -def parse_if_range_header(value: t.Optional[str]) -> "ds.IfRange": +def parse_if_range_header(value: str | None) -> ds.IfRange: """Parses an if-range header which can be an etag or a date. Returns a :class:`~werkzeug.datastructures.IfRange` object. @@ -726,8 +767,8 @@ def parse_if_range_header(value: t.Optional[str]) -> "ds.IfRange": def parse_range_header( - value: t.Optional[str], make_inclusive: bool = True -) -> t.Optional["ds.Range"]: + value: str | None, make_inclusive: bool = True +) -> ds.Range | None: """Parses a range header into a :class:`~werkzeug.datastructures.Range` object. If the header is missing or malformed `None` is returned. `ranges` is a list of ``(start, stop)`` tuples where the ranges are @@ -751,7 +792,7 @@ def parse_range_header( if last_end < 0: return None try: - begin = int(item) + begin = _plain_int(item) except ValueError: return None end = None @@ -762,7 +803,7 @@ def parse_range_header( end_str = end_str.strip() try: - begin = int(begin_str) + begin = _plain_int(begin_str) except ValueError: return None @@ -770,7 +811,7 @@ def parse_range_header( return None if end_str: try: - end = int(end_str) + 1 + end = _plain_int(end_str) + 1 except ValueError: return None @@ -785,9 +826,9 @@ def parse_range_header( def parse_content_range_header( - value: t.Optional[str], - on_update: t.Optional[t.Callable[["ds.ContentRange"], None]] = None, -) -> t.Optional["ds.ContentRange"]: + value: str | None, + on_update: t.Callable[[ds.ContentRange], None] | None = None, +) -> ds.ContentRange | None: """Parses a range header into a :class:`~werkzeug.datastructures.ContentRange` object or `None` if parsing is not possible. @@ -813,19 +854,22 @@ def parse_content_range_header( length = None else: try: - length = int(length_str) + length = _plain_int(length_str) except ValueError: return None if rng == "*": + if not is_byte_range_valid(None, None, length): + return None + return ds.ContentRange(units, None, None, length, on_update=on_update) elif "-" not in rng: return None start_str, stop_str = rng.split("-", 1) try: - start = int(start_str) - stop = int(stop_str) + 1 + start = _plain_int(start_str) + stop = _plain_int(stop_str) + 1 except ValueError: return None @@ -850,8 +894,8 @@ def quote_etag(etag: str, weak: bool = False) -> str: def unquote_etag( - etag: t.Optional[str], -) -> t.Union[t.Tuple[str, bool], t.Tuple[None, None]]: + etag: str | None, +) -> tuple[str, bool] | tuple[None, None]: """Unquote a single etag: >>> unquote_etag('W/"bar"') @@ -874,7 +918,7 @@ def unquote_etag( return etag, weak -def parse_etags(value: t.Optional[str]) -> "ds.ETags": +def parse_etags(value: str | None) -> ds.ETags: """Parse an etag header. :param value: the tag header to parse @@ -912,7 +956,7 @@ def generate_etag(data: bytes) -> str: return sha1(data).hexdigest() -def parse_date(value: t.Optional[str]) -> t.Optional[datetime]: +def parse_date(value: str | None) -> datetime | None: """Parse an :rfc:`2822` date into a timezone-aware :class:`datetime.datetime` object, or ``None`` if parsing fails. @@ -942,7 +986,7 @@ def parse_date(value: t.Optional[str]) -> t.Optional[datetime]: def http_date( - timestamp: t.Optional[t.Union[datetime, date, int, float, struct_time]] = None + timestamp: datetime | date | int | float | struct_time | None = None, ) -> str: """Format a datetime object or timestamp into an :rfc:`2822` date string. @@ -973,7 +1017,7 @@ def http_date( return email.utils.formatdate(timestamp, usegmt=True) -def parse_age(value: t.Optional[str] = None) -> t.Optional[timedelta]: +def parse_age(value: str | None = None) -> timedelta | None: """Parses a base-10 integer count of seconds into a timedelta. If parsing fails, the return value is `None`. @@ -995,7 +1039,7 @@ def parse_age(value: t.Optional[str] = None) -> t.Optional[timedelta]: return None -def dump_age(age: t.Optional[t.Union[timedelta, int]] = None) -> t.Optional[str]: +def dump_age(age: timedelta | int | None = None) -> str | None: """Formats the duration as a base-10 integer. :param age: should be an integer number of seconds, @@ -1016,10 +1060,10 @@ def dump_age(age: t.Optional[t.Union[timedelta, int]] = None) -> t.Optional[str] def is_resource_modified( - environ: "WSGIEnvironment", - etag: t.Optional[str] = None, - data: t.Optional[bytes] = None, - last_modified: t.Optional[t.Union[datetime, str]] = None, + environ: WSGIEnvironment, + etag: str | None = None, + data: bytes | None = None, + last_modified: datetime | str | None = None, ignore_if_range: bool = True, ) -> bool: """Convenience method for conditional requests. @@ -1054,7 +1098,7 @@ def is_resource_modified( def remove_entity_headers( - headers: t.Union["ds.Headers", t.List[t.Tuple[str, str]]], + headers: ds.Headers | list[tuple[str, str]], allowed: t.Iterable[str] = ("expires", "content-location"), ) -> None: """Remove all entity headers from a list or :class:`Headers` object. This @@ -1077,9 +1121,7 @@ def remove_entity_headers( ] -def remove_hop_by_hop_headers( - headers: t.Union["ds.Headers", t.List[t.Tuple[str, str]]] -) -> None: +def remove_hop_by_hop_headers(headers: ds.Headers | list[tuple[str, str]]) -> None: """Remove all HTTP/1.1 "Hop-by-Hop" headers from a list or :class:`Headers` object. This operation works in-place. @@ -1115,11 +1157,9 @@ def is_hop_by_hop_header(header: str) -> bool: def parse_cookie( - header: t.Union["WSGIEnvironment", str, bytes, None], - charset: str = "utf-8", - errors: str = "replace", - cls: t.Optional[t.Type["ds.MultiDict"]] = None, -) -> "ds.MultiDict[str, str]": + header: WSGIEnvironment | str | None, + cls: type[ds.MultiDict[str, str]] | None = None, +) -> ds.MultiDict[str, str]: """Parse a cookie from a string or WSGI environ. The same key can be provided multiple times, the values are stored @@ -1129,44 +1169,51 @@ def parse_cookie( :param header: The cookie header as a string, or a WSGI environ dict with a ``HTTP_COOKIE`` key. - :param charset: The charset for the cookie values. - :param errors: The error behavior for the charset decoding. :param cls: A dict-like class to store the parsed cookies in. Defaults to :class:`MultiDict`. - .. versionchanged:: 1.0.0 - Returns a :class:`MultiDict` instead of a - ``TypeConversionDict``. + .. versionchanged:: 3.0 + Passing bytes, and the ``charset`` and ``errors`` parameters, were removed. + + .. versionchanged:: 1.0 + Returns a :class:`MultiDict` instead of a ``TypeConversionDict``. .. versionchanged:: 0.5 - Returns a :class:`TypeConversionDict` instead of a regular dict. - The ``cls`` parameter was added. + Returns a :class:`TypeConversionDict` instead of a regular dict. The ``cls`` + parameter was added. """ if isinstance(header, dict): - cookie = header.get("HTTP_COOKIE", "") - elif header is None: - cookie = "" + cookie = header.get("HTTP_COOKIE") else: cookie = header - return _sansio_http.parse_cookie( - cookie=cookie, charset=charset, errors=errors, cls=cls - ) + if cookie: + cookie = cookie.encode("latin1").decode() + + return _sansio_http.parse_cookie(cookie=cookie, cls=cls) + + +_cookie_no_quote_re = re.compile(r"[\w!#$%&'()*+\-./:<=>?@\[\]^`{|}~]*", re.A) +_cookie_slash_re = re.compile(rb"[\x00-\x19\",;\\\x7f-\xff]", re.A) +_cookie_slash_map = {b'"': b'\\"', b"\\": b"\\\\"} +_cookie_slash_map.update( + (v.to_bytes(1, "big"), b"\\%03o" % v) + for v in [*range(0x20), *b",;", *range(0x7F, 256)] +) def dump_cookie( key: str, - value: t.Union[bytes, str] = "", - max_age: t.Optional[t.Union[timedelta, int]] = None, - expires: t.Optional[t.Union[str, datetime, int, float]] = None, - path: t.Optional[str] = "/", - domain: t.Optional[str] = None, + value: str = "", + max_age: timedelta | int | None = None, + expires: str | datetime | int | float | None = None, + path: str | None = "/", + domain: str | None = None, secure: bool = False, httponly: bool = False, - charset: str = "utf-8", sync_expires: bool = True, max_size: int = 4093, - samesite: t.Optional[str] = None, + samesite: str | None = None, ) -> str: """Create a Set-Cookie header without the ``Set-Cookie`` prefix. @@ -1187,7 +1234,7 @@ def dump_cookie( :param path: limits the cookie to a given path, per default it will span the whole domain. :param domain: Use this if you want to set a cross-domain cookie. For - example, ``domain=".example.com"`` will set a cookie + example, ``domain="example.com"`` will set a cookie that is readable by the domain ``www.example.com``, ``foo.example.com`` etc. Otherwise, a cookie will only be readable by the domain that set it. @@ -1206,18 +1253,33 @@ def dump_cookie( .. _`cookie`: http://browsercookielimits.squawky.net/ + .. versionchanged:: 3.0 + Passing bytes, and the ``charset`` parameter, were removed. + + .. versionchanged:: 2.3.3 + The ``path`` parameter is ``/`` by default. + + .. versionchanged:: 2.3.1 + The value allows more characters without quoting. + + .. versionchanged:: 2.3 + ``localhost`` and other names without a dot are allowed for the domain. A + leading dot is ignored. + + .. versionchanged:: 2.3 + The ``path`` parameter is ``None`` by default. + .. versionchanged:: 1.0.0 The string ``'None'`` is accepted for ``samesite``. """ - key = _to_bytes(key, charset) - value = _to_bytes(value, charset) - if path is not None: - from .urls import iri_to_uri - - path = iri_to_uri(path, charset) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + # as well as percent for things that are already quoted + # excluding semicolon since it's part of the header syntax + path = quote(path, safe="%!$&'()*+,/:=@") - domain = _make_cookie_domain(domain) + if domain: + domain = domain.partition(":")[0].lstrip(".").encode("idna").decode("ascii") if isinstance(max_age, timedelta): max_age = int(max_age.total_seconds()) @@ -1234,54 +1296,51 @@ def dump_cookie( if samesite not in {"Strict", "Lax", "None"}: raise ValueError("SameSite must be 'Strict', 'Lax', or 'None'.") - buf = [key + b"=" + _cookie_quote(value)] - - # XXX: In theory all of these parameters that are not marked with `None` - # should be quoted. Because stdlib did not quote it before I did not - # want to introduce quoting there now. - for k, v, q in ( - (b"Domain", domain, True), - (b"Expires", expires, False), - (b"Max-Age", max_age, False), - (b"Secure", secure, None), - (b"HttpOnly", httponly, None), - (b"Path", path, False), - (b"SameSite", samesite, False), + # Quote value if it contains characters not allowed by RFC 6265. Slash-escape with + # three octal digits, which matches http.cookies, although the RFC suggests base64. + if not _cookie_no_quote_re.fullmatch(value): + # Work with bytes here, since a UTF-8 character could be multiple bytes. + value = _cookie_slash_re.sub( + lambda m: _cookie_slash_map[m.group()], value.encode() + ).decode("ascii") + value = f'"{value}"' + + # Send a non-ASCII key as mojibake. Everything else should already be ASCII. + # TODO Remove encoding dance, it seems like clients accept UTF-8 keys + buf = [f"{key.encode().decode('latin1')}={value}"] + + for k, v in ( + ("Domain", domain), + ("Expires", expires), + ("Max-Age", max_age), + ("Secure", secure), + ("HttpOnly", httponly), + ("Path", path), + ("SameSite", samesite), ): - if q is None: - if v: - buf.append(k) + if v is None or v is False: continue - if v is None: + if v is True: + buf.append(k) continue - tmp = bytearray(k) - if not isinstance(v, (bytes, bytearray)): - v = _to_bytes(str(v), charset) - if q: - v = _cookie_quote(v) - tmp += b"=" + v - buf.append(bytes(tmp)) - - # The return value will be an incorrectly encoded latin1 header for - # consistency with the headers object. - rv = b"; ".join(buf) - rv = rv.decode("latin1") - - # Warn if the final value of the cookie is larger than the limit. If the - # cookie is too large, then it may be silently ignored by the browser, - # which can be quite hard to debug. + buf.append(f"{k}={v}") + + rv = "; ".join(buf) + + # Warn if the final value of the cookie is larger than the limit. If the cookie is + # too large, then it may be silently ignored by the browser, which can be quite hard + # to debug. cookie_size = len(rv) if max_size and cookie_size > max_size: value_size = len(value) warnings.warn( - f"The {key.decode(charset)!r} cookie is too large: the value was" - f" {value_size} bytes but the" + f"The '{key}' cookie is too large: the value was {value_size} bytes but the" f" header required {cookie_size - value_size} extra bytes. The final size" f" was {cookie_size} bytes but the limit is {max_size} bytes. Browsers may" - f" silently ignore cookies larger than this.", + " silently ignore cookies larger than this.", stacklevel=2, ) @@ -1289,7 +1348,7 @@ def dump_cookie( def is_byte_range_valid( - start: t.Optional[int], stop: t.Optional[int], length: t.Optional[int] + start: int | None, stop: int | None, length: int | None ) -> bool: """Checks if a given byte content range is valid for the given length. diff --git a/src/werkzeug/local.py b/src/werkzeug/local.py index 70e9bf7..302589b 100644 --- a/src/werkzeug/local.py +++ b/src/werkzeug/local.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import math import operator @@ -18,7 +20,7 @@ F = t.TypeVar("F", bound=t.Callable[..., t.Any]) -def release_local(local: t.Union["Local", "LocalStack"]) -> None: +def release_local(local: Local | LocalStack[t.Any]) -> None: """Release the data for the current context in a :class:`Local` or :class:`LocalStack` without using a :class:`LocalManager`. @@ -49,9 +51,7 @@ class Local: __slots__ = ("__storage",) - def __init__( - self, context_var: t.Optional[ContextVar[t.Dict[str, t.Any]]] = None - ) -> None: + def __init__(self, context_var: ContextVar[dict[str, t.Any]] | None = None) -> None: if context_var is None: # A ContextVar not created at global scope interferes with # Python's garbage collection. However, a local only makes @@ -61,12 +61,12 @@ def __init__( object.__setattr__(self, "_Local__storage", context_var) - def __iter__(self) -> t.Iterator[t.Tuple[str, t.Any]]: + def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: return iter(self.__storage.get({}).items()) def __call__( - self, name: str, *, unbound_message: t.Optional[str] = None - ) -> "LocalProxy": + self, name: str, *, unbound_message: str | None = None + ) -> LocalProxy[t.Any]: """Create a :class:`LocalProxy` that access an attribute on this local namespace. @@ -124,7 +124,7 @@ class LocalStack(t.Generic[T]): __slots__ = ("_storage",) - def __init__(self, context_var: t.Optional[ContextVar[t.List[T]]] = None) -> None: + def __init__(self, context_var: ContextVar[list[T]] | None = None) -> None: if context_var is None: # A ContextVar not created at global scope interferes with # Python's garbage collection. However, a local only makes @@ -137,14 +137,14 @@ def __init__(self, context_var: t.Optional[ContextVar[t.List[T]]] = None) -> Non def __release_local__(self) -> None: self._storage.set([]) - def push(self, obj: T) -> t.List[T]: + def push(self, obj: T) -> list[T]: """Add a new item to the top of the stack.""" stack = self._storage.get([]).copy() stack.append(obj) self._storage.set(stack) return stack - def pop(self) -> t.Optional[T]: + def pop(self) -> T | None: """Remove the top item from the stack and return it. If the stack is empty, return ``None``. """ @@ -158,7 +158,7 @@ def pop(self) -> t.Optional[T]: return rv @property - def top(self) -> t.Optional[T]: + def top(self) -> T | None: """The topmost item on the stack. If the stack is empty, `None` is returned. """ @@ -170,8 +170,8 @@ def top(self) -> t.Optional[T]: return stack[-1] def __call__( - self, name: t.Optional[str] = None, *, unbound_message: t.Optional[str] = None - ) -> "LocalProxy": + self, name: str | None = None, *, unbound_message: str | None = None + ) -> LocalProxy[t.Any]: """Create a :class:`LocalProxy` that accesses the top of this local stack. @@ -192,9 +192,8 @@ class LocalManager: :param locals: A local or list of locals to manage. - .. versionchanged:: 2.0 - ``ident_func`` is deprecated and will be removed in Werkzeug - 2.1. + .. versionchanged:: 2.1 + The ``ident_func`` was removed. .. versionchanged:: 0.7 The ``ident_func`` parameter was added. @@ -208,9 +207,8 @@ class LocalManager: def __init__( self, - locals: t.Optional[ - t.Union[Local, LocalStack, t.Iterable[t.Union[Local, LocalStack]]] - ] = None, + locals: None + | (Local | LocalStack[t.Any] | t.Iterable[Local | LocalStack[t.Any]]) = None, ) -> None: if locals is None: self.locals = [] @@ -226,19 +224,19 @@ def cleanup(self) -> None: for local in self.locals: release_local(local) - def make_middleware(self, app: "WSGIApplication") -> "WSGIApplication": + def make_middleware(self, app: WSGIApplication) -> WSGIApplication: """Wrap a WSGI application so that local data is released automatically after the response has been sent for a request. """ def application( - environ: "WSGIEnvironment", start_response: "StartResponse" + environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: return ClosingIterator(app(environ, start_response), self.cleanup) return application - def middleware(self, func: "WSGIApplication") -> "WSGIApplication": + def middleware(self, func: WSGIApplication) -> WSGIApplication: """Like :meth:`make_middleware` but used as a decorator on the WSGI application function. @@ -274,24 +272,28 @@ class _ProxyLookup: def __init__( self, - f: t.Optional[t.Callable] = None, - fallback: t.Optional[t.Callable] = None, - class_value: t.Optional[t.Any] = None, + f: t.Callable[..., t.Any] | None = None, + fallback: t.Callable[[LocalProxy[t.Any]], t.Any] | None = None, + class_value: t.Any | None = None, is_attr: bool = False, ) -> None: - bind_f: t.Optional[t.Callable[["LocalProxy", t.Any], t.Callable]] + bind_f: t.Callable[[LocalProxy[t.Any], t.Any], t.Callable[..., t.Any]] | None if hasattr(f, "__get__"): # A Python function, can be turned into a bound method. - def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: + def bind_f( + instance: LocalProxy[t.Any], obj: t.Any + ) -> t.Callable[..., t.Any]: return f.__get__(obj, type(obj)) # type: ignore elif f is not None: # A C function, use partial to bind the first argument. - def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: - return partial(f, obj) # type: ignore + def bind_f( + instance: LocalProxy[t.Any], obj: t.Any + ) -> t.Callable[..., t.Any]: + return partial(f, obj) else: # Use getattr, which will produce a bound method. @@ -302,10 +304,10 @@ def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: self.class_value = class_value self.is_attr = is_attr - def __set_name__(self, owner: "LocalProxy", name: str) -> None: + def __set_name__(self, owner: LocalProxy[t.Any], name: str) -> None: self.name = name - def __get__(self, instance: "LocalProxy", owner: t.Optional[type] = None) -> t.Any: + def __get__(self, instance: LocalProxy[t.Any], owner: type | None = None) -> t.Any: if instance is None: if self.class_value is not None: return self.class_value @@ -313,7 +315,7 @@ def __get__(self, instance: "LocalProxy", owner: t.Optional[type] = None) -> t.A return self try: - obj = instance._get_current_object() # type: ignore[misc] + obj = instance._get_current_object() except RuntimeError: if self.fallback is None: raise @@ -335,7 +337,9 @@ def __get__(self, instance: "LocalProxy", owner: t.Optional[type] = None) -> t.A def __repr__(self) -> str: return f"proxy {self.name}" - def __call__(self, instance: "LocalProxy", *args: t.Any, **kwargs: t.Any) -> t.Any: + def __call__( + self, instance: LocalProxy[t.Any], *args: t.Any, **kwargs: t.Any + ) -> t.Any: """Support calling unbound methods from the class. For example, this happens with ``copy.copy``, which does ``type(x).__copy__(x)``. ``type(x)`` can't be proxied, so it @@ -352,12 +356,14 @@ class _ProxyIOp(_ProxyLookup): __slots__ = () def __init__( - self, f: t.Optional[t.Callable] = None, fallback: t.Optional[t.Callable] = None + self, + f: t.Callable[..., t.Any] | None = None, + fallback: t.Callable[[LocalProxy[t.Any]], t.Any] | None = None, ) -> None: super().__init__(f, fallback) - def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: - def i_op(self: t.Any, other: t.Any) -> "LocalProxy": + def bind_f(instance: LocalProxy[t.Any], obj: t.Any) -> t.Callable[..., t.Any]: + def i_op(self: t.Any, other: t.Any) -> LocalProxy[t.Any]: f(self, other) # type: ignore return instance @@ -471,10 +477,10 @@ class LocalProxy(t.Generic[T]): def __init__( self, - local: t.Union[ContextVar[T], Local, LocalStack[T], t.Callable[[], T]], - name: t.Optional[str] = None, + local: ContextVar[T] | Local | LocalStack[T] | t.Callable[[], T], + name: str | None = None, *, - unbound_message: t.Optional[str] = None, + unbound_message: str | None = None, ) -> None: if name is None: get_name = _identity @@ -497,7 +503,7 @@ def _get_current_object() -> T: elif isinstance(local, LocalStack): def _get_current_object() -> T: - obj = local.top # type: ignore[union-attr] + obj = local.top if obj is None: raise RuntimeError(unbound_message) @@ -508,7 +514,7 @@ def _get_current_object() -> T: def _get_current_object() -> T: try: - obj = local.get() # type: ignore[union-attr] + obj = local.get() except LookupError: raise RuntimeError(unbound_message) from None @@ -517,7 +523,7 @@ def _get_current_object() -> T: elif callable(local): def _get_current_object() -> T: - return get_name(local()) # type: ignore + return get_name(local()) else: raise TypeError(f"Don't know how to proxy '{type(local)}'.") @@ -525,32 +531,33 @@ def _get_current_object() -> T: object.__setattr__(self, "_LocalProxy__wrapped", local) object.__setattr__(self, "_get_current_object", _get_current_object) - __doc__ = _ProxyLookup( # type: ignore + __doc__ = _ProxyLookup( # type: ignore[assignment] class_value=__doc__, fallback=lambda self: type(self).__doc__, is_attr=True ) __wrapped__ = _ProxyLookup( - fallback=lambda self: self._LocalProxy__wrapped, is_attr=True + fallback=lambda self: self._LocalProxy__wrapped, # type: ignore[attr-defined] + is_attr=True, ) # __del__ should only delete the proxy - __repr__ = _ProxyLookup( # type: ignore + __repr__ = _ProxyLookup( # type: ignore[assignment] repr, fallback=lambda self: f"<{type(self).__name__} unbound>" ) - __str__ = _ProxyLookup(str) # type: ignore + __str__ = _ProxyLookup(str) # type: ignore[assignment] __bytes__ = _ProxyLookup(bytes) - __format__ = _ProxyLookup() # type: ignore + __format__ = _ProxyLookup() # type: ignore[assignment] __lt__ = _ProxyLookup(operator.lt) __le__ = _ProxyLookup(operator.le) - __eq__ = _ProxyLookup(operator.eq) # type: ignore - __ne__ = _ProxyLookup(operator.ne) # type: ignore + __eq__ = _ProxyLookup(operator.eq) # type: ignore[assignment] + __ne__ = _ProxyLookup(operator.ne) # type: ignore[assignment] __gt__ = _ProxyLookup(operator.gt) __ge__ = _ProxyLookup(operator.ge) - __hash__ = _ProxyLookup(hash) # type: ignore + __hash__ = _ProxyLookup(hash) # type: ignore[assignment] __bool__ = _ProxyLookup(bool, fallback=lambda self: False) __getattr__ = _ProxyLookup(getattr) # __getattribute__ triggered through __getattr__ - __setattr__ = _ProxyLookup(setattr) # type: ignore - __delattr__ = _ProxyLookup(delattr) # type: ignore - __dir__ = _ProxyLookup(dir, fallback=lambda self: []) # type: ignore + __setattr__ = _ProxyLookup(setattr) # type: ignore[assignment] + __delattr__ = _ProxyLookup(delattr) # type: ignore[assignment] + __dir__ = _ProxyLookup(dir, fallback=lambda self: []) # type: ignore[assignment] # __get__ (proxying descriptor not supported) # __set__ (descriptor) # __delete__ (descriptor) @@ -561,9 +568,7 @@ def _get_current_object() -> T: # __weakref__ (__getattr__) # __init_subclass__ (proxying metaclass not supported) # __prepare__ (metaclass) - __class__ = _ProxyLookup( - fallback=lambda self: type(self), is_attr=True - ) # type: ignore + __class__ = _ProxyLookup(fallback=lambda self: type(self), is_attr=True) # type: ignore[assignment] __instancecheck__ = _ProxyLookup(lambda self, other: isinstance(other, self)) __subclasscheck__ = _ProxyLookup(lambda self, other: issubclass(other, self)) # __class_getitem__ triggered through __getitem__ diff --git a/src/werkzeug/middleware/__init__.py b/src/werkzeug/middleware/__init__.py index 6ddcf7f..e69de29 100644 --- a/src/werkzeug/middleware/__init__.py +++ b/src/werkzeug/middleware/__init__.py @@ -1,22 +0,0 @@ -""" -Middleware -========== - -A WSGI middleware is a WSGI application that wraps another application -in order to observe or change its behavior. Werkzeug provides some -middleware for common use cases. - -.. toctree:: - :maxdepth: 1 - - proxy_fix - shared_data - dispatcher - http_proxy - lint - profiler - -The :doc:`interactive debugger ` is also a middleware that can -be applied manually, although it is typically used automatically with -the :doc:`development server `. -""" diff --git a/src/werkzeug/middleware/dispatcher.py b/src/werkzeug/middleware/dispatcher.py index ace1c75..e11bacc 100644 --- a/src/werkzeug/middleware/dispatcher.py +++ b/src/werkzeug/middleware/dispatcher.py @@ -30,6 +30,9 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + +from __future__ import annotations + import typing as t if t.TYPE_CHECKING: @@ -50,14 +53,14 @@ class DispatcherMiddleware: def __init__( self, - app: "WSGIApplication", - mounts: t.Optional[t.Dict[str, "WSGIApplication"]] = None, + app: WSGIApplication, + mounts: dict[str, WSGIApplication] | None = None, ) -> None: self.app = app self.mounts = mounts or {} def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: script = environ.get("PATH_INFO", "") path_info = "" diff --git a/src/werkzeug/middleware/http_proxy.py b/src/werkzeug/middleware/http_proxy.py index 1cde458..5e23915 100644 --- a/src/werkzeug/middleware/http_proxy.py +++ b/src/werkzeug/middleware/http_proxy.py @@ -7,13 +7,16 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + +from __future__ import annotations + import typing as t from http import client +from urllib.parse import quote +from urllib.parse import urlsplit from ..datastructures import EnvironHeaders from ..http import is_hop_by_hop_header -from ..urls import url_parse -from ..urls import url_quote from ..wsgi import get_input_stream if t.TYPE_CHECKING: @@ -78,12 +81,12 @@ class ProxyMiddleware: def __init__( self, - app: "WSGIApplication", - targets: t.Mapping[str, t.Dict[str, t.Any]], + app: WSGIApplication, + targets: t.Mapping[str, dict[str, t.Any]], chunk_size: int = 2 << 13, timeout: int = 10, ) -> None: - def _set_defaults(opts: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + def _set_defaults(opts: dict[str, t.Any]) -> dict[str, t.Any]: opts.setdefault("remove_prefix", False) opts.setdefault("host", "") opts.setdefault("headers", {}) @@ -98,13 +101,14 @@ def _set_defaults(opts: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: self.timeout = timeout def proxy_to( - self, opts: t.Dict[str, t.Any], path: str, prefix: str - ) -> "WSGIApplication": - target = url_parse(opts["target"]) - host = t.cast(str, target.ascii_host) + self, opts: dict[str, t.Any], path: str, prefix: str + ) -> WSGIApplication: + target = urlsplit(opts["target"]) + # socket can handle unicode host, but header must be ascii + host = target.hostname.encode("idna").decode("ascii") def application( - environ: "WSGIEnvironment", start_response: "StartResponse" + environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: headers = list(EnvironHeaders(environ).items()) headers[:] = [ @@ -157,7 +161,9 @@ def application( ) con.connect() - remote_url = url_quote(remote_path) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + # as well as percent for things that are already quoted + remote_url = quote(remote_path, safe="!$&'()*+,/:;=@%") querystring = environ["QUERY_STRING"] if querystring: @@ -217,7 +223,7 @@ def read() -> t.Iterator[bytes]: return application def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: path = environ["PATH_INFO"] app = self.app diff --git a/src/werkzeug/middleware/lint.py b/src/werkzeug/middleware/lint.py index 6b54630..de93b52 100644 --- a/src/werkzeug/middleware/lint.py +++ b/src/werkzeug/middleware/lint.py @@ -12,6 +12,9 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + +from __future__ import annotations + import typing as t from types import TracebackType from urllib.parse import urlparse @@ -35,7 +38,7 @@ class HTTPWarning(Warning): """Warning class for HTTP warnings.""" -def check_type(context: str, obj: object, need: t.Type = str) -> None: +def check_type(context: str, obj: object, need: type = str) -> None: if type(obj) is not need: warn( f"{context!r} requires {need.__name__!r}, got {type(obj).__name__!r}.", @@ -117,7 +120,7 @@ def close(self) -> None: class GuardedWrite: - def __init__(self, write: t.Callable[[bytes], object], chunks: t.List[int]) -> None: + def __init__(self, write: t.Callable[[bytes], object], chunks: list[int]) -> None: self._write = write self._chunks = chunks @@ -131,8 +134,8 @@ class GuardedIterator: def __init__( self, iterator: t.Iterable[bytes], - headers_set: t.Tuple[int, Headers], - chunks: t.List[int], + headers_set: tuple[int, Headers], + chunks: list[int], ) -> None: self._iterator = iterator self._next = iter(iterator).__next__ @@ -140,7 +143,7 @@ def __init__( self.headers_set = headers_set self.chunks = chunks - def __iter__(self) -> "GuardedIterator": + def __iter__(self) -> GuardedIterator: return self def __next__(self) -> bytes: @@ -164,7 +167,7 @@ def close(self) -> None: self.closed = True if hasattr(self._iterator, "close"): - self._iterator.close() # type: ignore + self._iterator.close() if self.headers_set: status_code, headers = self.headers_set @@ -178,30 +181,44 @@ def close(self) -> None: key ): warn( - f"Entity header {key!r} found in 304 response.", HTTPWarning + f"Entity header {key!r} found in 304 response.", + HTTPWarning, + stacklevel=2, ) if bytes_sent: - warn("304 responses must not have a body.", HTTPWarning) + warn( + "304 responses must not have a body.", + HTTPWarning, + stacklevel=2, + ) elif 100 <= status_code < 200 or status_code == 204: if content_length != 0: warn( f"{status_code} responses must have an empty content length.", HTTPWarning, + stacklevel=2, ) if bytes_sent: - warn(f"{status_code} responses must not have a body.", HTTPWarning) + warn( + f"{status_code} responses must not have a body.", + HTTPWarning, + stacklevel=2, + ) elif content_length is not None and content_length != bytes_sent: warn( "Content-Length and the number of bytes sent to the" " client do not match.", WSGIWarning, + stacklevel=2, ) def __del__(self) -> None: if not self.closed: try: warn( - "Iterator was garbage collected before it was closed.", WSGIWarning + "Iterator was garbage collected before it was closed.", + WSGIWarning, + stacklevel=2, ) except Exception: pass @@ -230,11 +247,11 @@ class LintMiddleware: app = LintMiddleware(app) """ - def __init__(self, app: "WSGIApplication") -> None: + def __init__(self, app: WSGIApplication) -> None: self.app = app - def check_environ(self, environ: "WSGIEnvironment") -> None: - if type(environ) is not dict: + def check_environ(self, environ: WSGIEnvironment) -> None: + if type(environ) is not dict: # noqa: E721 warn( "WSGI environment is not a standard Python dict.", WSGIWarning, @@ -280,11 +297,9 @@ def check_environ(self, environ: "WSGIEnvironment") -> None: def check_start_response( self, status: str, - headers: t.List[t.Tuple[str, str]], - exc_info: t.Optional[ - t.Tuple[t.Type[BaseException], BaseException, TracebackType] - ], - ) -> t.Tuple[int, Headers]: + headers: list[tuple[str, str]], + exc_info: None | (tuple[type[BaseException], BaseException, TracebackType]), + ) -> tuple[int, Headers]: check_type("status", status, str) status_code_str = status.split(None, 1)[0] @@ -304,14 +319,14 @@ def check_start_response( if status_code < 100: warn("Status code < 100 detected.", WSGIWarning, stacklevel=3) - if type(headers) is not list: + if type(headers) is not list: # noqa: E721 warn("Header list is not a list.", WSGIWarning, stacklevel=3) for item in headers: if type(item) is not tuple or len(item) != 2: warn("Header items must be 2-item tuples.", WSGIWarning, stacklevel=3) name, value = item - if type(name) is not str or type(value) is not str: + if type(name) is not str or type(value) is not str: # noqa: E721 warn( "Header keys and values must be strings.", WSGIWarning, stacklevel=3 ) @@ -326,10 +341,10 @@ def check_start_response( if exc_info is not None and not isinstance(exc_info, tuple): warn("Invalid value for exc_info.", WSGIWarning, stacklevel=3) - headers = Headers(headers) - self.check_headers(headers) + headers_obj = Headers(headers) + self.check_headers(headers_obj) - return status_code, headers + return status_code, headers_obj def check_headers(self, headers: Headers) -> None: etag = headers.get("etag") @@ -359,9 +374,9 @@ def check_headers(self, headers: Headers) -> None: ) def check_iterator(self, app_iter: t.Iterable[bytes]) -> None: - if isinstance(app_iter, bytes): + if isinstance(app_iter, str): warn( - "The application returned a bytestring. The response will send one" + "The application returned a string. The response will send one" " character at a time to the client, which will kill performance." " Return a list or iterable instead.", WSGIWarning, @@ -377,8 +392,8 @@ def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Iterable[bytes]: "A WSGI app does not take keyword arguments.", WSGIWarning, stacklevel=2 ) - environ: "WSGIEnvironment" = args[0] - start_response: "StartResponse" = args[1] + environ: WSGIEnvironment = args[0] + start_response: StartResponse = args[1] self.check_environ(environ) environ["wsgi.input"] = InputStream(environ["wsgi.input"]) @@ -388,8 +403,8 @@ def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Iterable[bytes]: # iterate to the end and we can check the content length. environ["wsgi.file_wrapper"] = FileWrapper - headers_set: t.List[t.Any] = [] - chunks: t.List[int] = [] + headers_set: list[t.Any] = [] + chunks: list[int] = [] def checking_start_response( *args: t.Any, **kwargs: t.Any @@ -402,13 +417,17 @@ def checking_start_response( ) if kwargs: - warn("'start_response' does not take keyword arguments.", WSGIWarning) + warn( + "'start_response' does not take keyword arguments.", + WSGIWarning, + stacklevel=2, + ) status: str = args[0] - headers: t.List[t.Tuple[str, str]] = args[1] - exc_info: t.Optional[ - t.Tuple[t.Type[BaseException], BaseException, TracebackType] - ] = (args[2] if len(args) == 3 else None) + headers: list[tuple[str, str]] = args[1] + exc_info: ( + None | (tuple[type[BaseException], BaseException, TracebackType]) + ) = args[2] if len(args) == 3 else None headers_set[:] = self.check_start_response(status, headers, exc_info) return GuardedWrite(start_response(status, headers, exc_info), chunks) diff --git a/src/werkzeug/middleware/profiler.py b/src/werkzeug/middleware/profiler.py index 200dae0..112b877 100644 --- a/src/werkzeug/middleware/profiler.py +++ b/src/werkzeug/middleware/profiler.py @@ -11,6 +11,9 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + +from __future__ import annotations + import os.path import sys import time @@ -42,11 +45,16 @@ class ProfilerMiddleware: - ``{method}`` - The request method; GET, POST, etc. - ``{path}`` - The request path or 'root' should one not exist. - - ``{elapsed}`` - The elapsed time of the request. + - ``{elapsed}`` - The elapsed time of the request in milliseconds. - ``{time}`` - The time of the request. - If it is a callable, it will be called with the WSGI ``environ`` - dict and should return a filename. + If it is a callable, it will be called with the WSGI ``environ`` and + be expected to return a filename string. The ``environ`` dictionary + will also have the ``"werkzeug.profiler"`` key populated with a + dictionary containing the following fields (more may be added in the + future): + - ``{elapsed}`` - The elapsed time of the request in milliseconds. + - ``{time}`` - The time of the request. :param app: The WSGI application to wrap. :param stream: Write stats to this stream. Disable with ``None``. @@ -63,6 +71,10 @@ class ProfilerMiddleware: from werkzeug.middleware.profiler import ProfilerMiddleware app = ProfilerMiddleware(app) + .. versionchanged:: 3.0 + Added the ``"werkzeug.profiler"`` key to the ``filename_format(environ)`` + parameter with the ``elapsed`` and ``time`` fields. + .. versionchanged:: 0.15 Stats are written even if ``profile_dir`` is given, and can be disable by passing ``stream=None``. @@ -76,11 +88,11 @@ class ProfilerMiddleware: def __init__( self, - app: "WSGIApplication", - stream: t.IO[str] = sys.stdout, + app: WSGIApplication, + stream: t.IO[str] | None = sys.stdout, sort_by: t.Iterable[str] = ("time", "calls"), - restrictions: t.Iterable[t.Union[str, int, float]] = (), - profile_dir: t.Optional[str] = None, + restrictions: t.Iterable[str | int | float] = (), + profile_dir: str | None = None, filename_format: str = "{method}.{path}.{elapsed:.0f}ms.{time:.0f}.prof", ) -> None: self._app = app @@ -91,9 +103,9 @@ def __init__( self._filename_format = filename_format def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: - response_body: t.List[bytes] = [] + response_body: list[bytes] = [] def catching_start_response(status, headers, exc_info=None): # type: ignore start_response(status, headers, exc_info) @@ -106,7 +118,7 @@ def runapp() -> None: response_body.extend(app_iter) if hasattr(app_iter, "close"): - app_iter.close() # type: ignore + app_iter.close() profile = Profile() start = time.time() @@ -116,6 +128,10 @@ def runapp() -> None: if self._profile_dir is not None: if callable(self._filename_format): + environ["werkzeug.profiler"] = { + "elapsed": elapsed * 1000.0, + "time": time.time(), + } filename = self._filename_format(environ) else: filename = self._filename_format.format( diff --git a/src/werkzeug/middleware/proxy_fix.py b/src/werkzeug/middleware/proxy_fix.py index 4cef7cc..cbf4e0b 100644 --- a/src/werkzeug/middleware/proxy_fix.py +++ b/src/werkzeug/middleware/proxy_fix.py @@ -21,6 +21,9 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + +from __future__ import annotations + import typing as t from ..http import parse_list_header @@ -64,23 +67,16 @@ class ProxyFix: app = ProxyFix(app, x_for=1, x_host=1) .. versionchanged:: 1.0 - Deprecated code has been removed: - - * The ``num_proxies`` argument and attribute. - * The ``get_remote_addr`` method. - * The environ keys ``orig_remote_addr``, - ``orig_wsgi_url_scheme``, and ``orig_http_host``. + The ``num_proxies`` argument and attribute; the ``get_remote_addr`` method; and + the environ keys ``orig_remote_addr``, ``orig_wsgi_url_scheme``, and + ``orig_http_host`` were removed. .. versionchanged:: 0.15 - All headers support multiple values. The ``num_proxies`` - argument is deprecated. Each header is configured with a - separate number of trusted proxies. + All headers support multiple values. Each header is configured with a separate + number of trusted proxies. .. versionchanged:: 0.15 - Original WSGI environ values are stored in the - ``werkzeug.proxy_fix.orig`` dict. ``orig_remote_addr``, - ``orig_wsgi_url_scheme``, and ``orig_http_host`` are deprecated - and will be removed in 1.0. + Original WSGI environ values are stored in the ``werkzeug.proxy_fix.orig`` dict. .. versionchanged:: 0.15 Support ``X-Forwarded-Port`` and ``X-Forwarded-Prefix``. @@ -92,7 +88,7 @@ class ProxyFix: def __init__( self, - app: "WSGIApplication", + app: WSGIApplication, x_for: int = 1, x_proto: int = 1, x_host: int = 0, @@ -106,7 +102,7 @@ def __init__( self.x_port = x_port self.x_prefix = x_prefix - def _get_real_value(self, trusted: int, value: t.Optional[str]) -> t.Optional[str]: + def _get_real_value(self, trusted: int, value: str | None) -> str | None: """Get the real value from a list header based on the configured number of trusted proxies. @@ -128,7 +124,7 @@ def _get_real_value(self, trusted: int, value: t.Optional[str]) -> t.Optional[st return None def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: """Modify the WSGI environ based on the various ``Forwarded`` headers before calling the wrapped application. Store the diff --git a/src/werkzeug/middleware/shared_data.py b/src/werkzeug/middleware/shared_data.py index 2ec396c..0a0c956 100644 --- a/src/werkzeug/middleware/shared_data.py +++ b/src/werkzeug/middleware/shared_data.py @@ -8,9 +8,12 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + +from __future__ import annotations + +import importlib.util import mimetypes import os -import pkgutil import posixpath import typing as t from datetime import datetime @@ -36,7 +39,6 @@ class SharedDataMiddleware: - """A WSGI middleware which provides static content for development environments or simple server setups. Its usage is quite simple:: @@ -99,18 +101,18 @@ class SharedDataMiddleware: def __init__( self, - app: "WSGIApplication", - exports: t.Union[ - t.Dict[str, t.Union[str, t.Tuple[str, str]]], - t.Iterable[t.Tuple[str, t.Union[str, t.Tuple[str, str]]]], - ], + app: WSGIApplication, + exports: ( + dict[str, str | tuple[str, str]] + | t.Iterable[tuple[str, str | tuple[str, str]]] + ), disallow: None = None, cache: bool = True, cache_timeout: int = 60 * 60 * 12, fallback_mimetype: str = "application/octet-stream", ) -> None: self.app = app - self.exports: t.List[t.Tuple[str, _TLoader]] = [] + self.exports: list[tuple[str, _TLoader]] = [] self.cache = cache self.cache_timeout = cache_timeout @@ -156,12 +158,12 @@ def get_file_loader(self, filename: str) -> _TLoader: def get_package_loader(self, package: str, package_path: str) -> _TLoader: load_time = datetime.now(timezone.utc) - provider = pkgutil.get_loader(package) - reader = provider.get_resource_reader(package) # type: ignore + spec = importlib.util.find_spec(package) + reader = spec.loader.get_resource_reader(package) # type: ignore[union-attr] def loader( - path: t.Optional[str], - ) -> t.Tuple[t.Optional[str], t.Optional[_TOpener]]: + path: str | None, + ) -> tuple[str | None, _TOpener | None]: if path is None: return None, None @@ -198,8 +200,8 @@ def loader( def get_directory_loader(self, directory: str) -> _TLoader: def loader( - path: t.Optional[str], - ) -> t.Tuple[t.Optional[str], t.Optional[_TOpener]]: + path: str | None, + ) -> tuple[str | None, _TOpener | None]: if path is not None: path = safe_join(directory, path) @@ -216,13 +218,13 @@ def loader( return loader def generate_etag(self, mtime: datetime, file_size: int, real_filename: str) -> str: - real_filename = os.fsencode(real_filename) + fn_str = os.fsencode(real_filename) timestamp = mtime.timestamp() - checksum = adler32(real_filename) & 0xFFFFFFFF + checksum = adler32(fn_str) & 0xFFFFFFFF return f"wzsdm-{timestamp}-{file_size}-{checksum}" def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: path = get_path_info(environ) file_loader = None diff --git a/src/werkzeug/routing/__init__.py b/src/werkzeug/routing/__init__.py index 84b043f..62adc48 100644 --- a/src/werkzeug/routing/__init__.py +++ b/src/werkzeug/routing/__init__.py @@ -105,6 +105,7 @@ routing tried to match a ``POST`` request) a ``MethodNotAllowed`` exception is raised. """ + from .converters import AnyConverter as AnyConverter from .converters import BaseConverter as BaseConverter from .converters import FloatConverter as FloatConverter diff --git a/src/werkzeug/routing/converters.py b/src/werkzeug/routing/converters.py index bbad29d..6016a97 100644 --- a/src/werkzeug/routing/converters.py +++ b/src/werkzeug/routing/converters.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import re import typing as t import uuid - -from ..urls import _fast_url_quote +from urllib.parse import quote if t.TYPE_CHECKING: from .map import Map @@ -15,22 +16,33 @@ class ValidationError(ValueError): class BaseConverter: - """Base class for all converters.""" + """Base class for all converters. + + .. versionchanged:: 2.3 + ``part_isolating`` defaults to ``False`` if ``regex`` contains a ``/``. + """ regex = "[^/]+" weight = 100 part_isolating = True - def __init__(self, map: "Map", *args: t.Any, **kwargs: t.Any) -> None: + def __init_subclass__(cls, **kwargs: t.Any) -> None: + super().__init_subclass__(**kwargs) + + # If the converter isn't inheriting its regex, disable part_isolating by default + # if the regex contains a / character. + if "regex" in cls.__dict__ and "part_isolating" not in cls.__dict__: + cls.part_isolating = "/" not in cls.regex + + def __init__(self, map: Map, *args: t.Any, **kwargs: t.Any) -> None: self.map = map def to_python(self, value: str) -> t.Any: return value def to_url(self, value: t.Any) -> str: - if isinstance(value, (bytes, bytearray)): - return _fast_url_quote(value) - return _fast_url_quote(str(value).encode(self.map.charset)) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + return quote(str(value), safe="!$&'()*+,/:;=@") class UnicodeConverter(BaseConverter): @@ -51,14 +63,12 @@ class UnicodeConverter(BaseConverter): :param length: the exact length of the string. """ - part_isolating = True - def __init__( self, - map: "Map", + map: Map, minlength: int = 1, - maxlength: t.Optional[int] = None, - length: t.Optional[int] = None, + maxlength: int | None = None, + length: int | None = None, ) -> None: super().__init__(map) if length is not None: @@ -86,9 +96,7 @@ class AnyConverter(BaseConverter): Value is validated when building a URL. """ - part_isolating = True - - def __init__(self, map: "Map", *items: str) -> None: + def __init__(self, map: Map, *items: str) -> None: super().__init__(map) self.items = set(items) self.regex = f"(?:{'|'.join([re.escape(x) for x in items])})" @@ -111,9 +119,9 @@ class PathConverter(BaseConverter): :param map: the :class:`Map`. """ + part_isolating = False regex = "[^/].*?" weight = 200 - part_isolating = False class NumberConverter(BaseConverter): @@ -123,15 +131,14 @@ class NumberConverter(BaseConverter): """ weight = 50 - num_convert: t.Callable = int - part_isolating = True + num_convert: t.Callable[[t.Any], t.Any] = int def __init__( self, - map: "Map", + map: Map, fixed_digits: int = 0, - min: t.Optional[int] = None, - max: t.Optional[int] = None, + min: int | None = None, + max: int | None = None, signed: bool = False, ) -> None: if signed: @@ -145,18 +152,18 @@ def __init__( def to_python(self, value: str) -> t.Any: if self.fixed_digits and len(value) != self.fixed_digits: raise ValidationError() - value = self.num_convert(value) - if (self.min is not None and value < self.min) or ( - self.max is not None and value > self.max + value_num = self.num_convert(value) + if (self.min is not None and value_num < self.min) or ( + self.max is not None and value_num > self.max ): raise ValidationError() - return value + return value_num def to_url(self, value: t.Any) -> str: - value = str(self.num_convert(value)) + value_str = str(self.num_convert(value)) if self.fixed_digits: - value = value.zfill(self.fixed_digits) - return value + value_str = value_str.zfill(self.fixed_digits) + return value_str @property def signed_regex(self) -> str: @@ -186,7 +193,6 @@ class IntegerConverter(NumberConverter): """ regex = r"\d+" - part_isolating = True class FloatConverter(NumberConverter): @@ -210,13 +216,12 @@ class FloatConverter(NumberConverter): regex = r"\d+\.\d+" num_convert = float - part_isolating = True def __init__( self, - map: "Map", - min: t.Optional[float] = None, - max: t.Optional[float] = None, + map: Map, + min: float | None = None, + max: float | None = None, signed: bool = False, ) -> None: super().__init__(map, min=min, max=max, signed=signed) # type: ignore @@ -236,7 +241,6 @@ class UUIDConverter(BaseConverter): r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-" r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}" ) - part_isolating = True def to_python(self, value: str) -> uuid.UUID: return uuid.UUID(value) @@ -246,7 +250,7 @@ def to_url(self, value: uuid.UUID) -> str: #: the default converter mapping for the map. -DEFAULT_CONVERTERS: t.Mapping[str, t.Type[BaseConverter]] = { +DEFAULT_CONVERTERS: t.Mapping[str, type[BaseConverter]] = { "default": UnicodeConverter, "string": UnicodeConverter, "any": AnyConverter, diff --git a/src/werkzeug/routing/exceptions.py b/src/werkzeug/routing/exceptions.py index 7cbe6e9..eeabd4e 100644 --- a/src/werkzeug/routing/exceptions.py +++ b/src/werkzeug/routing/exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import difflib import typing as t @@ -8,10 +10,11 @@ if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment - from .map import MapAdapter - from .rules import Rule # noqa: F401 + from ..wrappers.request import Request from ..wrappers.response import Response + from .map import MapAdapter + from .rules import Rule class RoutingException(Exception): @@ -37,9 +40,9 @@ def __init__(self, new_url: str) -> None: def get_response( self, - environ: t.Optional[t.Union["WSGIEnvironment", "Request"]] = None, - scope: t.Optional[dict] = None, - ) -> "Response": + environ: WSGIEnvironment | Request | None = None, + scope: dict[str, t.Any] | None = None, + ) -> Response: return redirect(self.new_url, self.code) @@ -56,7 +59,7 @@ def __init__(self, path_info: str) -> None: class RequestAliasRedirect(RoutingException): # noqa: B903 """This rule is an alias and wants to redirect to the canonical URL.""" - def __init__(self, matched_values: t.Mapping[str, t.Any], endpoint: str) -> None: + def __init__(self, matched_values: t.Mapping[str, t.Any], endpoint: t.Any) -> None: super().__init__() self.matched_values = matched_values self.endpoint = endpoint @@ -69,10 +72,10 @@ class BuildError(RoutingException, LookupError): def __init__( self, - endpoint: str, + endpoint: t.Any, values: t.Mapping[str, t.Any], - method: t.Optional[str], - adapter: t.Optional["MapAdapter"] = None, + method: str | None, + adapter: MapAdapter | None = None, ) -> None: super().__init__(endpoint, values, method) self.endpoint = endpoint @@ -81,16 +84,19 @@ def __init__( self.adapter = adapter @cached_property - def suggested(self) -> t.Optional["Rule"]: + def suggested(self) -> Rule | None: return self.closest_rule(self.adapter) - def closest_rule(self, adapter: t.Optional["MapAdapter"]) -> t.Optional["Rule"]: - def _score_rule(rule: "Rule") -> float: + def closest_rule(self, adapter: MapAdapter | None) -> Rule | None: + def _score_rule(rule: Rule) -> float: return sum( [ 0.98 * difflib.SequenceMatcher( - None, rule.endpoint, self.endpoint + # endpoints can be any type, compare as strings + None, + str(rule.endpoint), + str(self.endpoint), ).ratio(), 0.01 * bool(set(self.values or ()).issubset(rule.arguments)), 0.01 * bool(rule.methods and self.method in rule.methods), @@ -141,6 +147,6 @@ class WebsocketMismatch(BadRequest): class NoMatch(Exception): __slots__ = ("have_match_for", "websocket_mismatch") - def __init__(self, have_match_for: t.Set[str], websocket_mismatch: bool) -> None: + def __init__(self, have_match_for: set[str], websocket_mismatch: bool) -> None: self.have_match_for = have_match_for self.websocket_mismatch = websocket_mismatch diff --git a/src/werkzeug/routing/map.py b/src/werkzeug/routing/map.py index daf94b6..4d15e88 100644 --- a/src/werkzeug/routing/map.py +++ b/src/werkzeug/routing/map.py @@ -1,12 +1,14 @@ -import posixpath +from __future__ import annotations + import typing as t import warnings from pprint import pformat from threading import Lock +from urllib.parse import quote +from urllib.parse import urljoin +from urllib.parse import urlunsplit -from .._internal import _encode_idna from .._internal import _get_environ -from .._internal import _to_str from .._internal import _wsgi_decoding_dance from ..datastructures import ImmutableDict from ..datastructures import MultiDict @@ -14,9 +16,7 @@ from ..exceptions import HTTPException from ..exceptions import MethodNotAllowed from ..exceptions import NotFound -from ..urls import url_encode -from ..urls import url_join -from ..urls import url_quote +from ..urls import _urlencode from ..wsgi import get_host from .converters import DEFAULT_CONVERTERS from .exceptions import BuildError @@ -30,12 +30,12 @@ from .rules import Rule if t.TYPE_CHECKING: - import typing_extensions as te from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment + + from ..wrappers.request import Request from .converters import BaseConverter from .rules import RuleFactory - from ..wrappers.request import Request class Map: @@ -48,7 +48,6 @@ class Map: :param rules: sequence of url rules for this map. :param default_subdomain: The default subdomain for rules without a subdomain defined. - :param charset: charset of the url. defaults to ``"utf-8"`` :param strict_slashes: If a rule ends with a slash but the matched URL does not, redirect to the URL with a trailing slash. :param merge_slashes: Merge consecutive slashes when matching or @@ -63,24 +62,25 @@ class Map: :param sort_parameters: If set to `True` the url parameters are sorted. See `url_encode` for more details. :param sort_key: The sort key function for `url_encode`. - :param encoding_errors: the error method to use for decoding :param host_matching: if set to `True` it enables the host matching feature and disables the subdomain one. If enabled the `host` parameter to rules is used instead of the `subdomain` one. + .. versionchanged:: 3.0 + The ``charset`` and ``encoding_errors`` parameters were removed. + .. versionchanged:: 1.0 - If ``url_scheme`` is ``ws`` or ``wss``, only WebSocket rules - will match. + If ``url_scheme`` is ``ws`` or ``wss``, only WebSocket rules will match. .. versionchanged:: 1.0 - Added ``merge_slashes``. + The ``merge_slashes`` parameter was added. .. versionchanged:: 0.7 - Added ``encoding_errors`` and ``host_matching``. + The ``encoding_errors`` and ``host_matching`` parameters were added. .. versionchanged:: 0.5 - Added ``sort_parameters`` and ``sort_key``. + The ``sort_parameters`` and ``sort_key`` paramters were added. """ #: A dict of default converters to be used. @@ -93,28 +93,23 @@ class Map: def __init__( self, - rules: t.Optional[t.Iterable["RuleFactory"]] = None, + rules: t.Iterable[RuleFactory] | None = None, default_subdomain: str = "", - charset: str = "utf-8", strict_slashes: bool = True, merge_slashes: bool = True, redirect_defaults: bool = True, - converters: t.Optional[t.Mapping[str, t.Type["BaseConverter"]]] = None, + converters: t.Mapping[str, type[BaseConverter]] | None = None, sort_parameters: bool = False, - sort_key: t.Optional[t.Callable[[t.Any], t.Any]] = None, - encoding_errors: str = "replace", + sort_key: t.Callable[[t.Any], t.Any] | None = None, host_matching: bool = False, ) -> None: self._matcher = StateMachineMatcher(merge_slashes) - self._rules_by_endpoint: t.Dict[str, t.List[Rule]] = {} + self._rules_by_endpoint: dict[t.Any, list[Rule]] = {} self._remap = True self._remap_lock = self.lock_class() self.default_subdomain = default_subdomain - self.charset = charset - self.encoding_errors = encoding_errors self.strict_slashes = strict_slashes - self.merge_slashes = merge_slashes self.redirect_defaults = redirect_defaults self.host_matching = host_matching @@ -128,7 +123,15 @@ def __init__( for rulefactory in rules or (): self.add(rulefactory) - def is_endpoint_expecting(self, endpoint: str, *arguments: str) -> bool: + @property + def merge_slashes(self) -> bool: + return self._matcher.merge_slashes + + @merge_slashes.setter + def merge_slashes(self, value: bool) -> None: + self._matcher.merge_slashes = value + + def is_endpoint_expecting(self, endpoint: t.Any, *arguments: str) -> bool: """Iterate over all rules and check if the endpoint expects the arguments provided. This is for example useful if you have some URLs that expect a language code and others that do not and @@ -142,17 +145,17 @@ def is_endpoint_expecting(self, endpoint: str, *arguments: str) -> bool: checked. """ self.update() - arguments = set(arguments) + arguments_set = set(arguments) for rule in self._rules_by_endpoint[endpoint]: - if arguments.issubset(rule.arguments): + if arguments_set.issubset(rule.arguments): return True return False @property - def _rules(self) -> t.List[Rule]: + def _rules(self) -> list[Rule]: return [rule for rules in self._rules_by_endpoint.values() for rule in rules] - def iter_rules(self, endpoint: t.Optional[str] = None) -> t.Iterator[Rule]: + def iter_rules(self, endpoint: t.Any | None = None) -> t.Iterator[Rule]: """Iterate over all rules or the rules of an endpoint. :param endpoint: if provided only the rules for that endpoint @@ -164,7 +167,7 @@ def iter_rules(self, endpoint: t.Optional[str] = None) -> t.Iterator[Rule]: return iter(self._rules_by_endpoint[endpoint]) return iter(self._rules) - def add(self, rulefactory: "RuleFactory") -> None: + def add(self, rulefactory: RuleFactory) -> None: """Add a new rule or factory to the map and bind it. Requires that the rule is not bound to another map. @@ -180,13 +183,13 @@ def add(self, rulefactory: "RuleFactory") -> None: def bind( self, server_name: str, - script_name: t.Optional[str] = None, - subdomain: t.Optional[str] = None, + script_name: str | None = None, + subdomain: str | None = None, url_scheme: str = "http", default_method: str = "GET", - path_info: t.Optional[str] = None, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - ) -> "MapAdapter": + path_info: str | None = None, + query_args: t.Mapping[str, t.Any] | str | None = None, + ) -> MapAdapter: """Return a new :class:`MapAdapter` with the details specified to the call. Note that `script_name` will default to ``'/'`` if not further specified or `None`. The `server_name` at least is a requirement @@ -227,14 +230,17 @@ def bind( if path_info is None: path_info = "/" + # Port isn't part of IDNA, and might push a name over the 63 octet limit. + server_name, port_sep, port = server_name.partition(":") + try: - server_name = _encode_idna(server_name) # type: ignore + server_name = server_name.encode("idna").decode("ascii") except UnicodeError as e: raise BadHost() from e return MapAdapter( self, - server_name, + f"{server_name}{port_sep}{port}", script_name, subdomain, url_scheme, @@ -245,10 +251,10 @@ def bind( def bind_to_environ( self, - environ: t.Union["WSGIEnvironment", "Request"], - server_name: t.Optional[str] = None, - subdomain: t.Optional[str] = None, - ) -> "MapAdapter": + environ: WSGIEnvironment | Request, + server_name: str | None = None, + subdomain: str | None = None, + ) -> MapAdapter: """Like :meth:`bind` but you can pass it an WSGI environment and it will fetch the information from that dictionary. Note that because of limitations in the protocol there is no way to get the current @@ -332,10 +338,10 @@ def bind_to_environ( else: subdomain = ".".join(filter(None, cur_server_name[:offset])) - def _get_wsgi_string(name: str) -> t.Optional[str]: + def _get_wsgi_string(name: str) -> str | None: val = env.get(name) if val is not None: - return _wsgi_decoding_dance(val, self.charset) + return _wsgi_decoding_dance(val) return None script_name = _get_wsgi_string("SCRIPT_NAME") @@ -374,7 +380,6 @@ def __repr__(self) -> str: class MapAdapter: - """Returned by :meth:`Map.bind` or :meth:`Map.bind_to_environ` and does the URL matching and building based on runtime information. """ @@ -384,32 +389,33 @@ def __init__( map: Map, server_name: str, script_name: str, - subdomain: t.Optional[str], + subdomain: str | None, url_scheme: str, path_info: str, default_method: str, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, + query_args: t.Mapping[str, t.Any] | str | None = None, ): self.map = map - self.server_name = _to_str(server_name) - script_name = _to_str(script_name) + self.server_name = server_name + if not script_name.endswith("/"): script_name += "/" + self.script_name = script_name - self.subdomain = _to_str(subdomain) - self.url_scheme = _to_str(url_scheme) - self.path_info = _to_str(path_info) - self.default_method = _to_str(default_method) + self.subdomain = subdomain + self.url_scheme = url_scheme + self.path_info = path_info + self.default_method = default_method self.query_args = query_args self.websocket = self.url_scheme in {"ws", "wss"} def dispatch( self, - view_func: t.Callable[[str, t.Mapping[str, t.Any]], "WSGIApplication"], - path_info: t.Optional[str] = None, - method: t.Optional[str] = None, + view_func: t.Callable[[str, t.Mapping[str, t.Any]], WSGIApplication], + path_info: str | None = None, + method: str | None = None, catch_http_exceptions: bool = False, - ) -> "WSGIApplication": + ) -> WSGIApplication: """Does the complete dispatching process. `view_func` is called with the endpoint and a dict with the values for the view. It should look up the view function, call it, and return a response object @@ -464,35 +470,33 @@ def application(environ, start_response): raise @t.overload - def match( # type: ignore + def match( self, - path_info: t.Optional[str] = None, - method: t.Optional[str] = None, - return_rule: "te.Literal[False]" = False, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - websocket: t.Optional[bool] = None, - ) -> t.Tuple[str, t.Mapping[str, t.Any]]: - ... + path_info: str | None = None, + method: str | None = None, + return_rule: t.Literal[False] = False, + query_args: t.Mapping[str, t.Any] | str | None = None, + websocket: bool | None = None, + ) -> tuple[t.Any, t.Mapping[str, t.Any]]: ... @t.overload def match( self, - path_info: t.Optional[str] = None, - method: t.Optional[str] = None, - return_rule: "te.Literal[True]" = True, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - websocket: t.Optional[bool] = None, - ) -> t.Tuple[Rule, t.Mapping[str, t.Any]]: - ... + path_info: str | None = None, + method: str | None = None, + return_rule: t.Literal[True] = True, + query_args: t.Mapping[str, t.Any] | str | None = None, + websocket: bool | None = None, + ) -> tuple[Rule, t.Mapping[str, t.Any]]: ... def match( self, - path_info: t.Optional[str] = None, - method: t.Optional[str] = None, + path_info: str | None = None, + method: str | None = None, return_rule: bool = False, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - websocket: t.Optional[bool] = None, - ) -> t.Tuple[t.Union[str, Rule], t.Mapping[str, t.Any]]: + query_args: t.Mapping[str, t.Any] | str | None = None, + websocket: bool | None = None, + ) -> tuple[t.Any | Rule, t.Mapping[str, t.Any]]: """The usage is simple: you just pass the match method the current path info as well as the method (which defaults to `GET`). The following things can then happen: @@ -583,8 +587,6 @@ def match( self.map.update() if path_info is None: path_info = self.path_info - else: - path_info = _to_str(path_info, self.map.charset) if query_args is None: query_args = self.query_args or {} method = (method or self.default_method).upper() @@ -592,17 +594,20 @@ def match( if websocket is None: websocket = self.websocket - domain_part = self.server_name if self.map.host_matching else self.subdomain + domain_part = self.server_name + + if not self.map.host_matching and self.subdomain is not None: + domain_part = self.subdomain + path_part = f"/{path_info.lstrip('/')}" if path_info else "" try: result = self.map._matcher.match(domain_part, path_part, method, websocket) except RequestPath as e: + # safe = https://url.spec.whatwg.org/#url-path-segment-string + new_path = quote(e.path_info, safe="!$&'()*+,/:;=@") raise RequestRedirect( - self.make_redirect_url( - url_quote(e.path_info, self.map.charset, safe="/:|+"), - query_args, - ) + self.make_redirect_url(new_path, query_args) ) from None except RequestAliasRedirect as e: raise RequestRedirect( @@ -647,7 +652,7 @@ def _handle_match(match: t.Match[str]) -> str: netloc = self.server_name raise RequestRedirect( - url_join( + urljoin( f"{self.url_scheme or 'http'}://{netloc}{self.script_name}", redirect_url, ) @@ -658,9 +663,7 @@ def _handle_match(match: t.Match[str]) -> str: else: return rule.endpoint, rv - def test( - self, path_info: t.Optional[str] = None, method: t.Optional[str] = None - ) -> bool: + def test(self, path_info: str | None = None, method: str | None = None) -> bool: """Test if a rule would match. Works like `match` but returns `True` if the URL matches, or `False` if it does not exist. @@ -677,7 +680,7 @@ def test( return False return True - def allowed_methods(self, path_info: t.Optional[str] = None) -> t.Iterable[str]: + def allowed_methods(self, path_info: str | None = None) -> t.Iterable[str]: """Returns the valid methods that match for a given path. .. versionadded:: 0.7 @@ -690,7 +693,7 @@ def allowed_methods(self, path_info: t.Optional[str] = None) -> t.Iterable[str]: pass return [] - def get_host(self, domain_part: t.Optional[str]) -> str: + def get_host(self, domain_part: str | None) -> str: """Figures out the full host name for the given domain part. The domain part is a subdomain in case host matching is disabled or a full host name. @@ -698,12 +701,13 @@ def get_host(self, domain_part: t.Optional[str]) -> str: if self.map.host_matching: if domain_part is None: return self.server_name - return _to_str(domain_part, "ascii") - subdomain = domain_part - if subdomain is None: + + return domain_part + + if domain_part is None: subdomain = self.subdomain else: - subdomain = _to_str(subdomain, "ascii") + subdomain = domain_part if subdomain: return f"{subdomain}.{self.server_name}" @@ -715,8 +719,8 @@ def get_default_redirect( rule: Rule, method: str, values: t.MutableMapping[str, t.Any], - query_args: t.Union[t.Mapping[str, t.Any], str], - ) -> t.Optional[str]: + query_args: t.Mapping[str, t.Any] | str, + ) -> str | None: """A helper that returns the URL to redirect to if it finds one. This is used for default redirecting only. @@ -735,38 +739,41 @@ def get_default_redirect( return self.make_redirect_url(path, query_args, domain_part=domain_part) return None - def encode_query_args(self, query_args: t.Union[t.Mapping[str, t.Any], str]) -> str: + def encode_query_args(self, query_args: t.Mapping[str, t.Any] | str) -> str: if not isinstance(query_args, str): - return url_encode(query_args, self.map.charset) + return _urlencode(query_args) return query_args def make_redirect_url( self, path_info: str, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - domain_part: t.Optional[str] = None, + query_args: t.Mapping[str, t.Any] | str | None = None, + domain_part: str | None = None, ) -> str: """Creates a redirect URL. :internal: """ + if query_args is None: + query_args = self.query_args + if query_args: - suffix = f"?{self.encode_query_args(query_args)}" + query_str = self.encode_query_args(query_args) else: - suffix = "" + query_str = None scheme = self.url_scheme or "http" host = self.get_host(domain_part) - path = posixpath.join(self.script_name.strip("/"), path_info.lstrip("/")) - return f"{scheme}://{host}/{path}{suffix}" + path = "/".join((self.script_name.strip("/"), path_info.lstrip("/"))) + return urlunsplit((scheme, host, path, query_str, None)) def make_alias_redirect_url( self, path: str, - endpoint: str, + endpoint: t.Any, values: t.Mapping[str, t.Any], method: str, - query_args: t.Union[t.Mapping[str, t.Any], str], + query_args: t.Mapping[str, t.Any] | str, ) -> str: """Internally called to make an alias redirect URL.""" url = self.build( @@ -779,11 +786,11 @@ def make_alias_redirect_url( def _partial_build( self, - endpoint: str, + endpoint: t.Any, values: t.Mapping[str, t.Any], - method: t.Optional[str], + method: str | None, append_unknown: bool, - ) -> t.Optional[t.Tuple[str, str, bool]]: + ) -> tuple[str, str, bool] | None: """Helper for :meth:`build`. Returns subdomain and path for the rule that accepts this endpoint, values and method. @@ -820,12 +827,12 @@ def _partial_build( def build( self, - endpoint: str, - values: t.Optional[t.Mapping[str, t.Any]] = None, - method: t.Optional[str] = None, + endpoint: t.Any, + values: t.Mapping[str, t.Any] | None = None, + method: str | None = None, force_external: bool = False, append_unknown: bool = True, - url_scheme: t.Optional[str] = None, + url_scheme: str | None = None, ) -> str: """Building URLs works pretty much the other way round. Instead of `match` you call `build` and pass it the endpoint and a dict of diff --git a/src/werkzeug/routing/matcher.py b/src/werkzeug/routing/matcher.py index d22b05a..1fd00ef 100644 --- a/src/werkzeug/routing/matcher.py +++ b/src/werkzeug/routing/matcher.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import typing as t from dataclasses import dataclass @@ -23,9 +25,9 @@ class State: possible *static* and *dynamic* transitions to the next state. """ - dynamic: t.List[t.Tuple[RulePart, "State"]] = field(default_factory=list) - rules: t.List[Rule] = field(default_factory=list) - static: t.Dict[str, "State"] = field(default_factory=dict) + dynamic: list[tuple[RulePart, State]] = field(default_factory=list) + rules: list[Rule] = field(default_factory=list) + static: dict[str, State] = field(default_factory=dict) class StateMachineMatcher: @@ -66,7 +68,7 @@ def _update_state(state: State) -> None: def match( self, domain: str, path: str, method: str, websocket: bool - ) -> t.Tuple[Rule, t.MutableMapping[str, t.Any]]: + ) -> tuple[Rule, t.MutableMapping[str, t.Any]]: # To match to a rule we need to start at the root state and # try to follow the transitions until we find a match, or find # there is no transition to follow. @@ -75,8 +77,8 @@ def match( websocket_mismatch = False def _match( - state: State, parts: t.List[str], values: t.List[str] - ) -> t.Optional[t.Tuple[Rule, t.List[str]]]: + state: State, parts: list[str], values: list[str] + ) -> tuple[Rule, list[str]] | None: # This function is meant to be called recursively, and will attempt # to match the head part to the state's transitions. nonlocal have_match_for, websocket_mismatch @@ -127,7 +129,22 @@ def _match( remaining = [] match = re.compile(test_part.content).match(target) if match is not None: - rv = _match(new_state, remaining, values + list(match.groups())) + if test_part.suffixed: + # If a part_isolating=False part has a slash suffix, remove the + # suffix from the match and check for the slash redirect next. + suffix = match.groups()[-1] + if suffix == "/": + remaining = [""] + + converter_groups = sorted( + match.groupdict().items(), key=lambda entry: entry[0] + ) + groups = [ + value + for key, value in converter_groups + if key[:11] == "__werkzeug_" + ] + rv = _match(new_state, remaining, values + groups) if rv is not None: return rv @@ -160,7 +177,7 @@ def _match( rv = _match(self._root, [domain, *path.split("/")], []) except SlashRequired: raise RequestPath(f"{path}/") from None - if rv is None: + if rv is None or rv[0].merge_slashes is False: raise NoMatch(have_match_for, websocket_mismatch) else: raise RequestPath(f"{path}") diff --git a/src/werkzeug/routing/rules.py b/src/werkzeug/routing/rules.py index a61717a..6a02f8d 100644 --- a/src/werkzeug/routing/rules.py +++ b/src/werkzeug/routing/rules.py @@ -1,13 +1,15 @@ +from __future__ import annotations + import ast import re import typing as t from dataclasses import dataclass from string import Template from types import CodeType +from urllib.parse import quote -from .._internal import _to_bytes -from ..urls import url_encode -from ..urls import url_quote +from ..datastructures import iter_multi_items +from ..urls import _urlencode from .converters import ValidationError if t.TYPE_CHECKING: @@ -17,9 +19,9 @@ class Weighting(t.NamedTuple): number_static_weights: int - static_weights: t.List[t.Tuple[int, int]] + static_weights: list[tuple[int, int]] number_argument_weights: int - argument_weights: t.List[int] + argument_weights: list[int] @dataclass @@ -36,22 +38,23 @@ class RulePart: content: str final: bool static: bool + suffixed: bool weight: Weighting _part_re = re.compile( r""" (?: - (?P\/) # a slash + (?P/) # a slash | - (?P[^<\/]+) # static rule data + (?P[^[a-zA-Z_][a-zA-Z0-9_]*) # converter name (?:\((?P.*?)\))? # converter arguments - \: # variable delimiter + : # variable delimiter )? (?P[a-zA-Z_][a-zA-Z0-9_]*) # variable name > @@ -64,6 +67,7 @@ class RulePart: _simple_rule_re = re.compile(r"<([^>]+)>") _converter_args_re = re.compile( r""" + \s* ((?P\w+)\s*=\s*)? (?P True|False| @@ -92,7 +96,7 @@ def _find(value: str, target: str, pos: int) -> int: return len(value) -def _pythonize(value: str) -> t.Union[None, bool, int, float, str]: +def _pythonize(value: str) -> None | bool | int | float | str: if value in _PYTHON_CONSTANTS: return _PYTHON_CONSTANTS[value] for convert in int, float: @@ -105,12 +109,18 @@ def _pythonize(value: str) -> t.Union[None, bool, int, float, str]: return str(value) -def parse_converter_args(argstr: str) -> t.Tuple[t.Tuple, t.Dict[str, t.Any]]: +def parse_converter_args(argstr: str) -> tuple[tuple[t.Any, ...], dict[str, t.Any]]: argstr += "," args = [] kwargs = {} + position = 0 for item in _converter_args_re.finditer(argstr): + if item.start() != position: + raise ValueError( + f"Cannot parse converter argument '{argstr[position:item.start()]}'" + ) + value = item.group("stringval") if value is None: value = item.group("value") @@ -120,6 +130,7 @@ def parse_converter_args(argstr: str) -> t.Tuple[t.Tuple, t.Dict[str, t.Any]]: else: name = item.group("name") kwargs[name] = value + position = item.end() return tuple(args), kwargs @@ -130,7 +141,7 @@ class RuleFactory: be added by subclassing `RuleFactory` and overriding `get_rules`. """ - def get_rules(self, map: "Map") -> t.Iterable["Rule"]: + def get_rules(self, map: Map) -> t.Iterable[Rule]: """Subclasses of `RuleFactory` have to override this method and return an iterable of rules.""" raise NotImplementedError() @@ -159,7 +170,7 @@ def __init__(self, subdomain: str, rules: t.Iterable[RuleFactory]) -> None: self.subdomain = subdomain self.rules = rules - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): rule = rule.empty() @@ -185,7 +196,7 @@ def __init__(self, path: str, rules: t.Iterable[RuleFactory]) -> None: self.path = path.rstrip("/") self.rules = rules - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): rule = rule.empty() @@ -210,7 +221,7 @@ def __init__(self, prefix: str, rules: t.Iterable[RuleFactory]) -> None: self.prefix = prefix self.rules = rules - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): rule = rule.empty() @@ -237,10 +248,10 @@ class RuleTemplate: replace the placeholders in all the string parameters. """ - def __init__(self, rules: t.Iterable["Rule"]) -> None: + def __init__(self, rules: t.Iterable[Rule]) -> None: self.rules = list(rules) - def __call__(self, *args: t.Any, **kwargs: t.Any) -> "RuleTemplateFactory": + def __call__(self, *args: t.Any, **kwargs: t.Any) -> RuleTemplateFactory: return RuleTemplateFactory(self.rules, dict(*args, **kwargs)) @@ -252,12 +263,12 @@ class RuleTemplateFactory(RuleFactory): """ def __init__( - self, rules: t.Iterable[RuleFactory], context: t.Dict[str, t.Any] + self, rules: t.Iterable[RuleFactory], context: dict[str, t.Any] ) -> None: self.rules = rules self.context = context - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): new_defaults = subdomain = None @@ -438,25 +449,26 @@ def foo_with_slug(adapter, id): def __init__( self, string: str, - defaults: t.Optional[t.Mapping[str, t.Any]] = None, - subdomain: t.Optional[str] = None, - methods: t.Optional[t.Iterable[str]] = None, + defaults: t.Mapping[str, t.Any] | None = None, + subdomain: str | None = None, + methods: t.Iterable[str] | None = None, build_only: bool = False, - endpoint: t.Optional[str] = None, - strict_slashes: t.Optional[bool] = None, - merge_slashes: t.Optional[bool] = None, - redirect_to: t.Optional[t.Union[str, t.Callable[..., str]]] = None, + endpoint: t.Any | None = None, + strict_slashes: bool | None = None, + merge_slashes: bool | None = None, + redirect_to: str | t.Callable[..., str] | None = None, alias: bool = False, - host: t.Optional[str] = None, + host: str | None = None, websocket: bool = False, ) -> None: if not string.startswith("/"): - raise ValueError("urls must start with a leading slash") + raise ValueError(f"URL rule '{string}' must start with a slash.") + self.rule = string self.is_leaf = not string.endswith("/") self.is_branch = string.endswith("/") - self.map: "Map" = None # type: ignore + self.map: Map = None # type: ignore self.strict_slashes = strict_slashes self.merge_slashes = merge_slashes self.subdomain = subdomain @@ -481,7 +493,7 @@ def __init__( ) self.methods = methods - self.endpoint: str = endpoint # type: ignore + self.endpoint: t.Any = endpoint self.redirect_to = redirect_to if defaults: @@ -489,11 +501,11 @@ def __init__( else: self.arguments = set() - self._converters: t.Dict[str, "BaseConverter"] = {} - self._trace: t.List[t.Tuple[bool, str]] = [] - self._parts: t.List[RulePart] = [] + self._converters: dict[str, BaseConverter] = {} + self._trace: list[tuple[bool, str]] = [] + self._parts: list[RulePart] = [] - def empty(self) -> "Rule": + def empty(self) -> Rule: """ Return an unbound copy of this rule. @@ -530,7 +542,7 @@ def get_empty_kwargs(self) -> t.Mapping[str, t.Any]: host=self.host, ) - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: yield self def refresh(self) -> None: @@ -541,7 +553,7 @@ def refresh(self) -> None: """ self.bind(self.map, rebind=True) - def bind(self, map: "Map", rebind: bool = False) -> None: + def bind(self, map: Map, rebind: bool = False) -> None: """Bind the url to a map and create a regular expression based on the information from the rule itself and the defaults from the map. @@ -562,9 +574,9 @@ def get_converter( self, variable_name: str, converter_name: str, - args: t.Tuple, + args: tuple[t.Any, ...], kwargs: t.Mapping[str, t.Any], - ) -> "BaseConverter": + ) -> BaseConverter: """Looks up the converter for the given parameter. .. versionadded:: 0.9 @@ -574,19 +586,20 @@ def get_converter( return self.map.converters[converter_name](self.map, *args, **kwargs) def _encode_query_vars(self, query_vars: t.Mapping[str, t.Any]) -> str: - return url_encode( - query_vars, - charset=self.map.charset, - sort=self.map.sort_parameters, - key=self.map.sort_key, - ) + items: t.Iterable[tuple[str, str]] = iter_multi_items(query_vars) + + if self.map.sort_parameters: + items = sorted(items, key=self.map.sort_key) + + return _urlencode(items) def _parse_rule(self, rule: str) -> t.Iterable[RulePart]: content = "" static = True argument_weights = [] - static_weights: t.List[t.Tuple[int, int]] = [] + static_weights: list[tuple[int, int]] = [] final = False + convertor_number = 0 pos = 0 while pos < len(rule): @@ -613,7 +626,8 @@ def _parse_rule(self, rule: str) -> t.Iterable[RulePart]: self.arguments.add(data["variable"]) if not convobj.part_isolating: final = True - content += f"({convobj.regex})" + content += f"(?P<__werkzeug_{convertor_number}>{convobj.regex})" + convertor_number += 1 argument_weights.append(convobj.weight) self._trace.append((True, data["variable"])) @@ -631,16 +645,27 @@ def _parse_rule(self, rule: str) -> t.Iterable[RulePart]: argument_weights, ) yield RulePart( - content=content, final=final, static=static, weight=weight + content=content, + final=final, + static=static, + suffixed=False, + weight=weight, ) content = "" static = True argument_weights = [] static_weights = [] final = False + convertor_number = 0 pos = match.end() + suffixed = False + if final and content[-1] == "/": + # If a converter is part_isolating=False (matches slashes) and ends with a + # slash, augment the regex to support slash redirects. + suffixed = True + content = content[:-1] + "(? t.Iterable[RulePart]: -len(argument_weights), argument_weights, ) - yield RulePart(content=content, final=final, static=static, weight=weight) + yield RulePart( + content=content, + final=final, + static=static, + suffixed=suffixed, + weight=weight, + ) + if suffixed: + yield RulePart( + content="", final=False, static=True, suffixed=False, weight=weight + ) def compile(self) -> None: """Compiles the regular expression and stores it.""" @@ -665,7 +700,11 @@ def compile(self) -> None: if domain_rule == "": self._parts = [ RulePart( - content="", final=False, static=True, weight=Weighting(0, [], 0, []) + content="", + final=False, + static=True, + suffixed=False, + weight=Weighting(0, [], 0, []), ) ] else: @@ -676,24 +715,24 @@ def compile(self) -> None: rule = re.sub("/{2,}?", "/", self.rule) self._parts.extend(self._parse_rule(rule)) - self._build: t.Callable[..., t.Tuple[str, str]] + self._build: t.Callable[..., tuple[str, str]] self._build = self._compile_builder(False).__get__(self, None) - self._build_unknown: t.Callable[..., t.Tuple[str, str]] + self._build_unknown: t.Callable[..., tuple[str, str]] self._build_unknown = self._compile_builder(True).__get__(self, None) @staticmethod - def _get_func_code(code: CodeType, name: str) -> t.Callable[..., t.Tuple[str, str]]: - globs: t.Dict[str, t.Any] = {} - locs: t.Dict[str, t.Any] = {} + def _get_func_code(code: CodeType, name: str) -> t.Callable[..., tuple[str, str]]: + globs: dict[str, t.Any] = {} + locs: dict[str, t.Any] = {} exec(code, globs, locs) return locs[name] # type: ignore def _compile_builder( self, append_unknown: bool = True - ) -> t.Callable[..., t.Tuple[str, str]]: + ) -> t.Callable[..., tuple[str, str]]: defaults = self.defaults or {} - dom_ops: t.List[t.Tuple[bool, str]] = [] - url_ops: t.List[t.Tuple[bool, str]] = [] + dom_ops: list[tuple[bool, str]] = [] + url_ops: list[tuple[bool, str]] = [] opl = dom_ops for is_dynamic, data in self._trace: @@ -707,9 +746,8 @@ def _compile_builder( data = self._converters[data].to_url(defaults[data]) opl.append((False, data)) elif not is_dynamic: - opl.append( - (False, url_quote(_to_bytes(data, self.map.charset), safe="/:|+")) - ) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + opl.append((False, quote(data, safe="!$&'()*+,/:;=@"))) else: opl.append((True, data)) @@ -718,17 +756,17 @@ def _convert(elem: str) -> ast.stmt: ret.args = [ast.Name(str(elem), ast.Load())] # type: ignore # str for py2 return ret - def _parts(ops: t.List[t.Tuple[bool, str]]) -> t.List[ast.AST]: + def _parts(ops: list[tuple[bool, str]]) -> list[ast.AST]: parts = [ - _convert(elem) if is_dynamic else ast.Str(s=elem) + _convert(elem) if is_dynamic else ast.Constant(elem) for is_dynamic, elem in ops ] - parts = parts or [ast.Str("")] + parts = parts or [ast.Constant("")] # constant fold ret = [parts[0]] for p in parts[1:]: - if isinstance(p, ast.Str) and isinstance(ret[-1], ast.Str): - ret[-1] = ast.Str(ret[-1].s + p.s) + if isinstance(p, ast.Constant) and isinstance(ret[-1], ast.Constant): + ret[-1] = ast.Constant(ret[-1].value + p.value) else: ret.append(p) return ret @@ -741,7 +779,7 @@ def _parts(ops: t.List[t.Tuple[bool, str]]) -> t.List[ast.AST]: body = [_IF_KWARGS_URL_ENCODE_AST] url_parts.extend(_URL_ENCODE_AST_NAMES) - def _join(parts: t.List[ast.AST]) -> ast.AST: + def _join(parts: list[ast.AST]) -> ast.AST: if len(parts) == 1: # shortcut return parts[0] return ast.JoinedStr(parts) @@ -764,11 +802,11 @@ def _join(parts: t.List[ast.AST]) -> ast.AST: func_ast.args.args.append(ast.arg(arg, None)) func_ast.args.kwarg = ast.arg(".kwargs", None) for _ in kargs: - func_ast.args.defaults.append(ast.Str("")) + func_ast.args.defaults.append(ast.Constant("")) func_ast.body = body - # use `ast.parse` instead of `ast.Module` for better portability - # Python 3.8 changes the signature of `ast.Module` + # Use `ast.parse` instead of `ast.Module` for better portability, since the + # signature of `ast.Module` can change. module = ast.parse("") module.body = [func_ast] @@ -779,18 +817,18 @@ def _join(parts: t.List[ast.AST]) -> ast.AST: if "lineno" in node._attributes: node.lineno = 1 if "end_lineno" in node._attributes: - node.end_lineno = node.lineno # type: ignore[attr-defined] + node.end_lineno = node.lineno if "col_offset" in node._attributes: node.col_offset = 0 if "end_col_offset" in node._attributes: - node.end_col_offset = node.col_offset # type: ignore[attr-defined] + node.end_col_offset = node.col_offset code = compile(module, "", "exec") return self._get_func_code(code, func_ast.name) def build( self, values: t.Mapping[str, t.Any], append_unknown: bool = True - ) -> t.Optional[t.Tuple[str, str]]: + ) -> tuple[str, str] | None: """Assembles the relative url for that rule and the subdomain. If building doesn't work for some reasons `None` is returned. @@ -804,7 +842,7 @@ def build( except ValidationError: return None - def provides_defaults_for(self, rule: "Rule") -> bool: + def provides_defaults_for(self, rule: Rule) -> bool: """Check if this rule has defaults for a given rule. :internal: @@ -818,7 +856,7 @@ def provides_defaults_for(self, rule: "Rule") -> bool: ) def suitable_for( - self, values: t.Mapping[str, t.Any], method: t.Optional[str] = None + self, values: t.Mapping[str, t.Any], method: str | None = None ) -> bool: """Check if the dict of values has enough data for url generation. @@ -850,7 +888,7 @@ def suitable_for( return True - def build_compare_key(self) -> t.Tuple[int, int, int]: + def build_compare_key(self) -> tuple[int, int, int]: """The build compare key for sorting. :internal: @@ -874,6 +912,6 @@ def __repr__(self) -> str: parts.append(f"<{data}>") else: parts.append(data) - parts = "".join(parts).lstrip("|") + parts_str = "".join(parts).lstrip("|") methods = f" ({', '.join(self.methods)})" if self.methods is not None else "" - return f"<{type(self).__name__} {parts!r}{methods} -> {self.endpoint}>" + return f"<{type(self).__name__} {parts_str!r}{methods} -> {self.endpoint}>" diff --git a/src/werkzeug/sansio/http.py b/src/werkzeug/sansio/http.py index 8288882..b2b8877 100644 --- a/src/werkzeug/sansio/http.py +++ b/src/werkzeug/sansio/http.py @@ -1,10 +1,10 @@ +from __future__ import annotations + import re import typing as t from datetime import datetime -from .._internal import _cookie_parse_impl from .._internal import _dt_as_utc -from .._internal import _to_str from ..http import generate_etag from ..http import parse_date from ..http import parse_etags @@ -15,14 +15,14 @@ def is_resource_modified( - http_range: t.Optional[str] = None, - http_if_range: t.Optional[str] = None, - http_if_modified_since: t.Optional[str] = None, - http_if_none_match: t.Optional[str] = None, - http_if_match: t.Optional[str] = None, - etag: t.Optional[str] = None, - data: t.Optional[bytes] = None, - last_modified: t.Optional[t.Union[datetime, str]] = None, + http_range: str | None = None, + http_if_range: str | None = None, + http_if_modified_since: str | None = None, + http_if_none_match: str | None = None, + http_if_match: str | None = None, + etag: str | None = None, + data: bytes | None = None, + last_modified: datetime | str | None = None, ignore_if_range: bool = True, ) -> bool: """Convenience method for conditional requests. @@ -63,7 +63,7 @@ def is_resource_modified( if_range = parse_if_range_header(http_if_range) if if_range is not None and if_range.date is not None: - modified_since: t.Optional[datetime] = if_range.date + modified_since: datetime | None = if_range.date else: modified_since = parse_date(http_if_modified_since) @@ -94,12 +94,36 @@ def is_resource_modified( return not unmodified +_cookie_re = re.compile( + r""" + ([^=;]*) + (?:\s*=\s* + ( + "(?:[^\\"]|\\.)*" + | + .*? + ) + )? + \s*;\s* + """, + flags=re.ASCII | re.VERBOSE, +) +_cookie_unslash_re = re.compile(rb"\\([0-3][0-7]{2}|.)") + + +def _cookie_unslash_replace(m: t.Match[bytes]) -> bytes: + v = m.group(1) + + if len(v) == 1: + return v + + return int(v, 8).to_bytes(1, "big") + + def parse_cookie( - cookie: t.Union[bytes, str, None] = "", - charset: str = "utf-8", - errors: str = "replace", - cls: t.Optional[t.Type["ds.MultiDict"]] = None, -) -> "ds.MultiDict[str, str]": + cookie: str | None = None, + cls: type[ds.MultiDict[str, str]] | None = None, +) -> ds.MultiDict[str, str]: """Parse a cookie from a string. The same key can be provided multiple times, the values are stored @@ -108,32 +132,39 @@ def parse_cookie( :meth:`MultiDict.getlist`. :param cookie: The cookie header as a string. - :param charset: The charset for the cookie values. - :param errors: The error behavior for the charset decoding. :param cls: A dict-like class to store the parsed cookies in. Defaults to :class:`MultiDict`. + .. versionchanged:: 3.0 + Passing bytes, and the ``charset`` and ``errors`` parameters, were removed. + .. versionadded:: 2.2 """ - # PEP 3333 sends headers through the environ as latin1 decoded - # strings. Encode strings back to bytes for parsing. - if isinstance(cookie, str): - cookie = cookie.encode("latin1", "replace") - if cls is None: - cls = ds.MultiDict + cls = t.cast("type[ds.MultiDict[str, str]]", ds.MultiDict) + + if not cookie: + return cls() + + cookie = f"{cookie};" + out = [] + + for ck, cv in _cookie_re.findall(cookie): + ck = ck.strip() + cv = cv.strip() - def _parse_pairs() -> t.Iterator[t.Tuple[str, str]]: - for key, val in _cookie_parse_impl(cookie): # type: ignore - key_str = _to_str(key, charset, errors, allow_none_charset=True) + if not ck: + continue - if not key_str: - continue + if len(cv) >= 2 and cv[0] == cv[-1] == '"': + # Work with bytes here, since a UTF-8 character could be multiple bytes. + cv = _cookie_unslash_re.sub( + _cookie_unslash_replace, cv[1:-1].encode() + ).decode(errors="replace") - val_str = _to_str(val, charset, errors, allow_none_charset=True) - yield key_str, val_str + out.append((ck, cv)) - return cls(_parse_pairs()) + return cls(out) # circular dependencies diff --git a/src/werkzeug/sansio/multipart.py b/src/werkzeug/sansio/multipart.py index d8abeb3..fc87353 100644 --- a/src/werkzeug/sansio/multipart.py +++ b/src/werkzeug/sansio/multipart.py @@ -1,14 +1,11 @@ +from __future__ import annotations + import re +import typing as t from dataclasses import dataclass from enum import auto from enum import Enum -from typing import cast -from typing import List -from typing import Optional -from typing import Tuple -from .._internal import _to_bytes -from .._internal import _to_str from ..datastructures import Headers from ..exceptions import RequestEntityTooLarge from ..http import parse_options_header @@ -58,6 +55,7 @@ class State(Enum): PREAMBLE = auto() PART = auto() DATA = auto() + DATA_START = auto() EPILOGUE = auto() COMPLETE = auto() @@ -86,11 +84,14 @@ class MultipartDecoder: def __init__( self, boundary: bytes, - max_form_memory_size: Optional[int] = None, + max_form_memory_size: int | None = None, + *, + max_parts: int | None = None, ) -> None: self.buffer = bytearray() self.complete = False self.max_form_memory_size = max_form_memory_size + self.max_parts = max_parts self.state = State.PREAMBLE self.boundary = boundary @@ -118,20 +119,21 @@ def __init__( re.MULTILINE, ) self._search_position = 0 + self._parts_decoded = 0 - def last_newline(self) -> int: + def last_newline(self, data: bytes) -> int: try: - last_nl = self.buffer.rindex(b"\n") + last_nl = data.rindex(b"\n") except ValueError: - last_nl = len(self.buffer) + last_nl = len(data) try: - last_cr = self.buffer.rindex(b"\r") + last_cr = data.rindex(b"\r") except ValueError: - last_cr = len(self.buffer) + last_cr = len(data) return min(last_nl, last_cr) - def receive_data(self, data: Optional[bytes]) -> None: + def receive_data(self, data: bytes | None) -> None: if data is None: self.complete = True elif ( @@ -168,7 +170,11 @@ def next_event(self) -> Event: match = BLANK_LINE_RE.search(self.buffer, self._search_position) if match is not None: headers = self._parse_headers(self.buffer[: match.start()]) - del self.buffer[: match.end()] + # The final header ends with a single CRLF, however a + # blank line indicates the start of the + # body. Therefore the end is after the first CRLF. + headers_end = (match.start() + match.end()) // 2 + del self.buffer[:headers_end] if "content-disposition" not in headers: raise ValueError("Missing Content-Disposition header") @@ -176,7 +182,7 @@ def next_event(self) -> Event: disposition, extra = parse_options_header( headers["content-disposition"] ) - name = cast(str, extra.get("name")) + name = t.cast(str, extra.get("name")) filename = extra.get("filename") if filename is not None: event = File( @@ -189,36 +195,27 @@ def next_event(self) -> Event: headers=headers, name=name, ) - self.state = State.DATA + self.state = State.DATA_START self._search_position = 0 + self._parts_decoded += 1 + + if self.max_parts is not None and self._parts_decoded > self.max_parts: + raise RequestEntityTooLarge() else: # Update the search start position to be equal to the # current buffer length (already searched) minus a # safe buffer for part of the search target. self._search_position = max(0, len(self.buffer) - SEARCH_EXTRA_LENGTH) - elif self.state == State.DATA: - if self.buffer.find(b"--" + self.boundary) == -1: - # No complete boundary in the buffer, but there may be - # a partial boundary at the end. As the boundary - # starts with either a nl or cr find the earliest and - # return up to that as data. - data_length = del_index = self.last_newline() - more_data = True - else: - match = self.boundary_re.search(self.buffer) - if match is not None: - if match.group(1).startswith(b"--"): - self.state = State.EPILOGUE - else: - self.state = State.PART - data_length = match.start() - del_index = match.end() - else: - data_length = del_index = self.last_newline() - more_data = match is None + elif self.state == State.DATA_START: + data, del_index, more_data = self._parse_data(self.buffer, start=True) + del self.buffer[:del_index] + event = Data(data=data, more_data=more_data) + if more_data: + self.state = State.DATA - data = bytes(self.buffer[:data_length]) + elif self.state == State.DATA: + data, del_index, more_data = self._parse_data(self.buffer, start=False) del self.buffer[:del_index] if data or not more_data: event = Data(data=data, more_data=more_data) @@ -234,16 +231,56 @@ def next_event(self) -> Event: return event def _parse_headers(self, data: bytes) -> Headers: - headers: List[Tuple[str, str]] = [] + headers: list[tuple[str, str]] = [] # Merge the continued headers into one line data = HEADER_CONTINUATION_RE.sub(b" ", data) # Now there is one header per line for line in data.splitlines(): - if line.strip() != b"": - name, value = _to_str(line).strip().split(":", 1) + line = line.strip() + + if line != b"": + name, _, value = line.decode().partition(":") headers.append((name.strip(), value.strip())) return Headers(headers) + def _parse_data(self, data: bytes, *, start: bool) -> tuple[bytes, int, bool]: + # Body parts must start with CRLF (or CR or LF) + if start: + match = LINE_BREAK_RE.match(data) + data_start = t.cast(t.Match[bytes], match).end() + else: + data_start = 0 + + boundary = b"--" + self.boundary + + if self.buffer.find(boundary) == -1: + # No complete boundary in the buffer, but there may be + # a partial boundary at the end. As the boundary + # starts with either a nl or cr find the earliest and + # return up to that as data. + data_end = del_index = self.last_newline(data[data_start:]) + data_start + # If amount of data after last newline is far from + # possible length of partial boundary, we should + # assume that there is no partial boundary in the buffer + # and return all pending data. + if (len(data) - data_end) > len(b"\n" + boundary): + data_end = del_index = len(data) + more_data = True + else: + match = self.boundary_re.search(data) + if match is not None: + if match.group(1).startswith(b"--"): + self.state = State.EPILOGUE + else: + self.state = State.PART + data_end = match.start() + del_index = match.end() + else: + data_end = del_index = self.last_newline(data[data_start:]) + data_start + more_data = match is None + + return bytes(data[data_start:data_end]), del_index, more_data + class MultipartEncoder: def __init__(self, boundary: bytes) -> None: @@ -259,17 +296,22 @@ def send_event(self, event: Event) -> bytes: State.PART, State.DATA, }: - self.state = State.DATA data = b"\r\n--" + self.boundary + b"\r\n" - data += b'Content-Disposition: form-data; name="%s"' % _to_bytes(event.name) + data += b'Content-Disposition: form-data; name="%s"' % event.name.encode() if isinstance(event, File): - data += b'; filename="%s"' % _to_bytes(event.filename) + data += b'; filename="%s"' % event.filename.encode() data += b"\r\n" - for name, value in cast(Field, event).headers: + for name, value in t.cast(Field, event).headers: if name.lower() != "content-disposition": - data += _to_bytes(f"{name}: {value}\r\n") - data += b"\r\n" + data += f"{name}: {value}\r\n".encode() + self.state = State.DATA_START return data + elif isinstance(event, Data) and self.state == State.DATA_START: + self.state = State.DATA + if len(event.data) > 0: + return b"\r\n" + event.data + else: + return event.data elif isinstance(event, Data) and self.state == State.DATA: return event.data elif isinstance(event, Epilogue): diff --git a/src/werkzeug/sansio/request.py b/src/werkzeug/sansio/request.py index 8832baa..dd0805d 100644 --- a/src/werkzeug/sansio/request.py +++ b/src/werkzeug/sansio/request.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import typing as t from datetime import datetime +from urllib.parse import parse_qsl -from .._internal import _to_str from ..datastructures import Accept from ..datastructures import Authorization from ..datastructures import CharsetAccept @@ -17,7 +19,6 @@ from ..datastructures import Range from ..datastructures import RequestCacheControl from ..http import parse_accept_header -from ..http import parse_authorization_header from ..http import parse_cache_control_header from ..http import parse_date from ..http import parse_etags @@ -26,11 +27,11 @@ from ..http import parse_options_header from ..http import parse_range_header from ..http import parse_set_header -from ..urls import url_decode from ..user_agent import UserAgent from ..utils import cached_property from ..utils import header_property from .http import parse_cookie +from .utils import get_content_length from .utils import get_current_url from .utils import get_host @@ -57,15 +58,13 @@ class Request: :param headers: The headers received with the request. :param remote_addr: The address of the client sending the request. + .. versionchanged:: 3.0 + The ``charset``, ``url_charset``, and ``encoding_errors`` attributes + were removed. + .. versionadded:: 2.0 """ - #: The charset used to decode most data in the request. - charset = "utf-8" - - #: the error handling procedure for errors, defaults to 'replace' - encoding_errors = "replace" - #: the class to use for `args` and `form`. The default is an #: :class:`~werkzeug.datastructures.ImmutableMultiDict` which supports #: multiple values per key. alternatively it makes sense to use an @@ -75,7 +74,7 @@ class Request: #: possible to use mutable structures, but this is not recommended. #: #: .. versionadded:: 0.6 - parameter_storage_class: t.Type[MultiDict] = ImmutableMultiDict + parameter_storage_class: type[MultiDict[str, t.Any]] = ImmutableMultiDict #: The type to be used for dict values from the incoming WSGI #: environment. (For example for :attr:`cookies`.) By default an @@ -85,16 +84,16 @@ class Request: #: Changed to ``ImmutableMultiDict`` to support multiple values. #: #: .. versionadded:: 0.6 - dict_storage_class: t.Type[MultiDict] = ImmutableMultiDict + dict_storage_class: type[MultiDict[str, t.Any]] = ImmutableMultiDict #: the type to be used for list values from the incoming WSGI environment. #: By default an :class:`~werkzeug.datastructures.ImmutableList` is used #: (for example for :attr:`access_list`). #: #: .. versionadded:: 0.6 - list_storage_class: t.Type[t.List] = ImmutableList + list_storage_class: type[list[t.Any]] = ImmutableList - user_agent_class: t.Type[UserAgent] = UserAgent + user_agent_class: type[UserAgent] = UserAgent """The class used and returned by the :attr:`user_agent` property to parse the header. Defaults to :class:`~werkzeug.user_agent.UserAgent`, which does no parsing. An @@ -114,18 +113,18 @@ class Request: #: the application is being run behind one). #: #: .. versionadded:: 0.9 - trusted_hosts: t.Optional[t.List[str]] = None + trusted_hosts: list[str] | None = None def __init__( self, method: str, scheme: str, - server: t.Optional[t.Tuple[str, t.Optional[int]]], + server: tuple[str, int | None] | None, root_path: str, path: str, query_string: bytes, headers: Headers, - remote_addr: t.Optional[str], + remote_addr: str | None, ) -> None: #: The method the request was made with, such as ``GET``. self.method = method.upper() @@ -157,17 +156,8 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {url!r} [{self.method}]>" - @property - def url_charset(self) -> str: - """The charset that is assumed for URLs. Defaults to the value - of :attr:`charset`. - - .. versionadded:: 0.6 - """ - return self.charset - @cached_property - def args(self) -> "MultiDict[str, str]": + def args(self) -> MultiDict[str, str]: """The parsed URL parameters (the part in the URL after the question mark). @@ -176,16 +166,20 @@ def args(self) -> "MultiDict[str, str]": is returned from this function. This can be changed by setting :attr:`parameter_storage_class` to a different type. This might be necessary if the order of the form data is important. + + .. versionchanged:: 2.3 + Invalid bytes remain percent encoded. """ - return url_decode( - self.query_string, - self.url_charset, - errors=self.encoding_errors, - cls=self.parameter_storage_class, + return self.parameter_storage_class( + parse_qsl( + self.query_string.decode(), + keep_blank_values=True, + errors="werkzeug.url_quote", + ) ) @cached_property - def access_route(self) -> t.List[str]: + def access_route(self) -> list[str]: """If a forwarded header exists this is a list of all ip addresses from the client ip to the last proxy server. """ @@ -200,7 +194,7 @@ def access_route(self) -> t.List[str]: @cached_property def full_path(self) -> str: """Requested path, including the query string.""" - return f"{self.path}?{_to_str(self.query_string, self.url_charset)}" + return f"{self.path}?{self.query_string.decode()}" @property def is_secure(self) -> bool: @@ -244,15 +238,12 @@ def host(self) -> str: ) @cached_property - def cookies(self) -> "ImmutableMultiDict[str, str]": + def cookies(self) -> ImmutableMultiDict[str, str]: """A :class:`dict` with the contents of all cookies transmitted with the request.""" wsgi_combined_cookie = ";".join(self.headers.getlist("Cookie")) return parse_cookie( # type: ignore - wsgi_combined_cookie, - self.charset, - self.encoding_errors, - cls=self.dict_storage_class, + wsgi_combined_cookie, cls=self.dict_storage_class ) # Common Descriptors @@ -267,23 +258,16 @@ def cookies(self) -> "ImmutableMultiDict[str, str]": ) @cached_property - def content_length(self) -> t.Optional[int]: + def content_length(self) -> int | None: """The Content-Length entity-header field indicates the size of the entity-body in bytes or, in the case of the HEAD method, the size of the entity-body that would have been sent had the request been a GET. """ - if self.headers.get("Transfer-Encoding", "") == "chunked": - return None - - content_length = self.headers.get("Content-Length") - if content_length is not None: - try: - return max(0, int(content_length)) - except (ValueError, TypeError): - pass - - return None + return get_content_length( + http_content_length=self.headers.get("Content-Length"), + http_transfer_encoding=self.headers.get("Transfer-Encoding"), + ) content_encoding = header_property[str]( "Content-Encoding", @@ -358,7 +342,7 @@ def mimetype(self) -> str: return self._parsed_content_type[0].lower() @property - def mimetype_params(self) -> t.Dict[str, str]: + def mimetype_params(self) -> dict[str, str]: """The mimetype parameters as dict. For example if the content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. @@ -438,7 +422,7 @@ def if_none_match(self) -> ETags: return parse_etags(self.headers.get("If-None-Match")) @cached_property - def if_modified_since(self) -> t.Optional[datetime]: + def if_modified_since(self) -> datetime | None: """The parsed `If-Modified-Since` header as a datetime object. .. versionchanged:: 2.0 @@ -447,7 +431,7 @@ def if_modified_since(self) -> t.Optional[datetime]: return parse_date(self.headers.get("If-Modified-Since")) @cached_property - def if_unmodified_since(self) -> t.Optional[datetime]: + def if_unmodified_since(self) -> datetime | None: """The parsed `If-Unmodified-Since` header as a datetime object. .. versionchanged:: 2.0 @@ -467,7 +451,7 @@ def if_range(self) -> IfRange: return parse_if_range_header(self.headers.get("If-Range")) @cached_property - def range(self) -> t.Optional[Range]: + def range(self) -> Range | None: """The parsed `Range` header. .. versionadded:: 0.7 @@ -485,19 +469,24 @@ def user_agent(self) -> UserAgent: :class:`~werkzeug.user_agent.UserAgent` to provide parsing for the other properties or other extended data. - .. versionchanged:: 2.0 - The built in parser is deprecated and will be removed in - Werkzeug 2.1. A ``UserAgent`` subclass must be set to parse - data from the string. + .. versionchanged:: 2.1 + The built-in parser was removed. Set ``user_agent_class`` to a ``UserAgent`` + subclass to parse data from the string. """ return self.user_agent_class(self.headers.get("User-Agent", "")) # Authorization @cached_property - def authorization(self) -> t.Optional[Authorization]: - """The `Authorization` object in parsed form.""" - return parse_authorization_header(self.headers.get("Authorization")) + def authorization(self) -> Authorization | None: + """The ``Authorization`` header parsed into an :class:`.Authorization` object. + ``None`` if the header is not present. + + .. versionchanged:: 2.3 + :class:`Authorization` is no longer a ``dict``. The ``token`` attribute + was added for auth schemes that use a token instead of parameters. + """ + return Authorization.from_header(self.headers.get("Authorization")) # CORS diff --git a/src/werkzeug/sansio/response.py b/src/werkzeug/sansio/response.py index de0bec2..9093b0a 100644 --- a/src/werkzeug/sansio/response.py +++ b/src/werkzeug/sansio/response.py @@ -1,41 +1,44 @@ +from __future__ import annotations + import typing as t from datetime import datetime from datetime import timedelta from datetime import timezone from http import HTTPStatus -from .._internal import _to_str +from ..datastructures import CallbackDict +from ..datastructures import ContentRange +from ..datastructures import ContentSecurityPolicy from ..datastructures import Headers from ..datastructures import HeaderSet +from ..datastructures import ResponseCacheControl +from ..datastructures import WWWAuthenticate +from ..http import COEP +from ..http import COOP +from ..http import dump_age from ..http import dump_cookie +from ..http import dump_header +from ..http import dump_options_header +from ..http import http_date from ..http import HTTP_STATUS_CODES +from ..http import parse_age +from ..http import parse_cache_control_header +from ..http import parse_content_range_header +from ..http import parse_csp_header +from ..http import parse_date +from ..http import parse_options_header +from ..http import parse_set_header +from ..http import quote_etag +from ..http import unquote_etag from ..utils import get_content_type -from werkzeug.datastructures import CallbackDict -from werkzeug.datastructures import ContentRange -from werkzeug.datastructures import ContentSecurityPolicy -from werkzeug.datastructures import ResponseCacheControl -from werkzeug.datastructures import WWWAuthenticate -from werkzeug.http import COEP -from werkzeug.http import COOP -from werkzeug.http import dump_age -from werkzeug.http import dump_header -from werkzeug.http import dump_options_header -from werkzeug.http import http_date -from werkzeug.http import parse_age -from werkzeug.http import parse_cache_control_header -from werkzeug.http import parse_content_range_header -from werkzeug.http import parse_csp_header -from werkzeug.http import parse_date -from werkzeug.http import parse_options_header -from werkzeug.http import parse_set_header -from werkzeug.http import parse_www_authenticate_header -from werkzeug.http import quote_etag -from werkzeug.http import unquote_etag -from werkzeug.utils import header_property - - -def _set_property(name: str, doc: t.Optional[str] = None) -> property: - def fget(self: "Response") -> HeaderSet: +from ..utils import header_property + +if t.TYPE_CHECKING: + from ..datastructures.cache_control import _CacheControl + + +def _set_property(name: str, doc: str | None = None) -> property: + def fget(self: Response) -> HeaderSet: def on_update(header_set: HeaderSet) -> None: if not header_set and name in self.headers: del self.headers[name] @@ -45,10 +48,8 @@ def on_update(header_set: HeaderSet) -> None: return parse_set_header(self.headers.get(name), on_update) def fset( - self: "Response", - value: t.Optional[ - t.Union[str, t.Dict[str, t.Union[str, int]], t.Iterable[str]] - ], + self: Response, + value: None | (str | dict[str, str | int] | t.Iterable[str]), ) -> None: if not value: del self.headers[name] @@ -82,17 +83,17 @@ class Response: :param content_type: The full content type of the response. Overrides building the value from ``mimetype``. + .. versionchanged:: 3.0 + The ``charset`` attribute was removed. + .. versionadded:: 2.0 """ - #: the charset of the response. - charset = "utf-8" - #: the default status if none is provided. default_status = 200 #: the default mimetype if none is provided. - default_mimetype: t.Optional[str] = "text/plain" + default_mimetype: str | None = "text/plain" #: Warn if a cookie header exceeds this size. The default, 4093, should be #: safely `supported by most browsers `_. A cookie larger than @@ -109,15 +110,12 @@ class Response: def __init__( self, - status: t.Optional[t.Union[int, str, HTTPStatus]] = None, - headers: t.Optional[ - t.Union[ - t.Mapping[str, t.Union[str, int, t.Iterable[t.Union[str, int]]]], - t.Iterable[t.Tuple[str, t.Union[str, int]]], - ] - ] = None, - mimetype: t.Optional[str] = None, - content_type: t.Optional[str] = None, + status: int | str | HTTPStatus | None = None, + headers: t.Mapping[str, str | t.Iterable[str]] + | t.Iterable[tuple[str, str]] + | None = None, + mimetype: str | None = None, + content_type: str | None = None, ) -> None: if isinstance(headers, Headers): self.headers = headers @@ -130,7 +128,7 @@ def __init__( if mimetype is None and "content-type" not in self.headers: mimetype = self.default_mimetype if mimetype is not None: - mimetype = get_content_type(mimetype, self.charset) + mimetype = get_content_type(mimetype, "utf-8") content_type = mimetype if content_type is not None: self.headers["Content-Type"] = content_type @@ -156,30 +154,29 @@ def status(self) -> str: return self._status @status.setter - def status(self, value: t.Union[str, int, HTTPStatus]) -> None: - if not isinstance(value, (str, bytes, int, HTTPStatus)): - raise TypeError("Invalid status argument") - + def status(self, value: str | int | HTTPStatus) -> None: self._status, self._status_code = self._clean_status(value) - def _clean_status(self, value: t.Union[str, int, HTTPStatus]) -> t.Tuple[str, int]: - if isinstance(value, HTTPStatus): - value = int(value) - status = _to_str(value, self.charset) - split_status = status.split(None, 1) + def _clean_status(self, value: str | int | HTTPStatus) -> tuple[str, int]: + if isinstance(value, (int, HTTPStatus)): + status_code = int(value) + else: + value = value.strip() - if len(split_status) == 0: - raise ValueError("Empty status argument") + if not value: + raise ValueError("Empty status argument") - try: - status_code = int(split_status[0]) - except ValueError: - # only message - return f"0 {status}", 0 + code_str, sep, _ = value.partition(" ") + + try: + status_code = int(code_str) + except ValueError: + # only message + return f"0 {value}", 0 - if len(split_status) > 1: - # code and message - return status, status_code + if sep: + # code and message + return value, status_code # only code, look up message try: @@ -193,13 +190,13 @@ def set_cookie( self, key: str, value: str = "", - max_age: t.Optional[t.Union[timedelta, int]] = None, - expires: t.Optional[t.Union[str, datetime, int, float]] = None, - path: t.Optional[str] = "/", - domain: t.Optional[str] = None, + max_age: timedelta | int | None = None, + expires: str | datetime | int | float | None = None, + path: str | None = "/", + domain: str | None = None, secure: bool = False, httponly: bool = False, - samesite: t.Optional[str] = None, + samesite: str | None = None, ) -> None: """Sets a cookie. @@ -215,7 +212,7 @@ def set_cookie( :param path: limits the cookie to a given path, per default it will span the whole domain. :param domain: if you want to set a cross-domain cookie. For example, - ``domain=".example.com"`` will set a cookie that is + ``domain="example.com"`` will set a cookie that is readable by the domain ``www.example.com``, ``foo.example.com`` etc. Otherwise, a cookie will only be readable by the domain that set it. @@ -236,7 +233,6 @@ def set_cookie( domain=domain, secure=secure, httponly=httponly, - charset=self.charset, max_size=self.max_cookie_size, samesite=samesite, ), @@ -245,11 +241,11 @@ def set_cookie( def delete_cookie( self, key: str, - path: str = "/", - domain: t.Optional[str] = None, + path: str | None = "/", + domain: str | None = None, secure: bool = False, httponly: bool = False, - samesite: t.Optional[str] = None, + samesite: str | None = None, ) -> None: """Delete a cookie. Fails silently if key doesn't exist. @@ -290,7 +286,7 @@ def is_json(self) -> bool: # Common Descriptors @property - def mimetype(self) -> t.Optional[str]: + def mimetype(self) -> str | None: """The mimetype (content type without charset etc.)""" ct = self.headers.get("content-type") @@ -301,10 +297,10 @@ def mimetype(self) -> t.Optional[str]: @mimetype.setter def mimetype(self, value: str) -> None: - self.headers["Content-Type"] = get_content_type(value, self.charset) + self.headers["Content-Type"] = get_content_type(value, "utf-8") @property - def mimetype_params(self) -> t.Dict[str, str]: + def mimetype_params(self) -> dict[str, str]: """The mimetype parameters as dict. For example if the content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. @@ -312,7 +308,7 @@ def mimetype_params(self) -> t.Dict[str, str]: .. versionadded:: 0.5 """ - def on_update(d: CallbackDict) -> None: + def on_update(d: CallbackDict[str, str]) -> None: self.headers["Content-Type"] = dump_options_header(self.mimetype, d) d = parse_options_header(self.headers.get("content-type", ""))[1] @@ -421,7 +417,7 @@ def on_update(d: CallbackDict) -> None: ) @property - def retry_after(self) -> t.Optional[datetime]: + def retry_after(self) -> datetime | None: """The Retry-After response-header field can be used with a 503 (Service Unavailable) response to indicate how long the service is expected to be unavailable to the requesting client. @@ -443,7 +439,7 @@ def retry_after(self) -> t.Optional[datetime]: return datetime.now(timezone.utc) + timedelta(seconds=seconds) @retry_after.setter - def retry_after(self, value: t.Optional[t.Union[datetime, int, str]]) -> None: + def retry_after(self, value: datetime | int | str | None) -> None: if value is None: if "retry-after" in self.headers: del self.headers["retry-after"] @@ -487,7 +483,7 @@ def cache_control(self) -> ResponseCacheControl: request/response chain. """ - def on_update(cache_control: ResponseCacheControl) -> None: + def on_update(cache_control: _CacheControl) -> None: if not cache_control and "cache-control" in self.headers: del self.headers["cache-control"] elif cache_control: @@ -501,7 +497,7 @@ def set_etag(self, etag: str, weak: bool = False) -> None: """Set the etag, and override the old one if there was one.""" self.headers["ETag"] = quote_etag(etag, weak) - def get_etag(self) -> t.Union[t.Tuple[str, bool], t.Tuple[None, None]]: + def get_etag(self) -> tuple[str, bool] | tuple[None, None]: """Return a tuple in the form ``(etag, is_weak)``. If there is no ETag the return value is ``(None, None)``. """ @@ -542,7 +538,7 @@ def on_update(rng: ContentRange) -> None: return rv @content_range.setter - def content_range(self, value: t.Optional[t.Union[ContentRange, str]]) -> None: + def content_range(self, value: ContentRange | str | None) -> None: if not value: del self.headers["content-range"] elif isinstance(value, str): @@ -554,16 +550,70 @@ def content_range(self, value: t.Optional[t.Union[ContentRange, str]]) -> None: @property def www_authenticate(self) -> WWWAuthenticate: - """The ``WWW-Authenticate`` header in a parsed form.""" + """The ``WWW-Authenticate`` header parsed into a :class:`.WWWAuthenticate` + object. Modifying the object will modify the header value. + + This header is not set by default. To set this header, assign an instance of + :class:`.WWWAuthenticate` to this attribute. + + .. code-block:: python + + response.www_authenticate = WWWAuthenticate( + "basic", {"realm": "Authentication Required"} + ) + + Multiple values for this header can be sent to give the client multiple options. + Assign a list to set multiple headers. However, modifying the items in the list + will not automatically update the header values, and accessing this attribute + will only ever return the first value. + + To unset this header, assign ``None`` or use ``del``. + + .. versionchanged:: 2.3 + This attribute can be assigned to to set the header. A list can be assigned + to set multiple header values. Use ``del`` to unset the header. + + .. versionchanged:: 2.3 + :class:`WWWAuthenticate` is no longer a ``dict``. The ``token`` attribute + was added for auth challenges that use a token instead of parameters. + """ + value = WWWAuthenticate.from_header(self.headers.get("WWW-Authenticate")) + + if value is None: + value = WWWAuthenticate("basic") + + def on_update(value: WWWAuthenticate) -> None: + self.www_authenticate = value + + value._on_update = on_update + return value + + @www_authenticate.setter + def www_authenticate( + self, value: WWWAuthenticate | list[WWWAuthenticate] | None + ) -> None: + if not value: # None or empty list + del self.www_authenticate + elif isinstance(value, list): + # Clear any existing header by setting the first item. + self.headers.set("WWW-Authenticate", value[0].to_header()) + + for item in value[1:]: + # Add additional header lines for additional items. + self.headers.add("WWW-Authenticate", item.to_header()) + else: + self.headers.set("WWW-Authenticate", value.to_header()) + + def on_update(value: WWWAuthenticate) -> None: + self.www_authenticate = value - def on_update(www_auth: WWWAuthenticate) -> None: - if not www_auth and "www-authenticate" in self.headers: - del self.headers["www-authenticate"] - elif www_auth: - self.headers["WWW-Authenticate"] = www_auth.to_header() + # When setting a single value, allow updating it directly. + value._on_update = on_update - header = self.headers.get("www-authenticate") - return parse_www_authenticate_header(header, on_update) + @www_authenticate.deleter + def www_authenticate(self) -> None: + if "WWW-Authenticate" in self.headers: + del self.headers["WWW-Authenticate"] # CSP @@ -590,7 +640,7 @@ def on_update(csp: ContentSecurityPolicy) -> None: @content_security_policy.setter def content_security_policy( - self, value: t.Optional[t.Union[ContentSecurityPolicy, str]] + self, value: ContentSecurityPolicy | str | None ) -> None: if not value: del self.headers["content-security-policy"] @@ -625,7 +675,7 @@ def on_update(csp: ContentSecurityPolicy) -> None: @content_security_policy_report_only.setter def content_security_policy_report_only( - self, value: t.Optional[t.Union[ContentSecurityPolicy, str]] + self, value: ContentSecurityPolicy | str | None ) -> None: if not value: del self.headers["content-security-policy-report-only"] @@ -645,7 +695,7 @@ def access_control_allow_credentials(self) -> bool: return "Access-Control-Allow-Credentials" in self.headers @access_control_allow_credentials.setter - def access_control_allow_credentials(self, value: t.Optional[bool]) -> None: + def access_control_allow_credentials(self, value: bool | None) -> None: if value is True: self.headers["Access-Control-Allow-Credentials"] = "true" else: diff --git a/src/werkzeug/sansio/utils.py b/src/werkzeug/sansio/utils.py index e639dcb..14fa0ac 100644 --- a/src/werkzeug/sansio/utils.py +++ b/src/werkzeug/sansio/utils.py @@ -1,12 +1,14 @@ +from __future__ import annotations + import typing as t +from urllib.parse import quote -from .._internal import _encode_idna +from .._internal import _plain_int from ..exceptions import SecurityError from ..urls import uri_to_iri -from ..urls import url_quote -def host_is_trusted(hostname: str, trusted_list: t.Iterable[str]) -> bool: +def host_is_trusted(hostname: str | None, trusted_list: t.Iterable[str]) -> bool: """Check if a host matches a list of trusted names. :param hostname: The name to check. @@ -18,20 +20,14 @@ def host_is_trusted(hostname: str, trusted_list: t.Iterable[str]) -> bool: if not hostname: return False - if isinstance(trusted_list, str): - trusted_list = [trusted_list] - - def _normalize(hostname: str) -> bytes: - if ":" in hostname: - hostname = hostname.rsplit(":", 1)[0] - - return _encode_idna(hostname) - try: - hostname_bytes = _normalize(hostname) - except UnicodeError: + hostname = hostname.partition(":")[0].encode("idna").decode("ascii") + except UnicodeEncodeError: return False + if isinstance(trusted_list, str): + trusted_list = [trusted_list] + for ref in trusted_list: if ref.startswith("."): ref = ref[1:] @@ -40,14 +36,11 @@ def _normalize(hostname: str) -> bytes: suffix_match = False try: - ref_bytes = _normalize(ref) - except UnicodeError: + ref = ref.partition(":")[0].encode("idna").decode("ascii") + except UnicodeEncodeError: return False - if ref_bytes == hostname_bytes: - return True - - if suffix_match and hostname_bytes.endswith(b"." + ref_bytes): + if ref == hostname or (suffix_match and hostname.endswith(f".{ref}")): return True return False @@ -55,9 +48,9 @@ def _normalize(hostname: str) -> bytes: def get_host( scheme: str, - host_header: t.Optional[str], - server: t.Optional[t.Tuple[str, t.Optional[int]]] = None, - trusted_hosts: t.Optional[t.Iterable[str]] = None, + host_header: str | None, + server: tuple[str, int | None] | None = None, + trusted_hosts: t.Iterable[str] | None = None, ) -> str: """Return the host for the given parameters. @@ -104,9 +97,9 @@ def get_host( def get_current_url( scheme: str, host: str, - root_path: t.Optional[str] = None, - path: t.Optional[str] = None, - query_string: t.Optional[bytes] = None, + root_path: str | None = None, + path: str | None = None, + query_string: bytes | None = None, ) -> str: """Recreate the URL for a request. If an optional part isn't provided, it and subsequent parts are not included in the URL. @@ -127,39 +120,40 @@ def get_current_url( url.append("/") return uri_to_iri("".join(url)) - url.append(url_quote(root_path.rstrip("/"))) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + # as well as percent for things that are already quoted + url.append(quote(root_path.rstrip("/"), safe="!$&'()*+,/:;=@%")) url.append("/") if path is None: return uri_to_iri("".join(url)) - url.append(url_quote(path.lstrip("/"))) + url.append(quote(path.lstrip("/"), safe="!$&'()*+,/:;=@%")) if query_string: url.append("?") - url.append(url_quote(query_string, safe=":&%=+$!*'(),")) + url.append(quote(query_string, safe="!$&'()*+,/:;=?@%")) return uri_to_iri("".join(url)) def get_content_length( - http_content_length: t.Union[str, None] = None, - http_transfer_encoding: t.Union[str, None] = "", -) -> t.Optional[int]: - """Returns the content length as an integer or ``None`` if - unavailable or chunked transfer encoding is used. + http_content_length: str | None = None, + http_transfer_encoding: str | None = None, +) -> int | None: + """Return the ``Content-Length`` header value as an int. If the header is not given + or the ``Transfer-Encoding`` header is ``chunked``, ``None`` is returned to indicate + a streaming request. If the value is not an integer, or negative, 0 is returned. :param http_content_length: The Content-Length HTTP header. :param http_transfer_encoding: The Transfer-Encoding HTTP header. .. versionadded:: 2.2 """ - if http_transfer_encoding == "chunked": + if http_transfer_encoding == "chunked" or http_content_length is None: return None - if http_content_length is not None: - try: - return max(0, int(http_content_length)) - except (ValueError, TypeError): - pass - return None + try: + return max(0, _plain_int(http_content_length)) + except ValueError: + return 0 diff --git a/src/werkzeug/security.py b/src/werkzeug/security.py index 18d0919..9999509 100644 --- a/src/werkzeug/security.py +++ b/src/werkzeug/security.py @@ -1,113 +1,134 @@ +from __future__ import annotations + import hashlib import hmac import os import posixpath import secrets -import typing as t - -if t.TYPE_CHECKING: - pass SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -DEFAULT_PBKDF2_ITERATIONS = 260000 +DEFAULT_PBKDF2_ITERATIONS = 600000 -_os_alt_seps: t.List[str] = list( - sep for sep in [os.path.sep, os.path.altsep] if sep is not None and sep != "/" +_os_alt_seps: list[str] = list( + sep for sep in [os.sep, os.path.altsep] if sep is not None and sep != "/" ) def gen_salt(length: int) -> str: """Generate a random string of SALT_CHARS with specified ``length``.""" if length <= 0: - raise ValueError("Salt length must be positive") + raise ValueError("Salt length must be at least 1.") return "".join(secrets.choice(SALT_CHARS) for _ in range(length)) -def _hash_internal(method: str, salt: str, password: str) -> t.Tuple[str, str]: - """Internal password hash helper. Supports plaintext without salt, - unsalted and salted passwords. In case salted passwords are used - hmac is used. - """ - if method == "plain": - return password, method - - salt = salt.encode("utf-8") - password = password.encode("utf-8") - - if method.startswith("pbkdf2:"): - if not salt: - raise ValueError("Salt is required for PBKDF2") - - args = method[7:].split(":") +def _hash_internal(method: str, salt: str, password: str) -> tuple[str, str]: + method, *args = method.split(":") + salt_bytes = salt.encode() + password_bytes = password.encode() - if len(args) not in (1, 2): - raise ValueError("Invalid number of arguments for PBKDF2") + if method == "scrypt": + if not args: + n = 2**15 + r = 8 + p = 1 + else: + try: + n, r, p = map(int, args) + except ValueError: + raise ValueError("'scrypt' takes 3 arguments.") from None - method = args.pop(0) - iterations = int(args[0] or 0) if args else DEFAULT_PBKDF2_ITERATIONS + maxmem = 132 * n * r * p # ideally 128, but some extra seems needed return ( - hashlib.pbkdf2_hmac(method, password, salt, iterations).hex(), - f"pbkdf2:{method}:{iterations}", + hashlib.scrypt( + password_bytes, salt=salt_bytes, n=n, r=r, p=p, maxmem=maxmem + ).hex(), + f"scrypt:{n}:{r}:{p}", ) + elif method == "pbkdf2": + len_args = len(args) + + if len_args == 0: + hash_name = "sha256" + iterations = DEFAULT_PBKDF2_ITERATIONS + elif len_args == 1: + hash_name = args[0] + iterations = DEFAULT_PBKDF2_ITERATIONS + elif len_args == 2: + hash_name = args[0] + iterations = int(args[1]) + else: + raise ValueError("'pbkdf2' takes 2 arguments.") - if salt: - return hmac.new(salt, password, method).hexdigest(), method - - return hashlib.new(method, password).hexdigest(), method + return ( + hashlib.pbkdf2_hmac( + hash_name, password_bytes, salt_bytes, iterations + ).hex(), + f"pbkdf2:{hash_name}:{iterations}", + ) + else: + raise ValueError(f"Invalid hash method '{method}'.") def generate_password_hash( - password: str, method: str = "pbkdf2:sha256", salt_length: int = 16 + password: str, method: str = "scrypt", salt_length: int = 16 ) -> str: - """Hash a password with the given method and salt with a string of - the given length. The format of the string returned includes the method - that was used so that :func:`check_password_hash` can check the hash. + """Securely hash a password for storage. A password can be compared to a stored hash + using :func:`check_password_hash`. - The format for the hashed string looks like this:: + The following methods are supported: - method$salt$hash + - ``scrypt``, the default. The parameters are ``n``, ``r``, and ``p``, the default + is ``scrypt:32768:8:1``. See :func:`hashlib.scrypt`. + - ``pbkdf2``, less secure. The parameters are ``hash_method`` and ``iterations``, + the default is ``pbkdf2:sha256:600000``. See :func:`hashlib.pbkdf2_hmac`. - This method can **not** generate unsalted passwords but it is possible - to set param method='plain' in order to enforce plaintext passwords. - If a salt is used, hmac is used internally to salt the password. + Default parameters may be updated to reflect current guidelines, and methods may be + deprecated and removed if they are no longer considered secure. To migrate old + hashes, you may generate a new hash when checking an old hash, or you may contact + users with a link to reset their password. - If PBKDF2 is wanted it can be enabled by setting the method to - ``pbkdf2:method:iterations`` where iterations is optional:: + :param password: The plaintext password. + :param method: The key derivation function and parameters. + :param salt_length: The number of characters to generate for the salt. - pbkdf2:sha256:80000$salt$hash - pbkdf2:sha256$salt$hash + .. versionchanged:: 2.3 + Scrypt support was added. - :param password: the password to hash. - :param method: the hash method to use (one that hashlib supports). Can - optionally be in the format ``pbkdf2:method:iterations`` - to enable PBKDF2. - :param salt_length: the length of the salt in letters. + .. versionchanged:: 2.3 + The default iterations for pbkdf2 was increased to 600,000. + + .. versionchanged:: 2.3 + All plain hashes are deprecated and will not be supported in Werkzeug 3.0. """ - salt = gen_salt(salt_length) if method != "plain" else "" + salt = gen_salt(salt_length) h, actual_method = _hash_internal(method, salt, password) return f"{actual_method}${salt}${h}" def check_password_hash(pwhash: str, password: str) -> bool: - """Check a password against a given salted and hashed password value. - In order to support unsalted legacy passwords this method supports - plain text passwords, md5 and sha1 hashes (both salted and unsalted). + """Securely check that the given stored password hash, previously generated using + :func:`generate_password_hash`, matches the given password. + + Methods may be deprecated and removed if they are no longer considered secure. To + migrate old hashes, you may generate a new hash when checking an old hash, or you + may contact users with a link to reset their password. - Returns `True` if the password matched, `False` otherwise. + :param pwhash: The hashed password. + :param password: The plaintext password. - :param pwhash: a hashed string like returned by - :func:`generate_password_hash`. - :param password: the plaintext password to compare against the hash. + .. versionchanged:: 2.3 + All plain hashes are deprecated and will not be supported in Werkzeug 3.0. """ - if pwhash.count("$") < 2: + try: + method, salt, hashval = pwhash.split("$", 2) + except ValueError: return False - method, salt, hashval = pwhash.split("$", 2) return hmac.compare_digest(_hash_internal(method, salt, password)[0], hashval) -def safe_join(directory: str, *pathnames: str) -> t.Optional[str]: +def safe_join(directory: str, *pathnames: str) -> str | None: """Safely join zero or more untrusted path components to a base directory to avoid escaping the base directory. diff --git a/src/werkzeug/serving.py b/src/werkzeug/serving.py index c482469..859f9aa 100644 --- a/src/werkzeug/serving.py +++ b/src/werkzeug/serving.py @@ -11,9 +11,13 @@ from myapp import create_app from werkzeug import run_simple """ + +from __future__ import annotations + import errno import io import os +import selectors import socket import socketserver import sys @@ -23,13 +27,13 @@ from datetime import timezone from http.server import BaseHTTPRequestHandler from http.server import HTTPServer +from urllib.parse import unquote +from urllib.parse import urlsplit from ._internal import _log from ._internal import _wsgi_encoding_dance from .exceptions import InternalServerError from .urls import uri_to_iri -from .urls import url_parse -from .urls import url_unquote try: import ssl @@ -70,11 +74,10 @@ class ForkingMixIn: # type: ignore LISTEN_QUEUE = 128 _TSSLContextArg = t.Optional[ - t.Union["ssl.SSLContext", t.Tuple[str, t.Optional[str]], "te.Literal['adhoc']"] + t.Union["ssl.SSLContext", t.Tuple[str, t.Optional[str]], t.Literal["adhoc"]] ] if t.TYPE_CHECKING: - import typing_extensions as te # noqa: F401 from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment from cryptography.hazmat.primitives.asymmetric.rsa import ( @@ -148,16 +151,14 @@ def readinto(self, buf: bytearray) -> int: # type: ignore class WSGIRequestHandler(BaseHTTPRequestHandler): """A request handler that implements WSGI dispatching.""" - server: "BaseWSGIServer" + server: BaseWSGIServer @property def server_version(self) -> str: # type: ignore - from . import __version__ - - return f"Werkzeug/{__version__}" + return self.server._server_version - def make_environ(self) -> "WSGIEnvironment": - request_url = url_parse(self.path) + def make_environ(self) -> WSGIEnvironment: + request_url = urlsplit(self.path) url_scheme = "http" if self.server.ssl_context is None else "https" if not self.client_address: @@ -173,9 +174,9 @@ def make_environ(self) -> "WSGIEnvironment": else: path_info = request_url.path - path_info = url_unquote(path_info) + path_info = unquote(path_info) - environ: "WSGIEnvironment" = { + environ: WSGIEnvironment = { "wsgi.version": (1, 0), "wsgi.url_scheme": url_scheme, "wsgi.input": self.rfile, @@ -201,6 +202,9 @@ def make_environ(self) -> "WSGIEnvironment": } for key, value in self.headers.items(): + if "_" in key: + continue + key = key.upper().replace("-", "_") value = value.replace("\r\n", "") if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): @@ -221,9 +225,7 @@ def make_environ(self) -> "WSGIEnvironment": try: # binary_form=False gives nicer information, but wouldn't be compatible with # what Nginx or Apache could return. - peer_cert = self.connection.getpeercert( # type: ignore[attr-defined] - binary_form=True - ) + peer_cert = self.connection.getpeercert(binary_form=True) if peer_cert is not None: # Nginx and Apache use PEM format. environ["SSL_CLIENT_CERT"] = ssl.DER_cert_to_PEM_cert(peer_cert) @@ -241,10 +243,10 @@ def run_wsgi(self) -> None: self.wfile.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.environ = environ = self.make_environ() - status_set: t.Optional[str] = None - headers_set: t.Optional[t.List[t.Tuple[str, str]]] = None - status_sent: t.Optional[str] = None - headers_sent: t.Optional[t.List[t.Tuple[str, str]]] = None + status_set: str | None = None + headers_set: list[tuple[str, str]] | None = None + status_sent: str | None = None + headers_sent: list[tuple[str, str]] | None = None chunk_response: bool = False def write(data: bytes) -> None: @@ -318,7 +320,7 @@ def start_response(status, headers, exc_info=None): # type: ignore headers_set = headers return write - def execute(app: "WSGIApplication") -> None: + def execute(app: WSGIApplication) -> None: application_iter = app(environ, start_response) try: for data in application_iter: @@ -328,8 +330,34 @@ def execute(app: "WSGIApplication") -> None: if chunk_response: self.wfile.write(b"0\r\n\r\n") finally: + # Check for any remaining data in the read socket, and discard it. This + # will read past request.max_content_length, but lets the client see a + # 413 response instead of a connection reset failure. If we supported + # keep-alive connections, this naive approach would break by reading the + # next request line. Since we know that write (above) closes every + # connection we can read everything. + selector = selectors.DefaultSelector() + selector.register(self.connection, selectors.EVENT_READ) + total_size = 0 + total_reads = 0 + + # A timeout of 0 tends to fail because a client needs a small amount of + # time to continue sending its data. + while selector.select(timeout=0.01): + # Only read 10MB into memory at a time. + data = self.rfile.read(10_000_000) + total_size += len(data) + total_reads += 1 + + # Stop reading on no data, >=10GB, or 1000 reads. If a client sends + # more than that, they'll get a connection reset failure. + if not data or total_size >= 10_000_000_000 or total_reads > 1000: + break + + selector.close() + if hasattr(application_iter, "close"): - application_iter.close() # type: ignore + application_iter.close() try: execute(self.server.app) @@ -370,7 +398,7 @@ def handle(self) -> None: raise def connection_dropped( - self, error: BaseException, environ: t.Optional["WSGIEnvironment"] = None + self, error: BaseException, environ: WSGIEnvironment | None = None ) -> None: """Called if the connection was closed by the client. By default nothing happens. @@ -396,9 +424,13 @@ def address_string(self) -> str: def port_integer(self) -> int: return self.client_address[1] - def log_request( - self, code: t.Union[int, str] = "-", size: t.Union[int, str] = "-" - ) -> None: + # Escape control characters. This is defined (but private) in Python 3.12. + _control_char_table = str.maketrans( + {c: rf"\x{c:02x}" for c in [*range(0x20), *range(0x7F, 0xA0)]} + ) + _control_char_table[ord("\\")] = r"\\" + + def log_request(self, code: int | str = "-", size: int | str = "-") -> None: try: path = uri_to_iri(self.path) msg = f"{self.command} {path} {self.request_version}" @@ -406,6 +438,8 @@ def log_request( # path isn't set if the requestline was bad msg = self.requestline + # Escape control characters that may be in the decoded path. + msg = msg.translate(self._control_char_table) code = str(code) if code[0] == "1": # 1xx - Informational @@ -459,14 +493,14 @@ def _ansi_style(value: str, *styles: str) -> str: def generate_adhoc_ssl_pair( - cn: t.Optional[str] = None, -) -> t.Tuple["Certificate", "RSAPrivateKeyWithSerialization"]: + cn: str | None = None, +) -> tuple[Certificate, RSAPrivateKeyWithSerialization]: try: from cryptography import x509 - from cryptography.x509.oid import NameOID from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID except ImportError: raise TypeError( "Using ad-hoc certificates requires the cryptography library." @@ -498,15 +532,18 @@ def generate_adhoc_ssl_pair( .not_valid_before(dt.now(timezone.utc)) .not_valid_after(dt.now(timezone.utc) + timedelta(days=365)) .add_extension(x509.ExtendedKeyUsage([x509.OID_SERVER_AUTH]), critical=False) - .add_extension(x509.SubjectAlternativeName([x509.DNSName(cn)]), critical=False) + .add_extension( + x509.SubjectAlternativeName([x509.DNSName(cn), x509.DNSName(f"*.{cn}")]), + critical=False, + ) .sign(pkey, hashes.SHA256(), backend) ) return cert, pkey def make_ssl_devcert( - base_path: str, host: t.Optional[str] = None, cn: t.Optional[str] = None -) -> t.Tuple[str, str]: + base_path: str, host: str | None = None, cn: str | None = None +) -> tuple[str, str]: """Creates an SSL key for development. This should be used instead of the ``'adhoc'`` key which generates a new cert on each server start. It accepts a path for where it should store the key and cert and @@ -526,7 +563,7 @@ def make_ssl_devcert( """ if host is not None: - cn = f"*.{host}/CN={host}" + cn = host cert, pkey = generate_adhoc_ssl_pair(cn=cn) from cryptography.hazmat.primitives import serialization @@ -548,10 +585,10 @@ def make_ssl_devcert( return cert_file, pkey_file -def generate_adhoc_ssl_context() -> "ssl.SSLContext": +def generate_adhoc_ssl_context() -> ssl.SSLContext: """Generates an adhoc SSL context for the development server.""" - import tempfile import atexit + import tempfile cert, pkey = generate_adhoc_ssl_pair() @@ -579,8 +616,8 @@ def generate_adhoc_ssl_context() -> "ssl.SSLContext": def load_ssl_context( - cert_file: str, pkey_file: t.Optional[str] = None, protocol: t.Optional[int] = None -) -> "ssl.SSLContext": + cert_file: str, pkey_file: str | None = None, protocol: int | None = None +) -> ssl.SSLContext: """Loads SSL context from cert/private key files and optional protocol. Many parameters are directly taken from the API of :py:class:`ssl.SSLContext`. @@ -599,7 +636,7 @@ def load_ssl_context( return ctx -def is_ssl_error(error: t.Optional[Exception] = None) -> bool: +def is_ssl_error(error: Exception | None = None) -> bool: """Checks if the given error (or the current one) is an SSL error.""" if error is None: error = t.cast(Exception, sys.exc_info()[1]) @@ -618,11 +655,12 @@ def select_address_family(host: str, port: int) -> socket.AddressFamily: def get_sockaddr( host: str, port: int, family: socket.AddressFamily -) -> t.Union[t.Tuple[str, int], str]: +) -> tuple[str, int] | str: """Return a fully qualified socket address that can be passed to :func:`socket.bind`.""" if family == af_unix: - return host.split("://", 1)[1] + # Absolute path avoids IDNA encoding error when path starts with dot. + return os.path.abspath(host.partition("://")[2]) try: res = socket.getaddrinfo( host, port, family, socket.SOCK_STREAM, socket.IPPROTO_TCP @@ -659,16 +697,17 @@ class BaseWSGIServer(HTTPServer): multithread = False multiprocess = False request_queue_size = LISTEN_QUEUE + allow_reuse_address = True def __init__( self, host: str, port: int, - app: "WSGIApplication", - handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + app: WSGIApplication, + handler: type[WSGIRequestHandler] | None = None, passthrough_errors: bool = False, - ssl_context: t.Optional[_TSSLContextArg] = None, - fd: t.Optional[int] = None, + ssl_context: _TSSLContextArg | None = None, + fd: int | None = None, ) -> None: if handler is None: handler = WSGIRequestHandler @@ -710,10 +749,36 @@ def __init__( try: self.server_bind() self.server_activate() + except OSError as e: + # Catch connection issues and show them without the traceback. Show + # extra instructions for address not found, and for macOS. + self.server_close() + print(e.strerror, file=sys.stderr) + + if e.errno == errno.EADDRINUSE: + print( + f"Port {port} is in use by another program. Either identify and" + " stop that program, or start the server with a different" + " port.", + file=sys.stderr, + ) + + if sys.platform == "darwin" and port == 5000: + print( + "On macOS, try disabling the 'AirPlay Receiver' service" + " from System Preferences -> General -> AirDrop & Handoff.", + file=sys.stderr, + ) + + sys.exit(1) except BaseException: self.server_close() raise else: + # TCPServer automatically opens a socket even if bind_and_activate is False. + # Close it to silence a ResourceWarning. + self.server_close() + # Use the passed in socket directly. self.socket = socket.fromfd(fd, address_family, socket.SOCK_STREAM) self.server_address = self.socket.getsockname() @@ -729,10 +794,14 @@ def __init__( ssl_context = generate_adhoc_ssl_context() self.socket = ssl_context.wrap_socket(self.socket, server_side=True) - self.ssl_context: t.Optional["ssl.SSLContext"] = ssl_context + self.ssl_context: ssl.SSLContext | None = ssl_context else: self.ssl_context = None + import importlib.metadata + + self._server_version = f"Werkzeug/{importlib.metadata.version('werkzeug')}" + def log(self, type: str, message: str, *args: t.Any) -> None: _log(type, message, *args) @@ -745,7 +814,7 @@ def serve_forever(self, poll_interval: float = 0.5) -> None: self.server_close() def handle_error( - self, request: t.Any, client_address: t.Union[t.Tuple[str, int], str] + self, request: t.Any, client_address: tuple[str, int] | str ) -> None: if self.passthrough_errors: raise @@ -811,12 +880,12 @@ def __init__( self, host: str, port: int, - app: "WSGIApplication", + app: WSGIApplication, processes: int = 40, - handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + handler: type[WSGIRequestHandler] | None = None, passthrough_errors: bool = False, - ssl_context: t.Optional[_TSSLContextArg] = None, - fd: t.Optional[int] = None, + ssl_context: _TSSLContextArg | None = None, + fd: int | None = None, ) -> None: if not can_fork: raise ValueError("Your platform does not support forking.") @@ -828,13 +897,13 @@ def __init__( def make_server( host: str, port: int, - app: "WSGIApplication", + app: WSGIApplication, threaded: bool = False, processes: int = 1, - request_handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + request_handler: type[WSGIRequestHandler] | None = None, passthrough_errors: bool = False, - ssl_context: t.Optional[_TSSLContextArg] = None, - fd: t.Optional[int] = None, + ssl_context: _TSSLContextArg | None = None, + fd: int | None = None, ) -> BaseWSGIServer: """Create an appropriate WSGI server instance based on the value of ``threaded`` and ``processes``. @@ -879,77 +948,23 @@ def is_running_from_reloader() -> bool: return os.environ.get("WERKZEUG_RUN_MAIN") == "true" -def prepare_socket(hostname: str, port: int) -> socket.socket: - """Prepare a socket for use by the WSGI server and reloader. - - The socket is marked inheritable so that it can be kept across - reloads instead of breaking connections. - - Catch errors during bind and show simpler error messages. For - "address already in use", show instructions for resolving the issue, - with special instructions for macOS. - - This is called from :func:`run_simple`, but can be used separately - to control server creation with :func:`make_server`. - """ - address_family = select_address_family(hostname, port) - server_address = get_sockaddr(hostname, port, address_family) - s = socket.socket(address_family, socket.SOCK_STREAM) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.set_inheritable(True) - - # Remove the socket file if it already exists. - if address_family == af_unix: - server_address = t.cast(str, server_address) - - if os.path.exists(server_address): - os.unlink(server_address) - - # Catch connection issues and show them without the traceback. Show - # extra instructions for address not found, and for macOS. - try: - s.bind(server_address) - except OSError as e: - print(e.strerror, file=sys.stderr) - - if e.errno == errno.EADDRINUSE: - print( - f"Port {port} is in use by another program. Either" - " identify and stop that program, or start the" - " server with a different port.", - file=sys.stderr, - ) - - if sys.platform == "darwin" and port == 5000: - print( - "On macOS, try disabling the 'AirPlay Receiver'" - " service from System Preferences -> Sharing.", - file=sys.stderr, - ) - - sys.exit(1) - - s.listen(LISTEN_QUEUE) - return s - - def run_simple( hostname: str, port: int, - application: "WSGIApplication", + application: WSGIApplication, use_reloader: bool = False, use_debugger: bool = False, use_evalex: bool = True, - extra_files: t.Optional[t.Iterable[str]] = None, - exclude_patterns: t.Optional[t.Iterable[str]] = None, + extra_files: t.Iterable[str] | None = None, + exclude_patterns: t.Iterable[str] | None = None, reloader_interval: int = 1, reloader_type: str = "auto", threaded: bool = False, processes: int = 1, - request_handler: t.Optional[t.Type[WSGIRequestHandler]] = None, - static_files: t.Optional[t.Dict[str, t.Union[str, t.Tuple[str, str]]]] = None, + request_handler: type[WSGIRequestHandler] | None = None, + static_files: dict[str, str | tuple[str, str]] | None = None, passthrough_errors: bool = False, - ssl_context: t.Optional[_TSSLContextArg] = None, + ssl_context: _TSSLContextArg | None = None, ) -> None: """Start a development server for a WSGI application. Various optional features can be enabled. @@ -997,7 +1012,7 @@ def run_simple( serve static files from using :class:`~werkzeug.middleware.SharedDataMiddleware`. :param passthrough_errors: Don't catch unhandled exceptions at the - server level, let the serve crash instead. If ``use_debugger`` + server level, let the server crash instead. If ``use_debugger`` is enabled, the debugger will still catch such errors. :param ssl_context: Configure TLS to serve over HTTPS. Can be an :class:`ssl.SSLContext` object, a ``(cert_file, key_file)`` @@ -1057,14 +1072,12 @@ def run_simple( from .debug import DebuggedApplication application = DebuggedApplication(application, evalex=use_evalex) + # Allow the specified hostname to use the debugger, in addition to + # localhost domains. + application.trusted_hosts.append(hostname) if not is_running_from_reloader(): - s = prepare_socket(hostname, port) - fd = s.fileno() - # Silence a ResourceWarning about an unclosed socket. This object is no longer - # used, the server will create another with fromfd. - s.detach() - os.environ["WERKZEUG_SERVER_FD"] = str(fd) + fd = None else: fd = int(os.environ["WERKZEUG_SERVER_FD"]) @@ -1079,6 +1092,8 @@ def run_simple( ssl_context, fd=fd, ) + srv.socket.set_inheritable(True) + os.environ["WERKZEUG_SERVER_FD"] = str(srv.fileno()) if not is_running_from_reloader(): srv.log_startup() @@ -1087,12 +1102,15 @@ def run_simple( if use_reloader: from ._reloader import run_with_reloader - run_with_reloader( - srv.serve_forever, - extra_files=extra_files, - exclude_patterns=exclude_patterns, - interval=reloader_interval, - reloader_type=reloader_type, - ) + try: + run_with_reloader( + srv.serve_forever, + extra_files=extra_files, + exclude_patterns=exclude_patterns, + interval=reloader_interval, + reloader_type=reloader_type, + ) + finally: + srv.server_close() else: srv.serve_forever() diff --git a/src/werkzeug/test.py b/src/werkzeug/test.py index edb4d4a..38f69bf 100644 --- a/src/werkzeug/test.py +++ b/src/werkzeug/test.py @@ -1,19 +1,21 @@ +from __future__ import annotations + +import dataclasses import mimetypes import sys import typing as t from collections import defaultdict from datetime import datetime -from datetime import timedelta -from http.cookiejar import CookieJar from io import BytesIO from itertools import chain from random import random from tempfile import TemporaryFile from time import time -from urllib.request import Request as _UrllibRequest +from urllib.parse import unquote +from urllib.parse import urlsplit +from urllib.parse import urlunsplit from ._internal import _get_environ -from ._internal import _make_encode_wrapper from ._internal import _wsgi_decoding_dance from ._internal import _wsgi_encoding_dance from .datastructures import Authorization @@ -25,6 +27,8 @@ from .datastructures import MultiDict from .http import dump_cookie from .http import dump_options_header +from .http import parse_cookie +from .http import parse_date from .http import parse_options_header from .sansio.multipart import Data from .sansio.multipart import Epilogue @@ -32,12 +36,8 @@ from .sansio.multipart import File from .sansio.multipart import MultipartEncoder from .sansio.multipart import Preamble +from .urls import _urlencode from .urls import iri_to_uri -from .urls import url_encode -from .urls import url_fix -from .urls import url_parse -from .urls import url_unparse -from .urls import url_unquote from .utils import cached_property from .utils import get_content_type from .wrappers.request import Request @@ -46,6 +46,7 @@ from .wsgi import get_current_url if t.TYPE_CHECKING: + import typing_extensions as te from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment @@ -54,12 +55,14 @@ def stream_encode_multipart( data: t.Mapping[str, t.Any], use_tempfile: bool = True, threshold: int = 1024 * 500, - boundary: t.Optional[str] = None, - charset: str = "utf-8", -) -> t.Tuple[t.IO[bytes], int, str]: + boundary: str | None = None, +) -> tuple[t.IO[bytes], int, str]: """Encode a dict of values (either strings or file descriptors or :class:`FileStorage` objects.) into a multipart encoded string stored in a file descriptor. + + .. versionchanged:: 3.0 + The ``charset`` parameter was removed. """ if boundary is None: boundary = f"---------------WerkzeugFormPart_{time()}{random()}" @@ -107,7 +110,8 @@ def write_binary(s: bytes) -> int: and mimetypes.guess_type(filename)[0] or "application/octet-stream" ) - headers = Headers([("Content-Type", content_type)]) + headers = value.headers + headers.update([("Content-Type", content_type)]) if filename is None: write_binary(encoder.send_event(Field(name=key, headers=headers))) else: @@ -120,6 +124,7 @@ def write_binary(s: bytes) -> int: chunk = reader(16384) if not chunk: + write_binary(encoder.send_event(Data(data=chunk, more_data=False))) break write_binary(encoder.send_event(Data(data=chunk, more_data=True))) @@ -127,9 +132,7 @@ def write_binary(s: bytes) -> int: if not isinstance(value, str): value = str(value) write_binary(encoder.send_event(Field(name=key, headers=Headers()))) - write_binary( - encoder.send_event(Data(data=value.encode(charset), more_data=False)) - ) + write_binary(encoder.send_event(Data(data=value.encode(), more_data=False))) write_binary(encoder.send_event(Epilogue(data=b""))) @@ -139,87 +142,21 @@ def write_binary(s: bytes) -> int: def encode_multipart( - values: t.Mapping[str, t.Any], - boundary: t.Optional[str] = None, - charset: str = "utf-8", -) -> t.Tuple[str, bytes]: + values: t.Mapping[str, t.Any], boundary: str | None = None +) -> tuple[str, bytes]: """Like `stream_encode_multipart` but returns a tuple in the form (``boundary``, ``data``) where data is bytes. + + .. versionchanged:: 3.0 + The ``charset`` parameter was removed. """ stream, length, boundary = stream_encode_multipart( - values, use_tempfile=False, boundary=boundary, charset=charset + values, use_tempfile=False, boundary=boundary ) return boundary, stream.read() -class _TestCookieHeaders: - """A headers adapter for cookielib""" - - def __init__(self, headers: t.Union[Headers, t.List[t.Tuple[str, str]]]) -> None: - self.headers = headers - - def getheaders(self, name: str) -> t.Iterable[str]: - headers = [] - name = name.lower() - for k, v in self.headers: - if k.lower() == name: - headers.append(v) - return headers - - def get_all( - self, name: str, default: t.Optional[t.Iterable[str]] = None - ) -> t.Iterable[str]: - headers = self.getheaders(name) - - if not headers: - return default # type: ignore - - return headers - - -class _TestCookieResponse: - """Something that looks like a httplib.HTTPResponse, but is actually just an - adapter for our test responses to make them available for cookielib. - """ - - def __init__(self, headers: t.Union[Headers, t.List[t.Tuple[str, str]]]) -> None: - self.headers = _TestCookieHeaders(headers) - - def info(self) -> _TestCookieHeaders: - return self.headers - - -class _TestCookieJar(CookieJar): - """A cookielib.CookieJar modified to inject and read cookie headers from - and to wsgi environments, and wsgi application responses. - """ - - def inject_wsgi(self, environ: "WSGIEnvironment") -> None: - """Inject the cookies as client headers into the server's wsgi - environment. - """ - cvals = [f"{c.name}={c.value}" for c in self] - - if cvals: - environ["HTTP_COOKIE"] = "; ".join(cvals) - else: - environ.pop("HTTP_COOKIE", None) - - def extract_wsgi( - self, - environ: "WSGIEnvironment", - headers: t.Union[Headers, t.List[t.Tuple[str, str]]], - ) -> None: - """Extract the server's set-cookie headers as cookies into the - cookie jar. - """ - self.extract_cookies( - _TestCookieResponse(headers), # type: ignore - _UrllibRequest(get_current_url(environ)), - ) - - -def _iter_data(data: t.Mapping[str, t.Any]) -> t.Iterator[t.Tuple[str, t.Any]]: +def _iter_data(data: t.Mapping[str, t.Any]) -> t.Iterator[tuple[str, t.Any]]: """Iterate over a mapping that might have a list of values, yielding all key, value pairs. Almost like iter_multi_items but only allows lists, not tuples, of values so tuples can be used for files. @@ -235,7 +172,7 @@ def _iter_data(data: t.Mapping[str, t.Any]) -> t.Iterator[t.Tuple[str, t.Any]]: yield key, value -_TAnyMultiDict = t.TypeVar("_TAnyMultiDict", bound=MultiDict) +_TAnyMultiDict = t.TypeVar("_TAnyMultiDict", bound="MultiDict[t.Any, t.Any]") class EnvironBuilder: @@ -302,11 +239,13 @@ class EnvironBuilder: Serialized with the function assigned to :attr:`json_dumps`. :param environ_base: an optional dict of environment defaults. :param environ_overrides: an optional dict of environment overrides. - :param charset: the charset used to encode string data. :param auth: An authorization object to use for the ``Authorization`` header value. A ``(username, password)`` tuple is a shortcut for ``Basic`` authorization. + .. versionchanged:: 3.0 + The ``charset`` parameter was removed. + .. versionchanged:: 2.1 ``CONTENT_TYPE`` and ``CONTENT_LENGTH`` are not duplicated as header keys in the environ. @@ -350,49 +289,45 @@ class EnvironBuilder: json_dumps = staticmethod(json.dumps) del json - _args: t.Optional[MultiDict] - _query_string: t.Optional[str] - _input_stream: t.Optional[t.IO[bytes]] - _form: t.Optional[MultiDict] - _files: t.Optional[FileMultiDict] + _args: MultiDict[str, str] | None + _query_string: str | None + _input_stream: t.IO[bytes] | None + _form: MultiDict[str, str] | None + _files: FileMultiDict | None def __init__( self, path: str = "/", - base_url: t.Optional[str] = None, - query_string: t.Optional[t.Union[t.Mapping[str, str], str]] = None, + base_url: str | None = None, + query_string: t.Mapping[str, str] | str | None = None, method: str = "GET", - input_stream: t.Optional[t.IO[bytes]] = None, - content_type: t.Optional[str] = None, - content_length: t.Optional[int] = None, - errors_stream: t.Optional[t.IO[str]] = None, + input_stream: t.IO[bytes] | None = None, + content_type: str | None = None, + content_length: int | None = None, + errors_stream: t.IO[str] | None = None, multithread: bool = False, multiprocess: bool = False, run_once: bool = False, - headers: t.Optional[t.Union[Headers, t.Iterable[t.Tuple[str, str]]]] = None, - data: t.Optional[ - t.Union[t.IO[bytes], str, bytes, t.Mapping[str, t.Any]] - ] = None, - environ_base: t.Optional[t.Mapping[str, t.Any]] = None, - environ_overrides: t.Optional[t.Mapping[str, t.Any]] = None, - charset: str = "utf-8", - mimetype: t.Optional[str] = None, - json: t.Optional[t.Mapping[str, t.Any]] = None, - auth: t.Optional[t.Union[Authorization, t.Tuple[str, str]]] = None, + headers: Headers | t.Iterable[tuple[str, str]] | None = None, + data: None | (t.IO[bytes] | str | bytes | t.Mapping[str, t.Any]) = None, + environ_base: t.Mapping[str, t.Any] | None = None, + environ_overrides: t.Mapping[str, t.Any] | None = None, + mimetype: str | None = None, + json: t.Mapping[str, t.Any] | None = None, + auth: Authorization | tuple[str, str] | None = None, ) -> None: - path_s = _make_encode_wrapper(path) - if query_string is not None and path_s("?") in path: + if query_string is not None and "?" in path: raise ValueError("Query string is defined in the path and as an argument") - request_uri = url_parse(path) - if query_string is None and path_s("?") in path: + request_uri = urlsplit(path) + if query_string is None and "?" in path: query_string = request_uri.query - self.charset = charset + self.path = iri_to_uri(request_uri.path) self.request_uri = path if base_url is not None: - base_url = url_fix(iri_to_uri(base_url, charset), charset) + base_url = iri_to_uri(base_url) self.base_url = base_url # type: ignore - if isinstance(query_string, (bytes, str)): + if isinstance(query_string, str): self.query_string = query_string else: if query_string is None: @@ -441,15 +376,15 @@ def __init__( if input_stream is not None: raise TypeError("can't provide input stream and data") if hasattr(data, "read"): - data = data.read() # type: ignore + data = data.read() if isinstance(data, str): - data = data.encode(self.charset) + data = data.encode() if isinstance(data, bytes): self.input_stream = BytesIO(data) if self.content_length is None: self.content_length = len(data) else: - for key, value in _iter_data(data): # type: ignore + for key, value in _iter_data(data): if isinstance(value, (tuple, dict)) or hasattr(value, "read"): self._add_file_from_data(key, value) else: @@ -459,9 +394,7 @@ def __init__( self.mimetype = mimetype @classmethod - def from_environ( - cls, environ: "WSGIEnvironment", **kwargs: t.Any - ) -> "EnvironBuilder": + def from_environ(cls, environ: WSGIEnvironment, **kwargs: t.Any) -> EnvironBuilder: """Turn an environ dict back into a builder. Any extra kwargs override the args extracted from the environ. @@ -496,9 +429,7 @@ def from_environ( def _add_file_from_data( self, key: str, - value: t.Union[ - t.IO[bytes], t.Tuple[t.IO[bytes], str], t.Tuple[t.IO[bytes], str, str] - ], + value: (t.IO[bytes] | tuple[t.IO[bytes], str] | tuple[t.IO[bytes], str, str]), ) -> None: """Called in the EnvironBuilder to add files from the data dict.""" if isinstance(value, tuple): @@ -508,7 +439,7 @@ def _add_file_from_data( @staticmethod def _make_base_url(scheme: str, host: str, script_root: str) -> str: - return url_unparse((scheme, host, script_root, "", "")).rstrip("/") + "/" + return urlunsplit((scheme, host, script_root, "", "")).rstrip("/") + "/" @property def base_url(self) -> str: @@ -518,13 +449,13 @@ def base_url(self) -> str: return self._make_base_url(self.url_scheme, self.host, self.script_root) @base_url.setter - def base_url(self, value: t.Optional[str]) -> None: + def base_url(self, value: str | None) -> None: if value is None: scheme = "http" netloc = "localhost" script_root = "" else: - scheme, netloc, script_root, qs, anchor = url_parse(value) + scheme, netloc, script_root, qs, anchor = urlsplit(value) if qs or anchor: raise ValueError("base url must not contain a query string or fragment") self.script_root = script_root.rstrip("/") @@ -532,7 +463,7 @@ def base_url(self, value: t.Optional[str]) -> None: self.url_scheme = scheme @property - def content_type(self) -> t.Optional[str]: + def content_type(self) -> str | None: """The content type for the request. Reflected from and to the :attr:`headers`. Do not set if you set :attr:`files` or :attr:`form` for auto detection. @@ -547,14 +478,14 @@ def content_type(self) -> t.Optional[str]: return ct @content_type.setter - def content_type(self, value: t.Optional[str]) -> None: + def content_type(self, value: str | None) -> None: if value is None: self.headers.pop("Content-Type", None) else: self.headers["Content-Type"] = value @property - def mimetype(self) -> t.Optional[str]: + def mimetype(self) -> str | None: """The mimetype (content type without charset etc.) .. versionadded:: 0.14 @@ -564,7 +495,7 @@ def mimetype(self) -> t.Optional[str]: @mimetype.setter def mimetype(self, value: str) -> None: - self.content_type = get_content_type(value, self.charset) + self.content_type = get_content_type(value, "utf-8") @property def mimetype_params(self) -> t.Mapping[str, str]: @@ -575,14 +506,14 @@ def mimetype_params(self) -> t.Mapping[str, str]: .. versionadded:: 0.14 """ - def on_update(d: CallbackDict) -> None: + def on_update(d: CallbackDict[str, str]) -> None: self.headers["Content-Type"] = dump_options_header(self.mimetype, d) d = parse_options_header(self.headers.get("content-type", ""))[1] return CallbackDict(d, on_update) @property - def content_length(self) -> t.Optional[int]: + def content_length(self) -> int | None: """The content length as integer. Reflected from and to the :attr:`headers`. Do not set if you set :attr:`files` or :attr:`form` for auto detection. @@ -590,13 +521,13 @@ def content_length(self) -> t.Optional[int]: return self.headers.get("Content-Length", type=int) @content_length.setter - def content_length(self, value: t.Optional[int]) -> None: + def content_length(self, value: int | None) -> None: if value is None: self.headers.pop("Content-Length", None) else: self.headers["Content-Length"] = str(value) - def _get_form(self, name: str, storage: t.Type[_TAnyMultiDict]) -> _TAnyMultiDict: + def _get_form(self, name: str, storage: type[_TAnyMultiDict]) -> _TAnyMultiDict: """Common behavior for getting the :attr:`form` and :attr:`files` properties. @@ -614,7 +545,7 @@ def _get_form(self, name: str, storage: t.Type[_TAnyMultiDict]) -> _TAnyMultiDic return rv # type: ignore - def _set_form(self, name: str, value: MultiDict) -> None: + def _set_form(self, name: str, value: MultiDict[str, t.Any]) -> None: """Common behavior for setting the :attr:`form` and :attr:`files` properties. @@ -625,12 +556,12 @@ def _set_form(self, name: str, value: MultiDict) -> None: setattr(self, name, value) @property - def form(self) -> MultiDict: + def form(self) -> MultiDict[str, str]: """A :class:`MultiDict` of form values.""" return self._get_form("_form", MultiDict) @form.setter - def form(self, value: MultiDict) -> None: + def form(self, value: MultiDict[str, str]) -> None: self._set_form("_form", value) @property @@ -645,7 +576,7 @@ def files(self, value: FileMultiDict) -> None: self._set_form("_files", value) @property - def input_stream(self) -> t.Optional[t.IO[bytes]]: + def input_stream(self) -> t.IO[bytes] | None: """An optional input stream. This is mutually exclusive with setting :attr:`form` and :attr:`files`, setting it will clear those. Do not provide this if the method is not ``POST`` or @@ -654,7 +585,7 @@ def input_stream(self) -> t.Optional[t.IO[bytes]]: return self._input_stream @input_stream.setter - def input_stream(self, value: t.Optional[t.IO[bytes]]) -> None: + def input_stream(self, value: t.IO[bytes] | None) -> None: self._input_stream = value self._form = None self._files = None @@ -666,17 +597,17 @@ def query_string(self) -> str: """ if self._query_string is None: if self._args is not None: - return url_encode(self._args, charset=self.charset) + return _urlencode(self._args) return "" return self._query_string @query_string.setter - def query_string(self, value: t.Optional[str]) -> None: + def query_string(self, value: str | None) -> None: self._query_string = value self._args = None @property - def args(self) -> MultiDict: + def args(self) -> MultiDict[str, str]: """The URL arguments as :class:`MultiDict`.""" if self._query_string is not None: raise AttributeError("a query string is defined") @@ -685,7 +616,7 @@ def args(self) -> MultiDict: return self._args @args.setter - def args(self, value: t.Optional[MultiDict]) -> None: + def args(self, value: MultiDict[str, str] | None) -> None: self._query_string = None self._args = value @@ -733,7 +664,7 @@ def close(self) -> None: pass self.closed = True - def get_environ(self) -> "WSGIEnvironment": + def get_environ(self) -> WSGIEnvironment: """Return the built environ. .. versionchanged:: 0.15 @@ -755,30 +686,30 @@ def get_environ(self) -> "WSGIEnvironment": content_length = end_pos - start_pos elif mimetype == "multipart/form-data": input_stream, content_length, boundary = stream_encode_multipart( - CombinedMultiDict([self.form, self.files]), charset=self.charset + CombinedMultiDict([self.form, self.files]) ) content_type = f'{mimetype}; boundary="{boundary}"' elif mimetype == "application/x-www-form-urlencoded": - form_encoded = url_encode(self.form, charset=self.charset).encode("ascii") + form_encoded = _urlencode(self.form).encode("ascii") content_length = len(form_encoded) input_stream = BytesIO(form_encoded) else: input_stream = BytesIO() - result: "WSGIEnvironment" = {} + result: WSGIEnvironment = {} if self.environ_base: result.update(self.environ_base) def _path_encode(x: str) -> str: - return _wsgi_encoding_dance(url_unquote(x, self.charset), self.charset) + return _wsgi_encoding_dance(unquote(x)) - raw_uri = _wsgi_encoding_dance(self.request_uri, self.charset) + raw_uri = _wsgi_encoding_dance(self.request_uri) result.update( { "REQUEST_METHOD": self.method, "SCRIPT_NAME": _path_encode(self.script_root), "PATH_INFO": _path_encode(self.path), - "QUERY_STRING": _wsgi_encoding_dance(self.query_string, self.charset), + "QUERY_STRING": _wsgi_encoding_dance(self.query_string), # Non-standard, added by mod_wsgi, uWSGI "REQUEST_URI": raw_uri, # Non-standard, added by gunicorn @@ -821,7 +752,7 @@ def _path_encode(x: str) -> str: return result - def get_request(self, cls: t.Optional[t.Type[Request]] = None) -> Request: + def get_request(self, cls: type[Request] | None = None) -> Request: """Returns a request with the data. If the request class is not specified :attr:`request_class` is used. @@ -840,24 +771,28 @@ class ClientRedirectError(Exception): class Client: - """This class allows you to send requests to a wrapped application. - - The use_cookies parameter indicates whether cookies should be stored and - sent for subsequent requests. This is True by default, but passing False - will disable this behaviour. - - If you want to request some subdomain of your application you may set - `allow_subdomain_redirects` to `True` as if not no external redirects - are allowed. + """Simulate sending requests to a WSGI application without running a WSGI or HTTP + server. + + :param application: The WSGI application to make requests to. + :param response_wrapper: A :class:`.Response` class to wrap response data with. + Defaults to :class:`.TestResponse`. If it's not a subclass of ``TestResponse``, + one will be created. + :param use_cookies: Persist cookies from ``Set-Cookie`` response headers to the + ``Cookie`` header in subsequent requests. Domain and path matching is supported, + but other cookie parameters are ignored. + :param allow_subdomain_redirects: Allow requests to follow redirects to subdomains. + Enable this if the application handles subdomains and redirects between them. + + .. versionchanged:: 2.3 + Simplify cookie implementation, support domain and path matching. .. versionchanged:: 2.1 - Removed deprecated behavior of treating the response as a - tuple. All data is available as properties on the returned - response object. + All data is available as properties on the returned response object. The + response cannot be returned as a tuple. .. versionchanged:: 2.0 - ``response_wrapper`` is always a subclass of - :class:``TestResponse``. + ``response_wrapper`` is always a subclass of :class:``TestResponse``. .. versionchanged:: 0.5 Added the ``use_cookies`` parameter. @@ -865,8 +800,8 @@ class Client: def __init__( self, - application: "WSGIApplication", - response_wrapper: t.Optional[t.Type["Response"]] = None, + application: WSGIApplication, + response_wrapper: type[Response] | None = None, use_cookies: bool = True, allow_subdomain_redirects: bool = False, ) -> None: @@ -874,106 +809,198 @@ def __init__( if response_wrapper in {None, Response}: response_wrapper = TestResponse - elif not isinstance(response_wrapper, TestResponse): + elif response_wrapper is not None and not issubclass( + response_wrapper, TestResponse + ): response_wrapper = type( "WrapperTestResponse", - (TestResponse, response_wrapper), # type: ignore + (TestResponse, response_wrapper), {}, ) self.response_wrapper = t.cast(t.Type["TestResponse"], response_wrapper) if use_cookies: - self.cookie_jar: t.Optional[_TestCookieJar] = _TestCookieJar() + self._cookies: dict[tuple[str, str, str], Cookie] | None = {} else: - self.cookie_jar = None + self._cookies = None self.allow_subdomain_redirects = allow_subdomain_redirects + def get_cookie( + self, key: str, domain: str = "localhost", path: str = "/" + ) -> Cookie | None: + """Return a :class:`.Cookie` if it exists. Cookies are uniquely identified by + ``(domain, path, key)``. + + :param key: The decoded form of the key for the cookie. + :param domain: The domain the cookie was set for. + :param path: The path the cookie was set for. + + .. versionadded:: 2.3 + """ + if self._cookies is None: + raise TypeError( + "Cookies are disabled. Create a client with 'use_cookies=True'." + ) + + return self._cookies.get((domain, path, key)) + def set_cookie( self, - server_name: str, key: str, value: str = "", - max_age: t.Optional[t.Union[timedelta, int]] = None, - expires: t.Optional[t.Union[str, datetime, int, float]] = None, + *, + domain: str = "localhost", + origin_only: bool = True, path: str = "/", - domain: t.Optional[str] = None, - secure: bool = False, - httponly: bool = False, - samesite: t.Optional[str] = None, - charset: str = "utf-8", + **kwargs: t.Any, ) -> None: - """Sets a cookie in the client's cookie jar. The server name - is required and has to match the one that is also passed to - the open call. + """Set a cookie to be sent in subsequent requests. + + This is a convenience to skip making a test request to a route that would set + the cookie. To test the cookie, make a test request to a route that uses the + cookie value. + + The client uses ``domain``, ``origin_only``, and ``path`` to determine which + cookies to send with a request. It does not use other cookie parameters that + browsers use, since they're not applicable in tests. + + :param key: The key part of the cookie. + :param value: The value part of the cookie. + :param domain: Send this cookie with requests that match this domain. If + ``origin_only`` is true, it must be an exact match, otherwise it may be a + suffix match. + :param origin_only: Whether the domain must be an exact match to the request. + :param path: Send this cookie with requests that match this path either exactly + or as a prefix. + :param kwargs: Passed to :func:`.dump_cookie`. + + .. versionchanged:: 3.0 + The parameter ``server_name`` is removed. The first parameter is + ``key``. Use the ``domain`` and ``origin_only`` parameters instead. + + .. versionchanged:: 2.3 + The ``origin_only`` parameter was added. + + .. versionchanged:: 2.3 + The ``domain`` parameter defaults to ``localhost``. """ - assert self.cookie_jar is not None, "cookies disabled" - header = dump_cookie( - key, - value, - max_age, - expires, - path, - domain, - secure, - httponly, - charset, - samesite=samesite, + if self._cookies is None: + raise TypeError( + "Cookies are disabled. Create a client with 'use_cookies=True'." + ) + + cookie = Cookie._from_response_header( + domain, "/", dump_cookie(key, value, domain=domain, path=path, **kwargs) ) - environ = create_environ(path, base_url=f"http://{server_name}") - headers = [("Set-Cookie", header)] - self.cookie_jar.extract_wsgi(environ, headers) + cookie.origin_only = origin_only + + if cookie._should_delete: + self._cookies.pop(cookie._storage_key, None) + else: + self._cookies[cookie._storage_key] = cookie def delete_cookie( self, - server_name: str, key: str, + *, + domain: str = "localhost", path: str = "/", - domain: t.Optional[str] = None, - secure: bool = False, - httponly: bool = False, - samesite: t.Optional[str] = None, ) -> None: - """Deletes a cookie in the test client.""" - self.set_cookie( - server_name, - key, - expires=0, - max_age=0, - path=path, - domain=domain, - secure=secure, - httponly=httponly, - samesite=samesite, + """Delete a cookie if it exists. Cookies are uniquely identified by + ``(domain, path, key)``. + + :param key: The decoded form of the key for the cookie. + :param domain: The domain the cookie was set for. + :param path: The path the cookie was set for. + + .. versionchanged:: 3.0 + The ``server_name`` parameter is removed. The first parameter is + ``key``. Use the ``domain`` parameter instead. + + .. versionchanged:: 3.0 + The ``secure``, ``httponly`` and ``samesite`` parameters are removed. + + .. versionchanged:: 2.3 + The ``domain`` parameter defaults to ``localhost``. + """ + if self._cookies is None: + raise TypeError( + "Cookies are disabled. Create a client with 'use_cookies=True'." + ) + + self._cookies.pop((domain, path, key), None) + + def _add_cookies_to_wsgi(self, environ: WSGIEnvironment) -> None: + """If cookies are enabled, set the ``Cookie`` header in the environ to the + cookies that are applicable to the request host and path. + + :meta private: + + .. versionadded:: 2.3 + """ + if self._cookies is None: + return + + url = urlsplit(get_current_url(environ)) + server_name = url.hostname or "localhost" + value = "; ".join( + c._to_request_header() + for c in self._cookies.values() + if c._matches_request(server_name, url.path) ) - def run_wsgi_app( - self, environ: "WSGIEnvironment", buffered: bool = False - ) -> t.Tuple[t.Iterable[bytes], str, Headers]: - """Runs the wrapped WSGI app with the given environment. + if value: + environ["HTTP_COOKIE"] = value + else: + environ.pop("HTTP_COOKIE", None) + + def _update_cookies_from_response( + self, server_name: str, path: str, headers: list[str] + ) -> None: + """If cookies are enabled, update the stored cookies from any ``Set-Cookie`` + headers in the response. :meta private: + + .. versionadded:: 2.3 """ - if self.cookie_jar is not None: - self.cookie_jar.inject_wsgi(environ) + if self._cookies is None: + return - rv = run_wsgi_app(self.application, environ, buffered=buffered) + for header in headers: + cookie = Cookie._from_response_header(server_name, path, header) + + if cookie._should_delete: + self._cookies.pop(cookie._storage_key, None) + else: + self._cookies[cookie._storage_key] = cookie - if self.cookie_jar is not None: - self.cookie_jar.extract_wsgi(environ, rv[2]) + def run_wsgi_app( + self, environ: WSGIEnvironment, buffered: bool = False + ) -> tuple[t.Iterable[bytes], str, Headers]: + """Runs the wrapped WSGI app with the given environment. + :meta private: + """ + self._add_cookies_to_wsgi(environ) + rv = run_wsgi_app(self.application, environ, buffered=buffered) + url = urlsplit(get_current_url(environ)) + self._update_cookies_from_response( + url.hostname or "localhost", url.path, rv[2].getlist("Set-Cookie") + ) return rv def resolve_redirect( - self, response: "TestResponse", buffered: bool = False - ) -> "TestResponse": + self, response: TestResponse, buffered: bool = False + ) -> TestResponse: """Perform a new request to the location given by the redirect response to the previous request. :meta private: """ - scheme, netloc, path, qs, anchor = url_parse(response.location) + scheme, netloc, path, qs, anchor = urlsplit(response.location) builder = EnvironBuilder.from_environ( response.request.environ, path=path, query_string=qs ) @@ -1034,7 +1061,7 @@ def open( buffered: bool = False, follow_redirects: bool = False, **kwargs: t.Any, - ) -> "TestResponse": + ) -> TestResponse: """Generate an environ dict from the given arguments, make a request to the application using it, and return the response. @@ -1052,11 +1079,6 @@ def open( .. versionchanged:: 2.1 Removed the ``as_tuple`` parameter. - .. versionchanged:: 2.0 - ``as_tuple`` is deprecated and will be removed in Werkzeug - 2.1. Use :attr:`TestResponse.request` and - ``request.environ`` instead. - .. versionchanged:: 2.0 The request input stream is closed when calling ``response.close()``. Input streams for redirects are @@ -1071,7 +1093,7 @@ def open( .. versionchanged:: 0.5 Added the ``follow_redirects`` parameter. """ - request: t.Optional["Request"] = None + request: Request | None = None if not kwargs and len(args) == 1: arg = args[0] @@ -1091,11 +1113,11 @@ def open( finally: builder.close() - response = self.run_wsgi_app(request.environ, buffered=buffered) - response = self.response_wrapper(*response, request=request) + response_parts = self.run_wsgi_app(request.environ, buffered=buffered) + response = self.response_wrapper(*response_parts, request=request) redirects = set() - history: t.List["TestResponse"] = [] + history: list[TestResponse] = [] if not follow_redirects: return response @@ -1134,42 +1156,42 @@ def open( response.call_on_close(request.input_stream.close) return response - def get(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def get(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``GET``.""" kw["method"] = "GET" return self.open(*args, **kw) - def post(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def post(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``POST``.""" kw["method"] = "POST" return self.open(*args, **kw) - def put(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def put(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``PUT``.""" kw["method"] = "PUT" return self.open(*args, **kw) - def delete(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def delete(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``DELETE``.""" kw["method"] = "DELETE" return self.open(*args, **kw) - def patch(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def patch(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``PATCH``.""" kw["method"] = "PATCH" return self.open(*args, **kw) - def options(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def options(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``OPTIONS``.""" kw["method"] = "OPTIONS" return self.open(*args, **kw) - def head(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def head(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``HEAD``.""" kw["method"] = "HEAD" return self.open(*args, **kw) - def trace(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def trace(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``TRACE``.""" kw["method"] = "TRACE" return self.open(*args, **kw) @@ -1178,7 +1200,7 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {self.application!r}>" -def create_environ(*args: t.Any, **kwargs: t.Any) -> "WSGIEnvironment": +def create_environ(*args: t.Any, **kwargs: t.Any) -> WSGIEnvironment: """Create a new WSGI environ dict based on the values passed. The first parameter should be the path of the request which defaults to '/'. The second one can either be an absolute path (in that case the host is @@ -1202,8 +1224,8 @@ def create_environ(*args: t.Any, **kwargs: t.Any) -> "WSGIEnvironment": def run_wsgi_app( - app: "WSGIApplication", environ: "WSGIEnvironment", buffered: bool = False -) -> t.Tuple[t.Iterable[bytes], str, Headers]: + app: WSGIApplication, environ: WSGIEnvironment, buffered: bool = False +) -> tuple[t.Iterable[bytes], str, Headers]: """Return a tuple in the form (app_iter, status, headers) of the application output. This works best if you pass it an application that returns an iterator all the time. @@ -1224,8 +1246,8 @@ def run_wsgi_app( # example) don't affect subsequent requests (such as redirects). environ = _get_environ(environ).copy() status: str - response: t.Optional[t.Tuple[str, t.List[t.Tuple[str, str]]]] = None - buffer: t.List[bytes] = [] + response: tuple[str, list[tuple[str, str]]] | None = None + buffer: list[bytes] = [] def start_response(status, headers, exc_info=None): # type: ignore nonlocal response @@ -1290,8 +1312,7 @@ class TestResponse(Response): assumed if missing. .. versionchanged:: 2.1 - Removed deprecated behavior for treating the response instance - as a tuple. + Response instances cannot be treated as tuples. .. versionadded:: 2.0 Test client methods always return instances of this class. @@ -1305,7 +1326,7 @@ class TestResponse(Response): resulted in this response. """ - history: t.Tuple["TestResponse", ...] + history: tuple[TestResponse, ...] """A list of intermediate responses. Populated when the test request is made with ``follow_redirects`` enabled. """ @@ -1319,7 +1340,7 @@ def __init__( status: str, headers: Headers, request: Request, - history: t.Tuple["TestResponse"] = (), # type: ignore + history: tuple[TestResponse] = (), # type: ignore **kwargs: t.Any, ) -> None: super().__init__(response, status, headers, **kwargs) @@ -1335,3 +1356,109 @@ def text(self) -> str: .. versionadded:: 2.1 """ return self.get_data(as_text=True) + + +@dataclasses.dataclass +class Cookie: + """A cookie key, value, and parameters. + + The class itself is not a public API. Its attributes are documented for inspection + with :meth:`.Client.get_cookie` only. + + .. versionadded:: 2.3 + """ + + key: str + """The cookie key, encoded as a client would see it.""" + + value: str + """The cookie key, encoded as a client would see it.""" + + decoded_key: str + """The cookie key, decoded as the application would set and see it.""" + + decoded_value: str + """The cookie value, decoded as the application would set and see it.""" + + expires: datetime | None + """The time at which the cookie is no longer valid.""" + + max_age: int | None + """The number of seconds from when the cookie was set at which it is + no longer valid. + """ + + domain: str + """The domain that the cookie was set for, or the request domain if not set.""" + + origin_only: bool + """Whether the cookie will be sent for exact domain matches only. This is ``True`` + if the ``Domain`` parameter was not present. + """ + + path: str + """The path that the cookie was set for.""" + + secure: bool | None + """The ``Secure`` parameter.""" + + http_only: bool | None + """The ``HttpOnly`` parameter.""" + + same_site: str | None + """The ``SameSite`` parameter.""" + + def _matches_request(self, server_name: str, path: str) -> bool: + return ( + server_name == self.domain + or ( + not self.origin_only + and server_name.endswith(self.domain) + and server_name[: -len(self.domain)].endswith(".") + ) + ) and ( + path == self.path + or ( + path.startswith(self.path) + and path[len(self.path) - self.path.endswith("/") :].startswith("/") + ) + ) + + def _to_request_header(self) -> str: + return f"{self.key}={self.value}" + + @classmethod + def _from_response_header(cls, server_name: str, path: str, header: str) -> te.Self: + header, _, parameters_str = header.partition(";") + key, _, value = header.partition("=") + decoded_key, decoded_value = next(parse_cookie(header).items()) + params = {} + + for item in parameters_str.split(";"): + k, sep, v = item.partition("=") + params[k.strip().lower()] = v.strip() if sep else None + + return cls( + key=key.strip(), + value=value.strip(), + decoded_key=decoded_key, + decoded_value=decoded_value, + expires=parse_date(params.get("expires")), + max_age=int(params["max-age"] or 0) if "max-age" in params else None, + domain=params.get("domain") or server_name, + origin_only="domain" not in params, + path=params.get("path") or path.rpartition("/")[0] or "/", + secure="secure" in params, + http_only="httponly" in params, + same_site=params.get("samesite"), + ) + + @property + def _storage_key(self) -> tuple[str, str, str]: + return self.domain, self.path, self.decoded_key + + @property + def _should_delete(self) -> bool: + return self.max_age == 0 or ( + self.expires is not None and self.expires.timestamp() == 0 + ) diff --git a/src/werkzeug/testapp.py b/src/werkzeug/testapp.py index 0d7ffbb..cdf7fac 100644 --- a/src/werkzeug/testapp.py +++ b/src/werkzeug/testapp.py @@ -1,7 +1,10 @@ """A small application that can be used to test a WSGI server and check it for WSGI compliance. """ -import base64 + +from __future__ import annotations + +import importlib.metadata import os import sys import typing as t @@ -9,57 +12,9 @@ from markupsafe import escape -from . import __version__ as _werkzeug_version from .wrappers.request import Request from .wrappers.response import Response -if t.TYPE_CHECKING: - from _typeshed.wsgi import StartResponse - from _typeshed.wsgi import WSGIEnvironment - - -logo = Response( - base64.b64decode( - """ -R0lGODlhoACgAOMIAAEDACwpAEpCAGdgAJaKAM28AOnVAP3rAP///////// -//////////////////////yH5BAEKAAgALAAAAACgAKAAAAT+EMlJq704680R+F0ojmRpnuj0rWnrv -nB8rbRs33gu0bzu/0AObxgsGn3D5HHJbCUFyqZ0ukkSDlAidctNFg7gbI9LZlrBaHGtzAae0eloe25 -7w9EDOX2fst/xenyCIn5/gFqDiVVDV4aGeYiKkhSFjnCQY5OTlZaXgZp8nJ2ekaB0SQOjqphrpnOiq -ncEn65UsLGytLVmQ6m4sQazpbtLqL/HwpnER8bHyLrLOc3Oz8PRONPU1crXN9na263dMt/g4SzjMeX -m5yDpLqgG7OzJ4u8lT/P69ej3JPn69kHzN2OIAHkB9RUYSFCFQYQJFTIkCDBiwoXWGnowaLEjRm7+G -p9A7Hhx4rUkAUaSLJlxHMqVMD/aSycSZkyTplCqtGnRAM5NQ1Ly5OmzZc6gO4d6DGAUKA+hSocWYAo -SlM6oUWX2O/o0KdaVU5vuSQLAa0ADwQgMEMB2AIECZhVSnTno6spgbtXmHcBUrQACcc2FrTrWS8wAf -78cMFBgwIBgbN+qvTt3ayikRBk7BoyGAGABAdYyfdzRQGV3l4coxrqQ84GpUBmrdR3xNIDUPAKDBSA -ADIGDhhqTZIWaDcrVX8EsbNzbkvCOxG8bN5w8ly9H8jyTJHC6DFndQydbguh2e/ctZJFXRxMAqqPVA -tQH5E64SPr1f0zz7sQYjAHg0In+JQ11+N2B0XXBeeYZgBZFx4tqBToiTCPv0YBgQv8JqA6BEf6RhXx -w1ENhRBnWV8ctEX4Ul2zc3aVGcQNC2KElyTDYyYUWvShdjDyMOGMuFjqnII45aogPhz/CodUHFwaDx -lTgsaOjNyhGWJQd+lFoAGk8ObghI0kawg+EV5blH3dr+digkYuAGSaQZFHFz2P/cTaLmhF52QeSb45 -Jwxd+uSVGHlqOZpOeJpCFZ5J+rkAkFjQ0N1tah7JJSZUFNsrkeJUJMIBi8jyaEKIhKPomnC91Uo+NB -yyaJ5umnnpInIFh4t6ZSpGaAVmizqjpByDegYl8tPE0phCYrhcMWSv+uAqHfgH88ak5UXZmlKLVJhd -dj78s1Fxnzo6yUCrV6rrDOkluG+QzCAUTbCwf9SrmMLzK6p+OPHx7DF+bsfMRq7Ec61Av9i6GLw23r -idnZ+/OO0a99pbIrJkproCQMA17OPG6suq3cca5ruDfXCCDoS7BEdvmJn5otdqscn+uogRHHXs8cbh -EIfYaDY1AkrC0cqwcZpnM6ludx72x0p7Fo/hZAcpJDjax0UdHavMKAbiKltMWCF3xxh9k25N/Viud8 -ba78iCvUkt+V6BpwMlErmcgc502x+u1nSxJSJP9Mi52awD1V4yB/QHONsnU3L+A/zR4VL/indx/y64 -gqcj+qgTeweM86f0Qy1QVbvmWH1D9h+alqg254QD8HJXHvjQaGOqEqC22M54PcftZVKVSQG9jhkv7C -JyTyDoAJfPdu8v7DRZAxsP/ky9MJ3OL36DJfCFPASC3/aXlfLOOON9vGZZHydGf8LnxYJuuVIbl83y -Az5n/RPz07E+9+zw2A2ahz4HxHo9Kt79HTMx1Q7ma7zAzHgHqYH0SoZWyTuOLMiHwSfZDAQTn0ajk9 -YQqodnUYjByQZhZak9Wu4gYQsMyEpIOAOQKze8CmEF45KuAHTvIDOfHJNipwoHMuGHBnJElUoDmAyX -c2Qm/R8Ah/iILCCJOEokGowdhDYc/yoL+vpRGwyVSCWFYZNljkhEirGXsalWcAgOdeAdoXcktF2udb -qbUhjWyMQxYO01o6KYKOr6iK3fE4MaS+DsvBsGOBaMb0Y6IxADaJhFICaOLmiWTlDAnY1KzDG4ambL -cWBA8mUzjJsN2KjSaSXGqMCVXYpYkj33mcIApyhQf6YqgeNAmNvuC0t4CsDbSshZJkCS1eNisKqlyG -cF8G2JeiDX6tO6Mv0SmjCa3MFb0bJaGPMU0X7c8XcpvMaOQmCajwSeY9G0WqbBmKv34DsMIEztU6Y2 -KiDlFdt6jnCSqx7Dmt6XnqSKaFFHNO5+FmODxMCWBEaco77lNDGXBM0ECYB/+s7nKFdwSF5hgXumQe -EZ7amRg39RHy3zIjyRCykQh8Zo2iviRKyTDn/zx6EefptJj2Cw+Ep2FSc01U5ry4KLPYsTyWnVGnvb -UpyGlhjBUljyjHhWpf8OFaXwhp9O4T1gU9UeyPPa8A2l0p1kNqPXEVRm1AOs1oAGZU596t6SOR2mcB -Oco1srWtkaVrMUzIErrKri85keKqRQYX9VX0/eAUK1hrSu6HMEX3Qh2sCh0q0D2CtnUqS4hj62sE/z -aDs2Sg7MBS6xnQeooc2R2tC9YrKpEi9pLXfYXp20tDCpSP8rKlrD4axprb9u1Df5hSbz9QU0cRpfgn -kiIzwKucd0wsEHlLpe5yHXuc6FrNelOl7pY2+11kTWx7VpRu97dXA3DO1vbkhcb4zyvERYajQgAADs -=""" - ), - mimetype="image/png", -) - - TEMPLATE = """\ @@ -70,7 +25,6 @@ body { font-family: 'Lucida Grande', 'Lucida Sans Unicode', 'Geneva', 'Verdana', sans-serif; background-color: white; color: #000; font-size: 15px; text-align: center; } - #logo { float: right; padding: 0 0 10px 10px; } div.box { text-align: left; width: 45em; margin: auto; padding: 50px 0; background-color: white; } h1, h2 { font-family: 'Ubuntu', 'Lucida Grande', 'Lucida Sans Unicode', @@ -92,7 +46,6 @@ li.exp { background: white; }
-

WSGI Information

This page displays all available information about the WSGI server and @@ -139,7 +92,7 @@ """ -def iter_sys_path() -> t.Iterator[t.Tuple[str, bool, bool]]: +def iter_sys_path() -> t.Iterator[tuple[str, bool, bool]]: if os.name == "posix": def strip(x: str) -> str: @@ -159,7 +112,21 @@ def strip(x: str) -> str: yield strip(os.path.normpath(path)), not os.path.isdir(path), path != item -def render_testapp(req: Request) -> bytes: +@Request.application +def test_app(req: Request) -> Response: + """Simple test application that dumps the environment. You can use + it to check if Werkzeug is working properly: + + .. sourcecode:: pycon + + >>> from werkzeug.serving import run_simple + >>> from werkzeug.testapp import test_app + >>> run_simple('localhost', 3000, test_app) + * Running on http://localhost:3000/ + + The application displays important information from the WSGI environment, + the Python interpreter and the installed libraries. + """ try: import pkg_resources except ImportError: @@ -167,7 +134,7 @@ def render_testapp(req: Request) -> bytes: else: eggs = sorted( pkg_resources.working_set, - key=lambda x: x.project_name.lower(), # type: ignore + key=lambda x: x.project_name.lower(), ) python_eggs = [] for egg in eggs: @@ -187,52 +154,38 @@ def render_testapp(req: Request) -> bytes: sys_path = [] for item, virtual, expanded in iter_sys_path(): - class_ = [] + css = [] if virtual: - class_.append("virtual") + css.append("virtual") if expanded: - class_.append("exp") - class_ = f' class="{" ".join(class_)}"' if class_ else "" - sys_path.append(f"{escape(item)}") - - return ( - TEMPLATE - % { - "python_version": "
".join(escape(sys.version).splitlines()), - "platform": escape(sys.platform), - "os": escape(os.name), - "api_version": sys.api_version, - "byteorder": sys.byteorder, - "werkzeug_version": _werkzeug_version, - "python_eggs": "\n".join(python_eggs), - "wsgi_env": "\n".join(wsgi_env), - "sys_path": "\n".join(sys_path), - } - ).encode("utf-8") - - -def test_app( - environ: "WSGIEnvironment", start_response: "StartResponse" -) -> t.Iterable[bytes]: - """Simple test application that dumps the environment. You can use - it to check if Werkzeug is working properly: + css.append("exp") + class_str = f' class="{" ".join(css)}"' if css else "" + sys_path.append(f"{escape(item)}") - .. sourcecode:: pycon + context = { + "python_version": "
".join(escape(sys.version).splitlines()), + "platform": escape(sys.platform), + "os": escape(os.name), + "api_version": sys.api_version, + "byteorder": sys.byteorder, + "werkzeug_version": _get_werkzeug_version(), + "python_eggs": "\n".join(python_eggs), + "wsgi_env": "\n".join(wsgi_env), + "sys_path": "\n".join(sys_path), + } + return Response(TEMPLATE % context, mimetype="text/html") - >>> from werkzeug.serving import run_simple - >>> from werkzeug.testapp import test_app - >>> run_simple('localhost', 3000, test_app) - * Running on http://localhost:3000/ - The application displays important information from the WSGI environment, - the Python interpreter and the installed libraries. - """ - req = Request(environ, populate_request=False) - if req.args.get("resource") == "logo": - response = logo - else: - response = Response(render_testapp(req), mimetype="text/html") - return response(environ, start_response) +_werkzeug_version = "" + + +def _get_werkzeug_version() -> str: + global _werkzeug_version + + if not _werkzeug_version: + _werkzeug_version = importlib.metadata.version("werkzeug") + + return _werkzeug_version if __name__ == "__main__": diff --git a/src/werkzeug/urls.py b/src/werkzeug/urls.py index 67c08b0..5bffe39 100644 --- a/src/werkzeug/urls.py +++ b/src/werkzeug/urls.py @@ -1,722 +1,64 @@ -"""Functions for working with URLs. +from __future__ import annotations -Contains implementations of functions from :mod:`urllib.parse` that -handle bytes and strings. -""" import codecs -import os import re import typing as t +import urllib.parse +from urllib.parse import quote +from urllib.parse import unquote +from urllib.parse import urlencode +from urllib.parse import urlsplit +from urllib.parse import urlunsplit -from ._internal import _check_str_tuple -from ._internal import _decode_idna -from ._internal import _encode_idna -from ._internal import _make_encode_wrapper -from ._internal import _to_str +from .datastructures import iter_multi_items -if t.TYPE_CHECKING: - from . import datastructures as ds -# A regular expression for what a valid schema looks like -_scheme_re = re.compile(r"^[a-zA-Z0-9+-.]+$") - -# Characters that are safe in any part of an URL. -_always_safe = frozenset( - bytearray( - b"abcdefghijklmnopqrstuvwxyz" - b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" - b"0123456789" - b"-._~" - b"$!'()*+,;" # RFC3986 sub-delims set, not including query string delimiters &= - ) -) - -_hexdigits = "0123456789ABCDEFabcdef" -_hextobyte = { - f"{a}{b}".encode("ascii"): int(f"{a}{b}", 16) - for a in _hexdigits - for b in _hexdigits -} -_bytetohex = [f"%{char:02X}".encode("ascii") for char in range(256)] - - -class _URLTuple(t.NamedTuple): - scheme: str - netloc: str - path: str - query: str - fragment: str - - -class BaseURL(_URLTuple): - """Superclass of :py:class:`URL` and :py:class:`BytesURL`.""" - - __slots__ = () - _at: str - _colon: str - _lbracket: str - _rbracket: str - - def __str__(self) -> str: - return self.to_url() - - def replace(self, **kwargs: t.Any) -> "BaseURL": - """Return an URL with the same values, except for those parameters - given new values by whichever keyword arguments are specified.""" - return self._replace(**kwargs) - - @property - def host(self) -> t.Optional[str]: - """The host part of the URL if available, otherwise `None`. The - host is either the hostname or the IP address mentioned in the - URL. It will not contain the port. - """ - return self._split_host()[0] - - @property - def ascii_host(self) -> t.Optional[str]: - """Works exactly like :attr:`host` but will return a result that - is restricted to ASCII. If it finds a netloc that is not ASCII - it will attempt to idna decode it. This is useful for socket - operations when the URL might include internationalized characters. - """ - rv = self.host - if rv is not None and isinstance(rv, str): - try: - rv = _encode_idna(rv) # type: ignore - except UnicodeError: - rv = rv.encode("ascii", "ignore") # type: ignore - return _to_str(rv, "ascii", "ignore") - - @property - def port(self) -> t.Optional[int]: - """The port in the URL as an integer if it was present, `None` - otherwise. This does not fill in default ports. - """ - try: - rv = int(_to_str(self._split_host()[1])) - if 0 <= rv <= 65535: - return rv - except (ValueError, TypeError): - pass - return None - - @property - def auth(self) -> t.Optional[str]: - """The authentication part in the URL if available, `None` - otherwise. - """ - return self._split_netloc()[0] - - @property - def username(self) -> t.Optional[str]: - """The username if it was part of the URL, `None` otherwise. - This undergoes URL decoding and will always be a string. - """ - rv = self._split_auth()[0] - if rv is not None: - return _url_unquote_legacy(rv) - return None - - @property - def raw_username(self) -> t.Optional[str]: - """The username if it was part of the URL, `None` otherwise. - Unlike :attr:`username` this one is not being decoded. - """ - return self._split_auth()[0] - - @property - def password(self) -> t.Optional[str]: - """The password if it was part of the URL, `None` otherwise. - This undergoes URL decoding and will always be a string. - """ - rv = self._split_auth()[1] - if rv is not None: - return _url_unquote_legacy(rv) - return None - - @property - def raw_password(self) -> t.Optional[str]: - """The password if it was part of the URL, `None` otherwise. - Unlike :attr:`password` this one is not being decoded. - """ - return self._split_auth()[1] - - def decode_query(self, *args: t.Any, **kwargs: t.Any) -> "ds.MultiDict[str, str]": - """Decodes the query part of the URL. Ths is a shortcut for - calling :func:`url_decode` on the query argument. The arguments and - keyword arguments are forwarded to :func:`url_decode` unchanged. - """ - return url_decode(self.query, *args, **kwargs) - - def join(self, *args: t.Any, **kwargs: t.Any) -> "BaseURL": - """Joins this URL with another one. This is just a convenience - function for calling into :meth:`url_join` and then parsing the - return value again. - """ - return url_parse(url_join(self, *args, **kwargs)) - - def to_url(self) -> str: - """Returns a URL string or bytes depending on the type of the - information stored. This is just a convenience function - for calling :meth:`url_unparse` for this URL. - """ - return url_unparse(self) - - def encode_netloc(self) -> str: - """Encodes the netloc part to an ASCII safe URL as bytes.""" - rv = self.ascii_host or "" - if ":" in rv: - rv = f"[{rv}]" - port = self.port - if port is not None: - rv = f"{rv}:{port}" - auth = ":".join( - filter( - None, - [ - url_quote(self.raw_username or "", "utf-8", "strict", "/:%"), - url_quote(self.raw_password or "", "utf-8", "strict", "/:%"), - ], - ) - ) - if auth: - rv = f"{auth}@{rv}" - return rv - - def decode_netloc(self) -> str: - """Decodes the netloc part into a string.""" - rv = _decode_idna(self.host or "") - - if ":" in rv: - rv = f"[{rv}]" - port = self.port - if port is not None: - rv = f"{rv}:{port}" - auth = ":".join( - filter( - None, - [ - _url_unquote_legacy(self.raw_username or "", "/:%@"), - _url_unquote_legacy(self.raw_password or "", "/:%@"), - ], - ) - ) - if auth: - rv = f"{auth}@{rv}" - return rv - - def to_uri_tuple(self) -> "BaseURL": - """Returns a :class:`BytesURL` tuple that holds a URI. This will - encode all the information in the URL properly to ASCII using the - rules a web browser would follow. - - It's usually more interesting to directly call :meth:`iri_to_uri` which - will return a string. - """ - return url_parse(iri_to_uri(self)) - - def to_iri_tuple(self) -> "BaseURL": - """Returns a :class:`URL` tuple that holds a IRI. This will try - to decode as much information as possible in the URL without - losing information similar to how a web browser does it for the - URL bar. - - It's usually more interesting to directly call :meth:`uri_to_iri` which - will return a string. - """ - return url_parse(uri_to_iri(self)) - - def get_file_location( - self, pathformat: t.Optional[str] = None - ) -> t.Tuple[t.Optional[str], t.Optional[str]]: - """Returns a tuple with the location of the file in the form - ``(server, location)``. If the netloc is empty in the URL or - points to localhost, it's represented as ``None``. - - The `pathformat` by default is autodetection but needs to be set - when working with URLs of a specific system. The supported values - are ``'windows'`` when working with Windows or DOS paths and - ``'posix'`` when working with posix paths. - - If the URL does not point to a local file, the server and location - are both represented as ``None``. - - :param pathformat: The expected format of the path component. - Currently ``'windows'`` and ``'posix'`` are - supported. Defaults to ``None`` which is - autodetect. - """ - if self.scheme != "file": - return None, None - - path = url_unquote(self.path) - host = self.netloc or None - - if pathformat is None: - if os.name == "nt": - pathformat = "windows" - else: - pathformat = "posix" - - if pathformat == "windows": - if path[:1] == "/" and path[1:2].isalpha() and path[2:3] in "|:": - path = f"{path[1:2]}:{path[3:]}" - windows_share = path[:3] in ("\\" * 3, "/" * 3) - import ntpath - - path = ntpath.normpath(path) - # Windows shared drives are represented as ``\\host\\directory``. - # That results in a URL like ``file://///host/directory``, and a - # path like ``///host/directory``. We need to special-case this - # because the path contains the hostname. - if windows_share and host is None: - parts = path.lstrip("\\").split("\\", 1) - if len(parts) == 2: - host, path = parts - else: - host = parts[0] - path = "" - elif pathformat == "posix": - import posixpath - - path = posixpath.normpath(path) - else: - raise TypeError(f"Invalid path format {pathformat!r}") - - if host in ("127.0.0.1", "::1", "localhost"): - host = None - - return host, path - - def _split_netloc(self) -> t.Tuple[t.Optional[str], str]: - if self._at in self.netloc: - auth, _, netloc = self.netloc.partition(self._at) - return auth, netloc - return None, self.netloc - - def _split_auth(self) -> t.Tuple[t.Optional[str], t.Optional[str]]: - auth = self._split_netloc()[0] - if not auth: - return None, None - if self._colon not in auth: - return auth, None - - username, _, password = auth.partition(self._colon) - return username, password - - def _split_host(self) -> t.Tuple[t.Optional[str], t.Optional[str]]: - rv = self._split_netloc()[1] - if not rv: - return None, None - - if not rv.startswith(self._lbracket): - if self._colon in rv: - host, _, port = rv.partition(self._colon) - return host, port - return rv, None - - idx = rv.find(self._rbracket) - if idx < 0: - return rv, None - - host = rv[1:idx] - rest = rv[idx + 1 :] - if rest.startswith(self._colon): - return host, rest[1:] - return host, None - - -class URL(BaseURL): - """Represents a parsed URL. This behaves like a regular tuple but - also has some extra attributes that give further insight into the - URL. - """ - - __slots__ = () - _at = "@" - _colon = ":" - _lbracket = "[" - _rbracket = "]" - - def encode(self, charset: str = "utf-8", errors: str = "replace") -> "BytesURL": - """Encodes the URL to a tuple made out of bytes. The charset is - only being used for the path, query and fragment. - """ - return BytesURL( - self.scheme.encode("ascii"), # type: ignore - self.encode_netloc(), - self.path.encode(charset, errors), # type: ignore - self.query.encode(charset, errors), # type: ignore - self.fragment.encode(charset, errors), # type: ignore - ) - - -class BytesURL(BaseURL): - """Represents a parsed URL in bytes.""" - - __slots__ = () - _at = b"@" # type: ignore - _colon = b":" # type: ignore - _lbracket = b"[" # type: ignore - _rbracket = b"]" # type: ignore - - def __str__(self) -> str: - return self.to_url().decode("utf-8", "replace") # type: ignore - - def encode_netloc(self) -> bytes: # type: ignore - """Returns the netloc unchanged as bytes.""" - return self.netloc # type: ignore - - def decode(self, charset: str = "utf-8", errors: str = "replace") -> "URL": - """Decodes the URL to a tuple made out of strings. The charset is - only being used for the path, query and fragment. - """ - return URL( - self.scheme.decode("ascii"), # type: ignore - self.decode_netloc(), - self.path.decode(charset, errors), # type: ignore - self.query.decode(charset, errors), # type: ignore - self.fragment.decode(charset, errors), # type: ignore - ) - - -_unquote_maps: t.Dict[t.FrozenSet[int], t.Dict[bytes, int]] = {frozenset(): _hextobyte} - - -def _unquote_to_bytes( - string: t.Union[str, bytes], unsafe: t.Union[str, bytes] = "" -) -> bytes: - if isinstance(string, str): - string = string.encode("utf-8") - - if isinstance(unsafe, str): - unsafe = unsafe.encode("utf-8") - - unsafe = frozenset(bytearray(unsafe)) - groups = iter(string.split(b"%")) - result = bytearray(next(groups, b"")) - - try: - hex_to_byte = _unquote_maps[unsafe] - except KeyError: - hex_to_byte = _unquote_maps[unsafe] = { - h: b for h, b in _hextobyte.items() if b not in unsafe - } - - for group in groups: - code = group[:2] - - if code in hex_to_byte: - result.append(hex_to_byte[code]) - result.extend(group[2:]) - else: - result.append(37) # % - result.extend(group) - - return bytes(result) - - -def _url_encode_impl( - obj: t.Union[t.Mapping[str, str], t.Iterable[t.Tuple[str, str]]], - charset: str, - sort: bool, - key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]], -) -> t.Iterator[str]: - from .datastructures import iter_multi_items - - iterable: t.Iterable[t.Tuple[str, str]] = iter_multi_items(obj) - - if sort: - iterable = sorted(iterable, key=key) - - for key_str, value_str in iterable: - if value_str is None: - continue - - if not isinstance(key_str, bytes): - key_bytes = str(key_str).encode(charset) - else: - key_bytes = key_str - - if not isinstance(value_str, bytes): - value_bytes = str(value_str).encode(charset) - else: - value_bytes = value_str - - yield f"{_fast_url_quote_plus(key_bytes)}={_fast_url_quote_plus(value_bytes)}" - - -def _url_unquote_legacy(value: str, unsafe: str = "") -> str: - try: - return url_unquote(value, charset="utf-8", errors="strict", unsafe=unsafe) - except UnicodeError: - return url_unquote(value, charset="latin1", unsafe=unsafe) - - -def url_parse( - url: str, scheme: t.Optional[str] = None, allow_fragments: bool = True -) -> BaseURL: - """Parses a URL from a string into a :class:`URL` tuple. If the URL - is lacking a scheme it can be provided as second argument. Otherwise, - it is ignored. Optionally fragments can be stripped from the URL - by setting `allow_fragments` to `False`. - - The inverse of this function is :func:`url_unparse`. - - :param url: the URL to parse. - :param scheme: the default schema to use if the URL is schemaless. - :param allow_fragments: if set to `False` a fragment will be removed - from the URL. - """ - s = _make_encode_wrapper(url) - is_text_based = isinstance(url, str) - - if scheme is None: - scheme = s("") - netloc = query = fragment = s("") - i = url.find(s(":")) - if i > 0 and _scheme_re.match(_to_str(url[:i], errors="replace")): - # make sure "iri" is not actually a port number (in which case - # "scheme" is really part of the path) - rest = url[i + 1 :] - if not rest or any(c not in s("0123456789") for c in rest): - # not a port number - scheme, url = url[:i].lower(), rest - - if url[:2] == s("//"): - delim = len(url) - for c in s("/?#"): - wdelim = url.find(c, 2) - if wdelim >= 0: - delim = min(delim, wdelim) - netloc, url = url[2:delim], url[delim:] - if (s("[") in netloc and s("]") not in netloc) or ( - s("]") in netloc and s("[") not in netloc - ): - raise ValueError("Invalid IPv6 URL") - - if allow_fragments and s("#") in url: - url, fragment = url.split(s("#"), 1) - if s("?") in url: - url, query = url.split(s("?"), 1) - - result_type = URL if is_text_based else BytesURL - return result_type(scheme, netloc, url, query, fragment) - - -def _make_fast_url_quote( - charset: str = "utf-8", - errors: str = "strict", - safe: t.Union[str, bytes] = "/:", - unsafe: t.Union[str, bytes] = "", -) -> t.Callable[[bytes], str]: - """Precompile the translation table for a URL encoding function. - - Unlike :func:`url_quote`, the generated function only takes the - string to quote. - - :param charset: The charset to encode the result with. - :param errors: How to handle encoding errors. - :param safe: An optional sequence of safe characters to never encode. - :param unsafe: An optional sequence of unsafe characters to always encode. - """ - if isinstance(safe, str): - safe = safe.encode(charset, errors) - - if isinstance(unsafe, str): - unsafe = unsafe.encode(charset, errors) - - safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) - table = [chr(c) if c in safe else f"%{c:02X}" for c in range(256)] - - def quote(string: bytes) -> str: - return "".join([table[c] for c in string]) - - return quote - - -_fast_url_quote = _make_fast_url_quote() -_fast_quote_plus = _make_fast_url_quote(safe=" ", unsafe="+") - - -def _fast_url_quote_plus(string: bytes) -> str: - return _fast_quote_plus(string).replace(" ", "+") - - -def url_quote( - string: t.Union[str, bytes], - charset: str = "utf-8", - errors: str = "strict", - safe: t.Union[str, bytes] = "/:", - unsafe: t.Union[str, bytes] = "", -) -> str: - """URL encode a single string with a given encoding. - - :param s: the string to quote. - :param charset: the charset to be used. - :param safe: an optional sequence of safe characters. - :param unsafe: an optional sequence of unsafe characters. - - .. versionadded:: 0.9.2 - The `unsafe` parameter was added. - """ - if not isinstance(string, (str, bytes, bytearray)): - string = str(string) - if isinstance(string, str): - string = string.encode(charset, errors) - if isinstance(safe, str): - safe = safe.encode(charset, errors) - if isinstance(unsafe, str): - unsafe = unsafe.encode(charset, errors) - safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) - rv = bytearray() - for char in bytearray(string): - if char in safe: - rv.append(char) - else: - rv.extend(_bytetohex[char]) - return bytes(rv).decode(charset) - - -def url_quote_plus( - string: str, charset: str = "utf-8", errors: str = "strict", safe: str = "" -) -> str: - """URL encode a single string with the given encoding and convert - whitespace to "+". - - :param s: The string to quote. - :param charset: The charset to be used. - :param safe: An optional sequence of safe characters. - """ - return url_quote(string, charset, errors, safe + " ", "+").replace(" ", "+") - - -def url_unparse(components: t.Tuple[str, str, str, str, str]) -> str: - """The reverse operation to :meth:`url_parse`. This accepts arbitrary - as well as :class:`URL` tuples and returns a URL as a string. - - :param components: the parsed URL as tuple which should be converted - into a URL string. - """ - _check_str_tuple(components) - scheme, netloc, path, query, fragment = components - s = _make_encode_wrapper(scheme) - url = s("") - - # We generally treat file:///x and file:/x the same which is also - # what browsers seem to do. This also allows us to ignore a schema - # register for netloc utilization or having to differentiate between - # empty and missing netloc. - if netloc or (scheme and path.startswith(s("/"))): - if path and path[:1] != s("/"): - path = s("/") + path - url = s("//") + (netloc or s("")) + path - elif path: - url += path - if scheme: - url = scheme + s(":") + url - if query: - url = url + s("?") + query - if fragment: - url = url + s("#") + fragment - return url - - -def url_unquote( - s: t.Union[str, bytes], - charset: str = "utf-8", - errors: str = "replace", - unsafe: str = "", -) -> str: - """URL decode a single string with a given encoding. If the charset - is set to `None` no decoding is performed and raw bytes are - returned. - - :param s: the string to unquote. - :param charset: the charset of the query string. If set to `None` - no decoding will take place. - :param errors: the error handling for the charset decoding. - """ - rv = _unquote_to_bytes(s, unsafe) - if charset is None: - return rv - return rv.decode(charset, errors) - - -def url_unquote_plus( - s: t.Union[str, bytes], charset: str = "utf-8", errors: str = "replace" -) -> str: - """URL decode a single string with the given `charset` and decode "+" to - whitespace. - - Per default encoding errors are ignored. If you want a different behavior - you can set `errors` to ``'replace'`` or ``'strict'``. - - :param s: The string to unquote. - :param charset: the charset of the query string. If set to `None` - no decoding will take place. - :param errors: The error handling for the `charset` decoding. +def _codec_error_url_quote(e: UnicodeError) -> tuple[str, int]: + """Used in :func:`uri_to_iri` after unquoting to re-quote any + invalid bytes. """ - if isinstance(s, str): - s = s.replace("+", " ") - else: - s = s.replace(b"+", b" ") - return url_unquote(s, charset, errors) + # the docs state that UnicodeError does have these attributes, + # but mypy isn't picking them up + out = quote(e.object[e.start : e.end], safe="") # type: ignore + return out, e.end # type: ignore -def url_fix(s: str, charset: str = "utf-8") -> str: - r"""Sometimes you get an URL by a user that just isn't a real URL because - it contains unsafe characters like ' ' and so on. This function can fix - some of the problems in a similar way browsers handle data entered by the - user: +codecs.register_error("werkzeug.url_quote", _codec_error_url_quote) - >>> url_fix('http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)') - 'http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)' - :param s: the string with the URL to fix. - :param charset: The target charset for the URL if the url was given - as a string. +def _make_unquote_part(name: str, chars: str) -> t.Callable[[str], str]: + """Create a function that unquotes all percent encoded characters except those + given. This allows working with unquoted characters if possible while not changing + the meaning of a given part of a URL. """ - # First step is to switch to text processing and to convert - # backslashes (which are invalid in URLs anyways) to slashes. This is - # consistent with what Chrome does. - s = _to_str(s, charset, "replace").replace("\\", "/") + choices = "|".join(f"{ord(c):02X}" for c in sorted(chars)) + pattern = re.compile(f"((?:%(?:{choices}))+)", re.I) - # For the specific case that we look like a malformed windows URL - # we want to fix this up manually: - if s.startswith("file://") and s[7:8].isalpha() and s[8:10] in (":/", "|/"): - s = f"file:///{s[7:]}" + def _unquote_partial(value: str) -> str: + parts = iter(pattern.split(value)) + out = [] - url = url_parse(s) - path = url_quote(url.path, charset, safe="/%+$!*'(),") - qs = url_quote_plus(url.query, charset, safe=":&%=+$!*'(),") - anchor = url_quote_plus(url.fragment, charset, safe=":&%=+$!*'(),") - return url_unparse((url.scheme, url.encode_netloc(), path, qs, anchor)) + for part in parts: + out.append(unquote(part, "utf-8", "werkzeug.url_quote")) + out.append(next(parts, "")) + return "".join(out) -# not-unreserved characters remain quoted when unquoting to IRI -_to_iri_unsafe = "".join([chr(c) for c in range(128) if c not in _always_safe]) + _unquote_partial.__name__ = f"_unquote_{name}" + return _unquote_partial -def _codec_error_url_quote(e: UnicodeError) -> t.Tuple[str, int]: - """Used in :func:`uri_to_iri` after unquoting to re-quote any - invalid bytes. - """ - # the docs state that UnicodeError does have these attributes, - # but mypy isn't picking them up - out = _fast_url_quote(e.object[e.start : e.end]) # type: ignore - return out, e.end # type: ignore +# characters that should remain quoted in URL parts +# based on https://url.spec.whatwg.org/#percent-encoded-bytes +# always keep all controls, space, and % quoted +_always_unsafe = bytes((*range(0x21), 0x25, 0x7F)).decode() +_unquote_fragment = _make_unquote_part("fragment", _always_unsafe) +_unquote_query = _make_unquote_part("query", _always_unsafe + "&=+#") +_unquote_path = _make_unquote_part("path", _always_unsafe + "/?#") +_unquote_user = _make_unquote_part("user", _always_unsafe + ":@/?#") -codecs.register_error("werkzeug.url_quote", _codec_error_url_quote) - - -def uri_to_iri( - uri: t.Union[str, t.Tuple[str, str, str, str, str]], - charset: str = "utf-8", - errors: str = "werkzeug.url_quote", -) -> str: +def uri_to_iri(uri: str) -> str: """Convert a URI to an IRI. All valid UTF-8 characters are unquoted, leaving all reserved and invalid characters quoted. If the URL has a domain, it is decoded from Punycode. @@ -725,9 +67,13 @@ def uri_to_iri( 'http://\\u2603.net/p\\xe5th?q=\\xe8ry%DF' :param uri: The URI to convert. - :param charset: The encoding to encode unquoted bytes with. - :param errors: Error handler to use during ``bytes.encode``. By - default, invalid bytes are left quoted. + + .. versionchanged:: 3.0 + Passing a tuple or bytes, and the ``charset`` and ``errors`` parameters, + are removed. + + .. versionchanged:: 2.3 + Which characters remain quoted is specific to each part of the URL. .. versionchanged:: 0.15 All reserved and invalid characters remain quoted. Previously, @@ -736,26 +82,35 @@ def uri_to_iri( .. versionadded:: 0.6 """ - if isinstance(uri, tuple): - uri = url_unparse(uri) + parts = urlsplit(uri) + path = _unquote_path(parts.path) + query = _unquote_query(parts.query) + fragment = _unquote_fragment(parts.fragment) + + if parts.hostname: + netloc = _decode_idna(parts.hostname) + else: + netloc = "" + + if ":" in netloc: + netloc = f"[{netloc}]" - uri = url_parse(_to_str(uri, charset)) - path = url_unquote(uri.path, charset, errors, _to_iri_unsafe) - query = url_unquote(uri.query, charset, errors, _to_iri_unsafe) - fragment = url_unquote(uri.fragment, charset, errors, _to_iri_unsafe) - return url_unparse((uri.scheme, uri.decode_netloc(), path, query, fragment)) + if parts.port: + netloc = f"{netloc}:{parts.port}" + if parts.username: + auth = _unquote_user(parts.username) -# reserved characters remain unquoted when quoting to URI -_to_uri_safe = ":/?#[]@!$&'()*+,;=%" + if parts.password: + password = _unquote_user(parts.password) + auth = f"{auth}:{password}" + netloc = f"{auth}@{netloc}" -def iri_to_uri( - iri: t.Union[str, t.Tuple[str, str, str, str, str]], - charset: str = "utf-8", - errors: str = "strict", - safe_conversion: bool = False, -) -> str: + return urlunsplit((parts.scheme, netloc, path, query, fragment)) + + +def iri_to_uri(iri: str) -> str: """Convert an IRI to a URI. All non-ASCII and unsafe characters are quoted. If the URL has a domain, it is encoded to Punycode. @@ -763,305 +118,86 @@ def iri_to_uri( 'http://xn--n3h.net/p%C3%A5th?q=%C3%A8ry%DF' :param iri: The IRI to convert. - :param charset: The encoding of the IRI. - :param errors: Error handler to use during ``bytes.encode``. - :param safe_conversion: Return the URL unchanged if it only contains - ASCII characters and no whitespace. See the explanation below. - - There is a general problem with IRI conversion with some protocols - that are in violation of the URI specification. Consider the - following two IRIs:: - magnet:?xt=uri:whatever - itms-services://?action=download-manifest + .. versionchanged:: 3.0 + Passing a tuple or bytes, the ``charset`` and ``errors`` parameters, + and the ``safe_conversion`` parameter, are removed. - After parsing, we don't know if the scheme requires the ``//``, - which is dropped if empty, but conveys different meanings in the - final URL if it's present or not. In this case, you can use - ``safe_conversion``, which will return the URL unchanged if it only - contains ASCII characters and no whitespace. This can result in a - URI with unquoted characters if it was not already quoted correctly, - but preserves the URL's semantics. Werkzeug uses this for the - ``Location`` header for redirects. + .. versionchanged:: 2.3 + Which characters remain unquoted is specific to each part of the URL. .. versionchanged:: 0.15 - All reserved characters remain unquoted. Previously, only some - reserved characters were left unquoted. + All reserved characters remain unquoted. Previously, only some reserved + characters were left unquoted. .. versionchanged:: 0.9.6 The ``safe_conversion`` parameter was added. .. versionadded:: 0.6 """ - if isinstance(iri, tuple): - iri = url_unparse(iri) - - if safe_conversion: - # If we're not sure if it's safe to convert the URL, and it only - # contains ASCII characters, return it unconverted. - try: - native_iri = _to_str(iri) - ascii_iri = native_iri.encode("ascii") - - # Only return if it doesn't have whitespace. (Why?) - if len(ascii_iri.split()) == 1: - return native_iri - except UnicodeError: - pass - - iri = url_parse(_to_str(iri, charset, errors)) - path = url_quote(iri.path, charset, errors, _to_uri_safe) - query = url_quote(iri.query, charset, errors, _to_uri_safe) - fragment = url_quote(iri.fragment, charset, errors, _to_uri_safe) - return url_unparse((iri.scheme, iri.encode_netloc(), path, query, fragment)) - - -def url_decode( - s: t.AnyStr, - charset: str = "utf-8", - include_empty: bool = True, - errors: str = "replace", - separator: str = "&", - cls: t.Optional[t.Type["ds.MultiDict"]] = None, -) -> "ds.MultiDict[str, str]": - """Parse a query string and return it as a :class:`MultiDict`. - - :param s: The query string to parse. - :param charset: Decode bytes to string with this charset. If not - given, bytes are returned as-is. - :param include_empty: Include keys with empty values in the dict. - :param errors: Error handling behavior when decoding bytes. - :param separator: Separator character between pairs. - :param cls: Container to hold result instead of :class:`MultiDict`. - - .. versionchanged:: 2.0 - The ``decode_keys`` parameter is deprecated and will be removed - in Werkzeug 2.1. - - .. versionchanged:: 0.5 - In previous versions ";" and "&" could be used for url decoding. - Now only "&" is supported. If you want to use ";", a different - ``separator`` can be provided. - - .. versionchanged:: 0.5 - The ``cls`` parameter was added. - """ - if cls is None: - from .datastructures import MultiDict # noqa: F811 - - cls = MultiDict - if isinstance(s, str) and not isinstance(separator, str): - separator = separator.decode(charset or "ascii") - elif isinstance(s, bytes) and not isinstance(separator, bytes): - separator = separator.encode(charset or "ascii") # type: ignore - return cls( - _url_decode_impl( - s.split(separator), charset, include_empty, errors # type: ignore - ) - ) - - -def url_decode_stream( - stream: t.IO[bytes], - charset: str = "utf-8", - include_empty: bool = True, - errors: str = "replace", - separator: bytes = b"&", - cls: t.Optional[t.Type["ds.MultiDict"]] = None, - limit: t.Optional[int] = None, -) -> "ds.MultiDict[str, str]": - """Works like :func:`url_decode` but decodes a stream. The behavior - of stream and limit follows functions like - :func:`~werkzeug.wsgi.make_line_iter`. The generator of pairs is - directly fed to the `cls` so you can consume the data while it's - parsed. + parts = urlsplit(iri) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + # as well as percent for things that are already quoted + path = quote(parts.path, safe="%!$&'()*+,/:;=@") + query = quote(parts.query, safe="%!$&'()*+,/:;=?@") + fragment = quote(parts.fragment, safe="%!#$&'()*+,/:;=?@") - :param stream: a stream with the encoded querystring - :param charset: the charset of the query string. If set to `None` - no decoding will take place. - :param include_empty: Set to `False` if you don't want empty values to - appear in the dict. - :param errors: the decoding error behavior. - :param separator: the pair separator to be used, defaults to ``&`` - :param cls: an optional dict class to use. If this is not specified - or `None` the default :class:`MultiDict` is used. - :param limit: the content length of the URL data. Not necessary if - a limited stream is provided. - - .. versionchanged:: 2.0 - The ``decode_keys`` and ``return_iterator`` parameters are - deprecated and will be removed in Werkzeug 2.1. - - .. versionadded:: 0.8 - """ - from .wsgi import make_chunk_iter - - pair_iter = make_chunk_iter(stream, separator, limit) - decoder = _url_decode_impl(pair_iter, charset, include_empty, errors) - - if cls is None: - from .datastructures import MultiDict # noqa: F811 - - cls = MultiDict - - return cls(decoder) - - -def _url_decode_impl( - pair_iter: t.Iterable[t.AnyStr], charset: str, include_empty: bool, errors: str -) -> t.Iterator[t.Tuple[str, str]]: - for pair in pair_iter: - if not pair: - continue - s = _make_encode_wrapper(pair) - equal = s("=") - if equal in pair: - key, value = pair.split(equal, 1) - else: - if not include_empty: - continue - key = pair - value = s("") - yield ( - url_unquote_plus(key, charset, errors), - url_unquote_plus(value, charset, errors), - ) - - -def url_encode( - obj: t.Union[t.Mapping[str, str], t.Iterable[t.Tuple[str, str]]], - charset: str = "utf-8", - sort: bool = False, - key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]] = None, - separator: str = "&", -) -> str: - """URL encode a dict/`MultiDict`. If a value is `None` it will not appear - in the result string. Per default only values are encoded into the target - charset strings. - - :param obj: the object to encode into a query string. - :param charset: the charset of the query string. - :param sort: set to `True` if you want parameters to be sorted by `key`. - :param separator: the separator to be used for the pairs. - :param key: an optional function to be used for sorting. For more details - check out the :func:`sorted` documentation. - - .. versionchanged:: 2.0 - The ``encode_keys`` parameter is deprecated and will be removed - in Werkzeug 2.1. + if parts.hostname: + netloc = parts.hostname.encode("idna").decode("ascii") + else: + netloc = "" - .. versionchanged:: 0.5 - Added the ``sort``, ``key``, and ``separator`` parameters. - """ - separator = _to_str(separator, "ascii") - return separator.join(_url_encode_impl(obj, charset, sort, key)) + if ":" in netloc: + netloc = f"[{netloc}]" + if parts.port: + netloc = f"{netloc}:{parts.port}" -def url_encode_stream( - obj: t.Union[t.Mapping[str, str], t.Iterable[t.Tuple[str, str]]], - stream: t.Optional[t.IO[str]] = None, - charset: str = "utf-8", - sort: bool = False, - key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]] = None, - separator: str = "&", -) -> None: - """Like :meth:`url_encode` but writes the results to a stream - object. If the stream is `None` a generator over all encoded - pairs is returned. + if parts.username: + auth = quote(parts.username, safe="%!$&'()*+,;=") - :param obj: the object to encode into a query string. - :param stream: a stream to write the encoded object into or `None` if - an iterator over the encoded pairs should be returned. In - that case the separator argument is ignored. - :param charset: the charset of the query string. - :param sort: set to `True` if you want parameters to be sorted by `key`. - :param separator: the separator to be used for the pairs. - :param key: an optional function to be used for sorting. For more details - check out the :func:`sorted` documentation. + if parts.password: + password = quote(parts.password, safe="%!$&'()*+,;=") + auth = f"{auth}:{password}" - .. versionchanged:: 2.0 - The ``encode_keys`` parameter is deprecated and will be removed - in Werkzeug 2.1. + netloc = f"{auth}@{netloc}" - .. versionadded:: 0.8 - """ - separator = _to_str(separator, "ascii") - gen = _url_encode_impl(obj, charset, sort, key) - if stream is None: - return gen # type: ignore - for idx, chunk in enumerate(gen): - if idx: - stream.write(separator) - stream.write(chunk) - return None + return urlunsplit((parts.scheme, netloc, path, query, fragment)) -def url_join( - base: t.Union[str, t.Tuple[str, str, str, str, str]], - url: t.Union[str, t.Tuple[str, str, str, str, str]], - allow_fragments: bool = True, -) -> str: - """Join a base URL and a possibly relative URL to form an absolute - interpretation of the latter. +# Python < 3.12 +# itms-services was worked around in previous iri_to_uri implementations, but +# we can tell Python directly that it needs to preserve the //. +if "itms-services" not in urllib.parse.uses_netloc: + urllib.parse.uses_netloc.append("itms-services") - :param base: the base URL for the join operation. - :param url: the URL to join. - :param allow_fragments: indicates whether fragments should be allowed. - """ - if isinstance(base, tuple): - base = url_unparse(base) - if isinstance(url, tuple): - url = url_unparse(url) - _check_str_tuple((base, url)) - s = _make_encode_wrapper(base) +def _decode_idna(domain: str) -> str: + try: + data = domain.encode("ascii") + except UnicodeEncodeError: + # If the domain is not ASCII, it's decoded already. + return domain - if not base: - return url - if not url: - return base + try: + # Try decoding in one shot. + return data.decode("idna") + except UnicodeDecodeError: + pass - bscheme, bnetloc, bpath, bquery, bfragment = url_parse( - base, allow_fragments=allow_fragments - ) - scheme, netloc, path, query, fragment = url_parse(url, bscheme, allow_fragments) - if scheme != bscheme: - return url - if netloc: - return url_unparse((scheme, netloc, path, query, fragment)) - netloc = bnetloc + # Decode each part separately, leaving invalid parts as punycode. + parts = [] - if path[:1] == s("/"): - segments = path.split(s("/")) - elif not path: - segments = bpath.split(s("/")) - if not query: - query = bquery - else: - segments = bpath.split(s("/"))[:-1] + path.split(s("/")) - - # If the rightmost part is "./" we want to keep the slash but - # remove the dot. - if segments[-1] == s("."): - segments[-1] = s("") + for part in data.split(b"."): + try: + parts.append(part.decode("idna")) + except UnicodeDecodeError: + parts.append(part.decode("ascii")) - # Resolve ".." and "." - segments = [segment for segment in segments if segment != s(".")] - while True: - i = 1 - n = len(segments) - 1 - while i < n: - if segments[i] == s("..") and segments[i - 1] not in (s(""), s("..")): - del segments[i - 1 : i + 1] - break - i += 1 - else: - break + return ".".join(parts) - # Remove trailing ".." if the URL is absolute - unwanted_marker = [s(""), s("..")] - while segments[:2] == unwanted_marker: - del segments[1] - path = s("/").join(segments) - return url_unparse((scheme, netloc, path, query, fragment)) +def _urlencode(query: t.Mapping[str, str] | t.Iterable[tuple[str, str]]) -> str: + items = [x for x in iter_multi_items(query) if x[1] is not None] + # safe = https://url.spec.whatwg.org/#percent-encoded-bytes + return urlencode(items, safe="!$'()*,/:;?@") diff --git a/src/werkzeug/user_agent.py b/src/werkzeug/user_agent.py index 66ffcbe..17e5d3f 100644 --- a/src/werkzeug/user_agent.py +++ b/src/werkzeug/user_agent.py @@ -1,4 +1,4 @@ -import typing as t +from __future__ import annotations class UserAgent: @@ -17,16 +17,16 @@ class UserAgent: provide a built-in parser. """ - platform: t.Optional[str] = None + platform: str | None = None """The OS name, if it could be parsed from the string.""" - browser: t.Optional[str] = None + browser: str | None = None """The browser name, if it could be parsed from the string.""" - version: t.Optional[str] = None + version: str | None = None """The browser version, if it could be parsed from the string.""" - language: t.Optional[str] = None + language: str | None = None """The browser language, if it could be parsed from the string.""" def __init__(self, string: str) -> None: diff --git a/src/werkzeug/utils.py b/src/werkzeug/utils.py index 672e6e5..59b97b7 100644 --- a/src/werkzeug/utils.py +++ b/src/werkzeug/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io import mimetypes import os @@ -8,6 +10,7 @@ import unicodedata from datetime import datetime from time import time +from urllib.parse import quote from zlib import adler32 from markupsafe import escape @@ -19,11 +22,11 @@ from .exceptions import NotFound from .exceptions import RequestedRangeNotSatisfiable from .security import safe_join -from .urls import url_quote from .wsgi import wrap_file if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment + from .wrappers.request import Request from .wrappers.response import Response @@ -31,19 +34,14 @@ _entity_re = re.compile(r"&([^;]+);") _filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.-]") -_windows_device_files = ( +_windows_device_files = { "CON", - "AUX", - "COM1", - "COM2", - "COM3", - "COM4", - "LPT1", - "LPT2", - "LPT3", "PRN", + "AUX", "NUL", -) + *(f"COM{i}" for i in range(10)), + *(f"LPT{i}" for i in range(10)), +} class cached_property(property, t.Generic[_T]): @@ -80,8 +78,8 @@ def value(self): def __init__( self, fget: t.Callable[[t.Any], _T], - name: t.Optional[str] = None, - doc: t.Optional[str] = None, + name: str | None = None, + doc: str | None = None, ) -> None: super().__init__(fget, doc=doc) self.__name__ = name or fget.__name__ @@ -145,14 +143,14 @@ class environ_property(_DictAccessorProperty[_TAccessorValue]): read_only = True - def lookup(self, obj: "Request") -> "WSGIEnvironment": + def lookup(self, obj: Request) -> WSGIEnvironment: return obj.environ class header_property(_DictAccessorProperty[_TAccessorValue]): """Like `environ_property` but for headers.""" - def lookup(self, obj: t.Union["Request", "Response"]) -> Headers: + def lookup(self, obj: Request | Response) -> Headers: return obj.headers @@ -221,7 +219,7 @@ def secure_filename(filename: str) -> str: filename = unicodedata.normalize("NFKD", filename) filename = filename.encode("ascii", "ignore").decode("ascii") - for sep in os.path.sep, os.path.altsep: + for sep in os.sep, os.path.altsep: if sep: filename = filename.replace(sep, " ") filename = str(_filename_ascii_strip_re.sub("", "_".join(filename.split()))).strip( @@ -242,8 +240,8 @@ def secure_filename(filename: str) -> str: def redirect( - location: str, code: int = 302, Response: t.Optional[t.Type["Response"]] = None -) -> "Response": + location: str, code: int = 302, Response: type[Response] | None = None +) -> Response: """Returns a response object (a WSGI application) that, if called, redirects the client to the target location. Supported codes are 301, 302, 303, 305, 307, and 308. 300 is not supported because @@ -264,24 +262,16 @@ def redirect( unspecified. """ if Response is None: - from .wrappers import Response # type: ignore - - display_location = escape(location) - if isinstance(location, str): - # Safe conversion is necessary here as we might redirect - # to a broken URI scheme (for instance itms-services). - from .urls import iri_to_uri - - location = iri_to_uri(location, safe_conversion=True) + from .wrappers import Response - response = Response( # type: ignore + html_location = escape(location) + response = Response( # type: ignore[misc] "\n" "\n" "Redirecting...\n" "

Redirecting...

\n" "

You should be redirected automatically to the target URL: " - f'{display_location}. If' - " not, click the link.\n", + f'{html_location}. If not, click the link.\n', code, mimetype="text/html", ) @@ -289,7 +279,7 @@ def redirect( return response -def append_slash_redirect(environ: "WSGIEnvironment", code: int = 308) -> "Response": +def append_slash_redirect(environ: WSGIEnvironment, code: int = 308) -> Response: """Redirect to the current URL with a slash appended. If the current URL is ``/user/42``, the redirect URL will be @@ -327,21 +317,19 @@ def append_slash_redirect(environ: "WSGIEnvironment", code: int = 308) -> "Respo def send_file( - path_or_file: t.Union[os.PathLike, str, t.IO[bytes]], - environ: "WSGIEnvironment", - mimetype: t.Optional[str] = None, + path_or_file: os.PathLike[str] | str | t.IO[bytes], + environ: WSGIEnvironment, + mimetype: str | None = None, as_attachment: bool = False, - download_name: t.Optional[str] = None, + download_name: str | None = None, conditional: bool = True, - etag: t.Union[bool, str] = True, - last_modified: t.Optional[t.Union[datetime, int, float]] = None, - max_age: t.Optional[ - t.Union[int, t.Callable[[t.Optional[str]], t.Optional[int]]] - ] = None, + etag: bool | str = True, + last_modified: datetime | int | float | None = None, + max_age: None | (int | t.Callable[[str | None], int | None]) = None, use_x_sendfile: bool = False, - response_class: t.Optional[t.Type["Response"]] = None, - _root_path: t.Optional[t.Union[os.PathLike, str]] = None, -) -> "Response": + response_class: type[Response] | None = None, + _root_path: os.PathLike[str] | str | None = None, +) -> Response: """Send the contents of a file to the client. The first argument can be a file path or a file-like object. Paths @@ -352,7 +340,7 @@ def send_file( Never pass file paths provided by a user. The path is assumed to be trusted, so a user could craft a path to access a file you didn't - intend. + intend. Use :func:`send_from_directory` to safely serve user-provided paths. If the WSGI server sets a ``file_wrapper`` in ``environ``, it is used, otherwise Werkzeug's built-in wrapper is used. Alternatively, @@ -419,16 +407,16 @@ def send_file( response_class = Response - path: t.Optional[str] = None - file: t.Optional[t.IO[bytes]] = None - size: t.Optional[int] = None - mtime: t.Optional[float] = None + path: str | None = None + file: t.IO[bytes] | None = None + size: int | None = None + mtime: float | None = None headers = Headers() if isinstance(path_or_file, (os.PathLike, str)) or hasattr( path_or_file, "__fspath__" ): - path_or_file = t.cast(t.Union[os.PathLike, str], path_or_file) + path_or_file = t.cast("t.Union[os.PathLike[str], str]", path_or_file) # Flask will pass app.root_path, allowing its send_file wrapper # to not have to deal with paths. @@ -470,7 +458,8 @@ def send_file( except UnicodeEncodeError: simple = unicodedata.normalize("NFKD", download_name) simple = simple.encode("ascii", "ignore").decode("ascii") - quoted = url_quote(download_name, safe="") + # safe = RFC 5987 attr-char + quoted = quote(download_name, safe="!#$&+-.^_`|~") names = {"filename": simple, "filename*": f"UTF-8''{quoted}"} else: names = {"filename": download_name} @@ -526,7 +515,7 @@ def send_file( if isinstance(etag, str): rv.set_etag(etag) elif etag and path is not None: - check = adler32(path.encode("utf-8")) & 0xFFFFFFFF + check = adler32(path.encode()) & 0xFFFFFFFF rv.set_etag(f"{mtime}-{size}-{check}") if conditional: @@ -547,11 +536,11 @@ def send_file( def send_from_directory( - directory: t.Union[os.PathLike, str], - path: t.Union[os.PathLike, str], - environ: "WSGIEnvironment", + directory: os.PathLike[str] | str, + path: os.PathLike[str] | str, + environ: WSGIEnvironment, **kwargs: t.Any, -) -> "Response": +) -> Response: """Send a file from within a directory using :func:`send_file`. This is a secure way to serve files from a folder, such as static @@ -562,33 +551,30 @@ def send_from_directory( If the final path does not point to an existing regular file, returns a 404 :exc:`~werkzeug.exceptions.NotFound` error. - :param directory: The directory that ``path`` must be located under. - :param path: The path to the file to send, relative to - ``directory``. + :param directory: The directory that ``path`` must be located under. This *must not* + be a value provided by the client, otherwise it becomes insecure. + :param path: The path to the file to send, relative to ``directory``. This is the + part of the path provided by the client, which is checked for security. :param environ: The WSGI environ for the current request. :param kwargs: Arguments to pass to :func:`send_file`. .. versionadded:: 2.0 Adapted from Flask's implementation. """ - path = safe_join(os.fspath(directory), os.fspath(path)) + path_str = safe_join(os.fspath(directory), os.fspath(path)) - if path is None: + if path_str is None: raise NotFound() # Flask will pass app.root_path, allowing its send_from_directory # wrapper to not have to deal with paths. if "_root_path" in kwargs: - path = os.path.join(kwargs["_root_path"], path) + path_str = os.path.join(kwargs["_root_path"], path_str) - try: - if not os.path.isfile(path): - raise NotFound() - except ValueError: - # path contains null byte on Python < 3.8 - raise NotFound() from None + if not os.path.isfile(path_str): + raise NotFound() - return send_file(path, environ, **kwargs) + return send_file(path_str, environ, **kwargs) def import_string(import_name: str, silent: bool = False) -> t.Any: diff --git a/src/werkzeug/wrappers/__init__.py b/src/werkzeug/wrappers/__init__.py index b8c45d7..b36f228 100644 --- a/src/werkzeug/wrappers/__init__.py +++ b/src/werkzeug/wrappers/__init__.py @@ -1,3 +1,3 @@ from .request import Request as Request from .response import Response as Response -from .response import ResponseStream +from .response import ResponseStream as ResponseStream diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py index 57b739c..38053c2 100644 --- a/src/werkzeug/wrappers/request.py +++ b/src/werkzeug/wrappers/request.py @@ -1,6 +1,8 @@ +from __future__ import annotations + +import collections.abc as cabc import functools import json -import typing import typing as t from io import BytesIO @@ -11,6 +13,8 @@ from ..datastructures import ImmutableMultiDict from ..datastructures import iter_multi_items from ..datastructures import MultiDict +from ..exceptions import BadRequest +from ..exceptions import UnsupportedMediaType from ..formparser import default_stream_factory from ..formparser import FormDataParser from ..sansio.request import Request as _SansIORequest @@ -18,10 +22,8 @@ from ..utils import environ_property from ..wsgi import _get_server from ..wsgi import get_input_stream -from werkzeug.exceptions import BadRequest if t.TYPE_CHECKING: - import typing_extensions as te from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment @@ -49,13 +51,19 @@ class Request(_SansIORequest): prevent consuming the form data in middleware, which would make it unavailable to the final application. + .. versionchanged:: 3.0 + The ``charset``, ``url_charset``, and ``encoding_errors`` parameters + were removed. + + .. versionchanged:: 2.1 + Old ``BaseRequest`` and mixin classes were removed. + .. versionchanged:: 2.1 Remove the ``disable_data_descriptor`` attribute. .. versionchanged:: 2.0 Combine ``BaseRequest`` and mixins into a single ``Request`` - class. Using the old classes is deprecated and will be removed - in Werkzeug 2.1. + class. .. versionchanged:: 0.5 Read-only mode is enforced with immutable classes for all data. @@ -67,10 +75,8 @@ class Request(_SansIORequest): #: parsing fails because more than the specified value is transmitted #: a :exc:`~werkzeug.exceptions.RequestEntityTooLarge` exception is raised. #: - #: Have a look at :doc:`/request_data` for more details. - #: #: .. versionadded:: 0.5 - max_content_length: t.Optional[int] = None + max_content_length: int | None = None #: the maximum form field size. This is forwarded to the form data #: parsing function (:func:`parse_form_data`). When set and the @@ -78,18 +84,23 @@ class Request(_SansIORequest): #: data in memory for post data is longer than the specified value a #: :exc:`~werkzeug.exceptions.RequestEntityTooLarge` exception is raised. #: - #: Have a look at :doc:`/request_data` for more details. - #: #: .. versionadded:: 0.5 - max_form_memory_size: t.Optional[int] = None + max_form_memory_size: int | None = None + + #: The maximum number of multipart parts to parse, passed to + #: :attr:`form_data_parser_class`. Parsing form data with more than this + #: many parts will raise :exc:`~.RequestEntityTooLarge`. + #: + #: .. versionadded:: 2.2.3 + max_form_parts = 1000 #: The form data parser that should be used. Can be replaced to customize #: the form date parsing. - form_data_parser_class: t.Type[FormDataParser] = FormDataParser + form_data_parser_class: type[FormDataParser] = FormDataParser #: The WSGI environment containing HTTP headers and information from #: the WSGI server. - environ: "WSGIEnvironment" + environ: WSGIEnvironment #: Set when creating the request object. If ``True``, reading from #: the request body will cause a ``RuntimeException``. Useful to @@ -98,7 +109,7 @@ class Request(_SansIORequest): def __init__( self, - environ: "WSGIEnvironment", + environ: WSGIEnvironment, populate_request: bool = True, shallow: bool = False, ) -> None: @@ -106,12 +117,8 @@ def __init__( method=environ.get("REQUEST_METHOD", "GET"), scheme=environ.get("wsgi.url_scheme", "http"), server=_get_server(environ), - root_path=_wsgi_decoding_dance( - environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors - ), - path=_wsgi_decoding_dance( - environ.get("PATH_INFO") or "", self.charset, self.encoding_errors - ), + root_path=_wsgi_decoding_dance(environ.get("SCRIPT_NAME") or ""), + path=_wsgi_decoding_dance(environ.get("PATH_INFO") or ""), query_string=environ.get("QUERY_STRING", "").encode("latin1"), headers=EnvironHeaders(environ), remote_addr=environ.get("REMOTE_ADDR"), @@ -123,7 +130,7 @@ def __init__( self.environ["werkzeug.request"] = self @classmethod - def from_values(cls, *args: t.Any, **kwargs: t.Any) -> "Request": + def from_values(cls, *args: t.Any, **kwargs: t.Any) -> Request: """Create a new request object based on the values provided. If environ is given missing values are filled from there. This method is useful for small scripts when you need to simulate a request from an URL. @@ -143,8 +150,6 @@ def from_values(cls, *args: t.Any, **kwargs: t.Any) -> "Request": """ from ..test import EnvironBuilder - charset = kwargs.pop("charset", cls.charset) - kwargs["charset"] = charset builder = EnvironBuilder(*args, **kwargs) try: return builder.get_request(cls) @@ -152,9 +157,7 @@ def from_values(cls, *args: t.Any, **kwargs: t.Any) -> "Request": builder.close() @classmethod - def application( - cls, f: t.Callable[["Request"], "WSGIApplication"] - ) -> "WSGIApplication": + def application(cls, f: t.Callable[[Request], WSGIApplication]) -> WSGIApplication: """Decorate a function as responder that accepts the request as the last argument. This works like the :func:`responder` decorator but the function is passed the request object as the @@ -180,23 +183,23 @@ def my_wsgi_app(request): from ..exceptions import HTTPException @functools.wraps(f) - def application(*args): # type: ignore + def application(*args: t.Any) -> cabc.Iterable[bytes]: request = cls(args[-2]) with request: try: resp = f(*args[:-2] + (request,)) except HTTPException as e: - resp = e.get_response(args[-2]) + resp = t.cast("WSGIApplication", e.get_response(args[-2])) return resp(*args[-2:]) return t.cast("WSGIApplication", application) def _get_file_stream( self, - total_content_length: t.Optional[int], - content_type: t.Optional[str], - filename: t.Optional[str] = None, - content_length: t.Optional[int] = None, + total_content_length: int | None, + content_type: str | None, + filename: str | None = None, + content_length: int | None = None, ) -> t.IO[bytes]: """Called to get a stream for the file upload. @@ -240,12 +243,11 @@ def make_form_data_parser(self) -> FormDataParser: .. versionadded:: 0.8 """ return self.form_data_parser_class( - self._get_file_stream, - self.charset, - self.encoding_errors, - self.max_form_memory_size, - self.max_content_length, - self.parameter_storage_class, + stream_factory=self._get_file_stream, + max_form_memory_size=self.max_form_memory_size, + max_content_length=self.max_content_length, + max_form_parts=self.max_form_parts, + cls=self.parameter_storage_class, ) def _load_form_data(self) -> None: @@ -304,7 +306,7 @@ def close(self) -> None: for _key, value in iter_multi_items(files or ()): value.close() - def __enter__(self) -> "Request": + def __enter__(self) -> Request: return self def __exit__(self, exc_type, exc_value, tb) -> None: # type: ignore @@ -312,21 +314,30 @@ def __exit__(self, exc_type, exc_value, tb) -> None: # type: ignore @cached_property def stream(self) -> t.IO[bytes]: - """ - If the incoming form data was not encoded with a known mimetype - the data is stored unmodified in this stream for consumption. Most - of the time it is a better idea to use :attr:`data` which will give - you that data as a string. The stream only returns the data once. + """The WSGI input stream, with safety checks. This stream can only be consumed + once. + + Use :meth:`get_data` to get the full data as bytes or text. The :attr:`data` + attribute will contain the full bytes only if they do not represent form data. + The :attr:`form` attribute will contain the parsed form data in that case. + + Unlike :attr:`input_stream`, this stream guards against infinite streams or + reading past :attr:`content_length` or :attr:`max_content_length`. + + If ``max_content_length`` is set, it can be enforced on streams if + ``wsgi.input_terminated`` is set. Otherwise, an empty stream is returned. - Unlike :attr:`input_stream` this stream is properly guarded that you - can't accidentally read past the length of the input. Werkzeug will - internally always refer to this stream to read data which makes it - possible to wrap this object with a stream that does filtering. + If the limit is reached before the underlying stream is exhausted (such as a + file that is too large, or an infinite stream), the remaining contents of the + stream cannot be read safely. Depending on how the server handles this, clients + may show a "connection reset" failure instead of seeing the 413 response. + + .. versionchanged:: 2.3 + Check ``max_content_length`` preemptively and while reading. .. versionchanged:: 0.9 - This stream is now always available but might be consumed by the - form parser later on. Previously the stream was only set if no - parsing happened. + The stream is always set (but may be consumed) even if form parsing was + accessed first. """ if self.shallow: raise RuntimeError( @@ -334,46 +345,49 @@ def stream(self) -> t.IO[bytes]: " from the input stream is disabled." ) - return get_input_stream(self.environ) + return get_input_stream( + self.environ, max_content_length=self.max_content_length + ) input_stream = environ_property[t.IO[bytes]]( "wsgi.input", - doc="""The WSGI input stream. + doc="""The raw WSGI input stream, without any safety checks. + + This is dangerous to use. It does not guard against infinite streams or reading + past :attr:`content_length` or :attr:`max_content_length`. - In general it's a bad idea to use this one because you can - easily read past the boundary. Use the :attr:`stream` - instead.""", + Use :attr:`stream` instead. + """, ) @cached_property def data(self) -> bytes: - """ - Contains the incoming request data as string in case it came with - a mimetype Werkzeug does not handle. + """The raw data read from :attr:`stream`. Will be empty if the request + represents form data. + + To get the raw data even if it represents form data, use :meth:`get_data`. """ return self.get_data(parse_form_data=True) - @typing.overload + @t.overload def get_data( # type: ignore self, cache: bool = True, - as_text: "te.Literal[False]" = False, + as_text: t.Literal[False] = False, parse_form_data: bool = False, - ) -> bytes: - ... + ) -> bytes: ... - @typing.overload + @t.overload def get_data( self, cache: bool = True, - as_text: "te.Literal[True]" = ..., + as_text: t.Literal[True] = ..., parse_form_data: bool = False, - ) -> str: - ... + ) -> str: ... def get_data( self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False - ) -> t.Union[bytes, str]: + ) -> bytes | str: """This reads the buffered incoming data from the client into one bytes object. By default this is cached but that behavior can be changed by setting `cache` to `False`. @@ -406,11 +420,11 @@ def get_data( if cache: self._cached_data = rv if as_text: - rv = rv.decode(self.charset, self.encoding_errors) + rv = rv.decode(errors="replace") return rv @cached_property - def form(self) -> "ImmutableMultiDict[str, str]": + def form(self) -> ImmutableMultiDict[str, str]: """The form parameters. By default an :class:`~werkzeug.datastructures.ImmutableMultiDict` is returned from this function. This can be changed by setting @@ -429,7 +443,7 @@ def form(self) -> "ImmutableMultiDict[str, str]": return self.form @cached_property - def values(self) -> "CombinedMultiDict[str, str]": + def values(self) -> CombinedMultiDict[str, str]: """A :class:`werkzeug.datastructures.CombinedMultiDict` that combines :attr:`args` and :attr:`form`. @@ -458,7 +472,7 @@ def values(self) -> "CombinedMultiDict[str, str]": return CombinedMultiDict(args) @cached_property - def files(self) -> "ImmutableMultiDict[str, FileStorage]": + def files(self) -> ImmutableMultiDict[str, FileStorage]: """:class:`~werkzeug.datastructures.MultiDict` object containing all uploaded files. Each key in :attr:`files` is the name from the ````. Each value in :attr:`files` is a @@ -525,14 +539,17 @@ def url_root(self) -> str: json_module = json @property - def json(self) -> t.Optional[t.Any]: + def json(self) -> t.Any | None: """The parsed JSON data if :attr:`mimetype` indicates JSON (:mimetype:`application/json`, see :attr:`is_json`). Calls :meth:`get_json` with default arguments. If the request content type is not ``application/json``, this - will raise a 400 Bad Request error. + will raise a 415 Unsupported Media Type error. + + .. versionchanged:: 2.3 + Raise a 415 error instead of 400. .. versionchanged:: 2.1 Raise a 400 error if the content type is incorrect. @@ -541,18 +558,28 @@ def json(self) -> t.Optional[t.Any]: # Cached values for ``(silent=False, silent=True)``. Initialized # with sentinel values. - _cached_json: t.Tuple[t.Any, t.Any] = (Ellipsis, Ellipsis) + _cached_json: tuple[t.Any, t.Any] = (Ellipsis, Ellipsis) + + @t.overload + def get_json( + self, force: bool = ..., silent: t.Literal[False] = ..., cache: bool = ... + ) -> t.Any: ... + + @t.overload + def get_json( + self, force: bool = ..., silent: bool = ..., cache: bool = ... + ) -> t.Any | None: ... def get_json( self, force: bool = False, silent: bool = False, cache: bool = True - ) -> t.Optional[t.Any]: + ) -> t.Any | None: """Parse :attr:`data` as JSON. If the mimetype does not indicate JSON (:mimetype:`application/json`, see :attr:`is_json`), or parsing fails, :meth:`on_json_loading_failed` is called and its return value is used as the return value. By default this - raises a 400 Bad Request error. + raises a 415 Unsupported Media Type resp. :param force: Ignore the mimetype and always try to parse JSON. :param silent: Silence mimetype and parsing errors, and @@ -560,6 +587,9 @@ def get_json( :param cache: Store the parsed JSON to return for subsequent calls. + .. versionchanged:: 2.3 + Raise a 415 error instead of 400. + .. versionchanged:: 2.1 Raise a 400 error if the content type is incorrect. """ @@ -595,7 +625,7 @@ def get_json( return rv - def on_json_loading_failed(self, e: t.Optional[ValueError]) -> t.Any: + def on_json_loading_failed(self, e: ValueError | None) -> t.Any: """Called if :meth:`get_json` fails and isn't silenced. If this method returns a value, it is used as the return value @@ -604,11 +634,14 @@ def on_json_loading_failed(self, e: t.Optional[ValueError]) -> t.Any: :param e: If parsing failed, this is the exception. It will be ``None`` if the content type wasn't ``application/json``. + + .. versionchanged:: 2.3 + Raise a 415 error instead of 400. """ if e is not None: raise BadRequest(f"Failed to decode JSON object: {e}") - raise BadRequest( + raise UnsupportedMediaType( "Did not attempt to load JSON data because the request" " Content-Type was not 'application/json'." ) diff --git a/src/werkzeug/wrappers/response.py b/src/werkzeug/wrappers/response.py index 7e888cb..7f01287 100644 --- a/src/werkzeug/wrappers/response.py +++ b/src/werkzeug/wrappers/response.py @@ -1,69 +1,41 @@ +from __future__ import annotations + import json -import typing import typing as t -import warnings from http import HTTPStatus +from urllib.parse import urljoin -from .._internal import _to_bytes +from .._internal import _get_environ from ..datastructures import Headers +from ..http import generate_etag +from ..http import http_date +from ..http import is_resource_modified +from ..http import parse_etags +from ..http import parse_range_header from ..http import remove_entity_headers from ..sansio.response import Response as _SansIOResponse from ..urls import iri_to_uri -from ..urls import url_join from ..utils import cached_property +from ..wsgi import _RangeWrapper from ..wsgi import ClosingIterator from ..wsgi import get_current_url -from werkzeug._internal import _get_environ -from werkzeug.http import generate_etag -from werkzeug.http import http_date -from werkzeug.http import is_resource_modified -from werkzeug.http import parse_etags -from werkzeug.http import parse_range_header -from werkzeug.wsgi import _RangeWrapper if t.TYPE_CHECKING: - import typing_extensions as te from _typeshed.wsgi import StartResponse from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment - from .request import Request - -def _warn_if_string(iterable: t.Iterable) -> None: - """Helper for the response objects to check if the iterable returned - to the WSGI server is not a string. - """ - if isinstance(iterable, str): - warnings.warn( - "Response iterable was set to a string. This will appear to" - " work but means that the server will send the data to the" - " client one character at a time. This is almost never" - " intended behavior, use 'response.data' to assign strings" - " to the response object.", - stacklevel=2, - ) + from .request import Request -def _iter_encoded( - iterable: t.Iterable[t.Union[str, bytes]], charset: str -) -> t.Iterator[bytes]: +def _iter_encoded(iterable: t.Iterable[str | bytes]) -> t.Iterator[bytes]: for item in iterable: if isinstance(item, str): - yield item.encode(charset) + yield item.encode() else: yield item -def _clean_accept_ranges(accept_ranges: t.Union[bool, str]) -> str: - if accept_ranges is True: - return "bytes" - elif accept_ranges is False: - return "none" - elif isinstance(accept_ranges, str): - return accept_ranges - raise ValueError("Invalid accept_ranges value") - - class Response(_SansIOResponse): """Represents an outgoing WSGI HTTP response with body, status, and headers. Has properties and methods for using the functionality @@ -123,10 +95,12 @@ def application(environ, start_response): checks. Use :func:`~werkzeug.utils.send_file` instead of setting this manually. + .. versionchanged:: 2.1 + Old ``BaseResponse`` and mixin classes were removed. + .. versionchanged:: 2.0 Combine ``BaseResponse`` and mixins into a single ``Response`` - class. Using the old classes is deprecated and will be removed - in Werkzeug 2.1. + class. .. versionchanged:: 0.5 The ``direct_passthrough`` parameter was added. @@ -165,22 +139,17 @@ def application(environ, start_response): #: Do not set to a plain string or bytes, that will cause sending #: the response to be very inefficient as it will iterate one byte #: at a time. - response: t.Union[t.Iterable[str], t.Iterable[bytes]] + response: t.Iterable[str] | t.Iterable[bytes] def __init__( self, - response: t.Optional[ - t.Union[t.Iterable[bytes], bytes, t.Iterable[str], str] - ] = None, - status: t.Optional[t.Union[int, str, HTTPStatus]] = None, - headers: t.Optional[ - t.Union[ - t.Mapping[str, t.Union[str, int, t.Iterable[t.Union[str, int]]]], - t.Iterable[t.Tuple[str, t.Union[str, int]]], - ] - ] = None, - mimetype: t.Optional[str] = None, - content_type: t.Optional[str] = None, + response: t.Iterable[bytes] | bytes | t.Iterable[str] | str | None = None, + status: int | str | HTTPStatus | None = None, + headers: t.Mapping[str, str | t.Iterable[str]] + | t.Iterable[tuple[str, str]] + | None = None, + mimetype: str | None = None, + content_type: str | None = None, direct_passthrough: bool = False, ) -> None: super().__init__( @@ -196,7 +165,7 @@ def __init__( #: :func:`~werkzeug.utils.send_file` instead of setting this #: manually. self.direct_passthrough = direct_passthrough - self._on_close: t.List[t.Callable[[], t.Any]] = [] + self._on_close: list[t.Callable[[], t.Any]] = [] # we set the response after the headers so that if a class changes # the charset attribute, the data is set in the correct charset. @@ -227,8 +196,8 @@ def __repr__(self) -> str: @classmethod def force_type( - cls, response: "Response", environ: t.Optional["WSGIEnvironment"] = None - ) -> "Response": + cls, response: Response, environ: WSGIEnvironment | None = None + ) -> Response: """Enforce that the WSGI response is a response object of the current type. Werkzeug will use the :class:`Response` internally in many situations like the exceptions. If you call :meth:`get_response` on an @@ -272,8 +241,8 @@ def force_type( @classmethod def from_app( - cls, app: "WSGIApplication", environ: "WSGIEnvironment", buffered: bool = False - ) -> "Response": + cls, app: WSGIApplication, environ: WSGIEnvironment, buffered: bool = False + ) -> Response: """Create a new response object from an application output. This works best if you pass it an application that returns a generator all the time. Sometimes applications may use the `write()` callable @@ -290,15 +259,13 @@ def from_app( return cls(*run_wsgi_app(app, environ, buffered)) - @typing.overload - def get_data(self, as_text: "te.Literal[False]" = False) -> bytes: - ... + @t.overload + def get_data(self, as_text: t.Literal[False] = False) -> bytes: ... - @typing.overload - def get_data(self, as_text: "te.Literal[True]") -> str: - ... + @t.overload + def get_data(self, as_text: t.Literal[True]) -> str: ... - def get_data(self, as_text: bool = False) -> t.Union[bytes, str]: + def get_data(self, as_text: bool = False) -> bytes | str: """The string representation of the response body. Whenever you call this property the response iterable is encoded and flattened. This can lead to unwanted behavior if you stream big data. @@ -315,23 +282,19 @@ def get_data(self, as_text: bool = False) -> t.Union[bytes, str]: rv = b"".join(self.iter_encoded()) if as_text: - return rv.decode(self.charset) + return rv.decode() return rv - def set_data(self, value: t.Union[bytes, str]) -> None: + def set_data(self, value: bytes | str) -> None: """Sets a new string as response. The value must be a string or bytes. If a string is set it's encoded to the charset of the response (utf-8 by default). .. versionadded:: 0.9 """ - # if a string is set, it's encoded directly so that we - # can set the content length if isinstance(value, str): - value = value.encode(self.charset) - else: - value = bytes(value) + value = value.encode() self.response = [value] if self.automatically_set_content_length: self.headers["Content-Length"] = str(len(value)) @@ -342,7 +305,7 @@ def set_data(self, value: t.Union[bytes, str]) -> None: doc="A descriptor that calls :meth:`get_data` and :meth:`set_data`.", ) - def calculate_content_length(self) -> t.Optional[int]: + def calculate_content_length(self) -> int | None: """Returns the content length if available or `None` otherwise.""" try: self._ensure_sequence() @@ -398,12 +361,10 @@ def iter_encoded(self) -> t.Iterator[bytes]: value of this method is used as application iterator unless :attr:`direct_passthrough` was activated. """ - if __debug__: - _warn_if_string(self.response) # Encode in a separate function so that self.response is fetched # early. This allows us to wrap the response with the return # value from get_app_iter or iter_encoded. - return _iter_encoded(self.response, self.charset) + return _iter_encoded(self.response) @property def is_streamed(self) -> bool: @@ -439,11 +400,11 @@ def close(self) -> None: Can now be used in a with statement. """ if hasattr(self.response, "close"): - self.response.close() # type: ignore + self.response.close() for func in self._on_close: func() - def __enter__(self) -> "Response": + def __enter__(self) -> Response: return self def __exit__(self, exc_type, exc_value, tb): # type: ignore @@ -463,8 +424,7 @@ def freeze(self) -> None: Removed the ``no_etag`` parameter. .. versionchanged:: 2.0 - An ``ETag`` header is added, the ``no_etag`` parameter is - deprecated and will be removed in Werkzeug 2.1. + An ``ETag`` header is always added. .. versionchanged:: 0.6 The ``Content-Length`` header is set. @@ -475,7 +435,7 @@ def freeze(self) -> None: self.headers["Content-Length"] = str(sum(map(len, self.response))) self.add_etag() - def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers: + def get_wsgi_headers(self, environ: WSGIEnvironment) -> Headers: """This is automatically called right before the response is started and returns headers modified for the given environment. It returns a copy of the headers from the response with some modifications applied @@ -500,9 +460,9 @@ def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers: object. """ headers = Headers(self.headers) - location: t.Optional[str] = None - content_location: t.Optional[str] = None - content_length: t.Optional[t.Union[str, int]] = None + location: str | None = None + content_location: str | None = None + content_length: str | int | None = None status = self.status_code # iterate over the headers to find all values in one go. Because @@ -517,24 +477,19 @@ def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers: elif ikey == "content-length": content_length = value - # make sure the location header is an absolute URL if location is not None: - old_location = location - if isinstance(location, str): - # Safe conversion is necessary here as we might redirect - # to a broken URI scheme (for instance itms-services). - location = iri_to_uri(location, safe_conversion=True) + location = iri_to_uri(location) if self.autocorrect_location_header: + # Make the location header an absolute URL. current_url = get_current_url(environ, strip_querystring=True) - if isinstance(current_url, str): - current_url = iri_to_uri(current_url) - location = url_join(current_url, location) - if location != old_location: - headers["Location"] = location + current_url = iri_to_uri(current_url) + location = urljoin(current_url, location) + + headers["Location"] = location # make sure the content location is a URL - if content_location is not None and isinstance(content_location, str): + if content_location is not None: headers["Content-Location"] = iri_to_uri(content_location) if 100 <= status < 200 or status == 204: @@ -557,18 +512,12 @@ def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers: and status not in (204, 304) and not (100 <= status < 200) ): - try: - content_length = sum(len(_to_bytes(x, "ascii")) for x in self.response) - except UnicodeError: - # Something other than bytes, can't safely figure out - # the length of the response. - pass - else: - headers["Content-Length"] = str(content_length) + content_length = sum(len(x) for x in self.iter_encoded()) + headers["Content-Length"] = str(content_length) return headers - def get_app_iter(self, environ: "WSGIEnvironment") -> t.Iterable[bytes]: + def get_app_iter(self, environ: WSGIEnvironment) -> t.Iterable[bytes]: """Returns the application iterator for the given environ. Depending on the request method and the current status code the return value might be an empty response rather than the one from the response. @@ -590,16 +539,14 @@ def get_app_iter(self, environ: "WSGIEnvironment") -> t.Iterable[bytes]: ): iterable: t.Iterable[bytes] = () elif self.direct_passthrough: - if __debug__: - _warn_if_string(self.response) return self.response # type: ignore else: iterable = self.iter_encoded() return ClosingIterator(iterable, self.close) def get_wsgi_response( - self, environ: "WSGIEnvironment" - ) -> t.Tuple[t.Iterable[bytes], str, t.List[t.Tuple[str, str]]]: + self, environ: WSGIEnvironment + ) -> tuple[t.Iterable[bytes], str, list[tuple[str, str]]]: """Returns the final WSGI response as tuple. The first item in the tuple is the application iterator, the second the status and the third the list of headers. The response returned is created @@ -617,7 +564,7 @@ def get_wsgi_response( return app_iter, self.status, headers.to_wsgi_list() def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: """Process this response as WSGI application. @@ -637,7 +584,7 @@ def __call__( json_module = json @property - def json(self) -> t.Optional[t.Any]: + def json(self) -> t.Any | None: """The parsed JSON data if :attr:`mimetype` indicates JSON (:mimetype:`application/json`, see :attr:`is_json`). @@ -645,7 +592,13 @@ def json(self) -> t.Optional[t.Any]: """ return self.get_json() - def get_json(self, force: bool = False, silent: bool = False) -> t.Optional[t.Any]: + @t.overload + def get_json(self, force: bool = ..., silent: t.Literal[False] = ...) -> t.Any: ... + + @t.overload + def get_json(self, force: bool = ..., silent: bool = ...) -> t.Any | None: ... + + def get_json(self, force: bool = False, silent: bool = False) -> t.Any | None: """Parse :attr:`data` as JSON. Useful during testing. If the mimetype does not indicate JSON @@ -674,7 +627,7 @@ def get_json(self, force: bool = False, silent: bool = False) -> t.Optional[t.An # Stream @cached_property - def stream(self) -> "ResponseStream": + def stream(self) -> ResponseStream: """The response iterable as write-only stream.""" return ResponseStream(self) @@ -683,7 +636,7 @@ def _wrap_range_response(self, start: int, length: int) -> None: if self.status_code == 206: self.response = _RangeWrapper(self.response, start, length) # type: ignore - def _is_range_request_processable(self, environ: "WSGIEnvironment") -> bool: + def _is_range_request_processable(self, environ: WSGIEnvironment) -> bool: """Return ``True`` if `Range` header is present and if underlying resource is considered unchanged when compared with `If-Range` header. """ @@ -700,9 +653,9 @@ def _is_range_request_processable(self, environ: "WSGIEnvironment") -> bool: def _process_range_request( self, - environ: "WSGIEnvironment", - complete_length: t.Optional[int] = None, - accept_ranges: t.Optional[t.Union[bool, str]] = None, + environ: WSGIEnvironment, + complete_length: int | None, + accept_ranges: bool | str, ) -> bool: """Handle Range Request related headers (RFC7233). If `Accept-Ranges` header is valid, and Range Request is processable, we set the headers @@ -720,13 +673,16 @@ def _process_range_request( from ..exceptions import RequestedRangeNotSatisfiable if ( - accept_ranges is None + not accept_ranges or complete_length is None or complete_length == 0 or not self._is_range_request_processable(environ) ): return False + if accept_ranges is True: + accept_ranges = "bytes" + parsed_range = parse_range_header(environ.get("HTTP_RANGE")) if parsed_range is None: @@ -739,7 +695,7 @@ def _process_range_request( raise RequestedRangeNotSatisfiable(complete_length) content_length = range_tuple[1] - range_tuple[0] - self.headers["Content-Length"] = content_length + self.headers["Content-Length"] = str(content_length) self.headers["Accept-Ranges"] = accept_ranges self.content_range = content_range_header # type: ignore self.status_code = 206 @@ -748,10 +704,10 @@ def _process_range_request( def make_conditional( self, - request_or_environ: t.Union["WSGIEnvironment", "Request"], - accept_ranges: t.Union[bool, str] = False, - complete_length: t.Optional[int] = None, - ) -> "Response": + request_or_environ: WSGIEnvironment | Request, + accept_ranges: bool | str = False, + complete_length: int | None = None, + ) -> Response: """Make the response conditional to the request. This method works best if an etag was defined for the response already. The `add_etag` method can be used to do that. If called without etag just the date @@ -777,8 +733,7 @@ def make_conditional( :param accept_ranges: This parameter dictates the value of `Accept-Ranges` header. If ``False`` (default), the header is not set. If ``True``, it will be set - to ``"bytes"``. If ``None``, it will be set to - ``"none"``. If it's a string, it will use this + to ``"bytes"``. If it's a string, it will use this value. :param complete_length: Will be used only in valid Range Requests. It will set `Content-Range` complete length @@ -800,7 +755,6 @@ def make_conditional( # wsgiref. if "date" not in self.headers: self.headers["Date"] = http_date() - accept_ranges = _clean_accept_ranges(accept_ranges) is206 = self._process_range_request(environ, complete_length, accept_ranges) if not is206 and not is_resource_modified( environ, @@ -818,7 +772,7 @@ def make_conditional( ): length = self.calculate_content_length() if length is not None: - self.headers["Content-Length"] = length + self.headers["Content-Length"] = str(length) return self def add_etag(self, overwrite: bool = False, weak: bool = False) -> None: @@ -874,4 +828,4 @@ def tell(self) -> int: @property def encoding(self) -> str: - return self.response.charset + return "utf-8" diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py index 24ece0b..01d40af 100644 --- a/src/werkzeug/wsgi.py +++ b/src/werkzeug/wsgi.py @@ -1,28 +1,21 @@ +from __future__ import annotations + import io -import re import typing as t -import warnings from functools import partial from functools import update_wrapper -from itertools import chain -from ._internal import _make_encode_wrapper -from ._internal import _to_bytes -from ._internal import _to_str +from .exceptions import ClientDisconnected +from .exceptions import RequestEntityTooLarge from .sansio import utils as _sansio_utils from .sansio.utils import host_is_trusted # noqa: F401 # Imported as part of API -from .urls import _URLTuple -from .urls import uri_to_iri -from .urls import url_join -from .urls import url_parse -from .urls import url_quote if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment -def responder(f: t.Callable[..., "WSGIApplication"]) -> "WSGIApplication": +def responder(f: t.Callable[..., WSGIApplication]) -> WSGIApplication: """Marks a function as responder. Decorate a function with it and it will automatically call the return value as WSGI application. @@ -36,11 +29,11 @@ def application(environ, start_response): def get_current_url( - environ: "WSGIEnvironment", + environ: WSGIEnvironment, root_only: bool = False, strip_querystring: bool = False, host_only: bool = False, - trusted_hosts: t.Optional[t.Iterable[str]] = None, + trusted_hosts: t.Iterable[str] | None = None, ) -> str: """Recreate the URL for a request from the parts in a WSGI environment. @@ -74,15 +67,15 @@ def get_current_url( def _get_server( - environ: "WSGIEnvironment", -) -> t.Optional[t.Tuple[str, t.Optional[int]]]: + environ: WSGIEnvironment, +) -> tuple[str, int | None] | None: name = environ.get("SERVER_NAME") if name is None: return None try: - port: t.Optional[int] = int(environ.get("SERVER_PORT", None)) + port: int | None = int(environ.get("SERVER_PORT", None)) except (TypeError, ValueError): # unix socket port = None @@ -91,7 +84,7 @@ def _get_server( def get_host( - environ: "WSGIEnvironment", trusted_hosts: t.Optional[t.Iterable[str]] = None + environ: WSGIEnvironment, trusted_hosts: t.Iterable[str] | None = None ) -> str: """Return the host for the given WSGI environment. @@ -118,337 +111,101 @@ def get_host( ) -def get_content_length(environ: "WSGIEnvironment") -> t.Optional[int]: - """Returns the content length from the WSGI environment as - integer. If it's not available or chunked transfer encoding is used, - ``None`` is returned. +def get_content_length(environ: WSGIEnvironment) -> int | None: + """Return the ``Content-Length`` header value as an int. If the header is not given + or the ``Transfer-Encoding`` header is ``chunked``, ``None`` is returned to indicate + a streaming request. If the value is not an integer, or negative, 0 is returned. - .. versionadded:: 0.9 + :param environ: The WSGI environ to get the content length from. - :param environ: the WSGI environ to fetch the content length from. + .. versionadded:: 0.9 """ return _sansio_utils.get_content_length( http_content_length=environ.get("CONTENT_LENGTH"), - http_transfer_encoding=environ.get("HTTP_TRANSFER_ENCODING", ""), + http_transfer_encoding=environ.get("HTTP_TRANSFER_ENCODING"), ) def get_input_stream( - environ: "WSGIEnvironment", safe_fallback: bool = True + environ: WSGIEnvironment, + safe_fallback: bool = True, + max_content_length: int | None = None, ) -> t.IO[bytes]: - """Returns the input stream from the WSGI environment and wraps it - in the most sensible way possible. The stream returned is not the - raw WSGI stream in most cases but one that is safe to read from - without taking into account the content length. - - If content length is not set, the stream will be empty for safety reasons. - If the WSGI server supports chunked or infinite streams, it should set - the ``wsgi.input_terminated`` value in the WSGI environ to indicate that. - - .. versionadded:: 0.9 - - :param environ: the WSGI environ to fetch the stream from. - :param safe_fallback: use an empty stream as a safe fallback when the - content length is not set. Disabling this allows infinite streams, - which can be a denial-of-service risk. - """ - stream = t.cast(t.IO[bytes], environ["wsgi.input"]) - content_length = get_content_length(environ) + """Return the WSGI input stream, wrapped so that it may be read safely without going + past the ``Content-Length`` header value or ``max_content_length``. - # A wsgi extension that tells us if the input is terminated. In - # that case we return the stream unchanged as we know we can safely - # read it until the end. - if environ.get("wsgi.input_terminated"): - return stream + If ``Content-Length`` exceeds ``max_content_length``, a + :exc:`RequestEntityTooLarge`` ``413 Content Too Large`` error is raised. - # If the request doesn't specify a content length, returning the stream is - # potentially dangerous because it could be infinite, malicious or not. If - # safe_fallback is true, return an empty stream instead for safety. - if content_length is None: - return io.BytesIO() if safe_fallback else stream + If the WSGI server sets ``environ["wsgi.input_terminated"]``, it indicates that the + server handles terminating the stream, so it is safe to read directly. For example, + a server that knows how to handle chunked requests safely would set this. - # Otherwise limit the stream to the content length - return t.cast(t.IO[bytes], LimitedStream(stream, content_length)) + If ``max_content_length`` is set, it can be enforced on streams if + ``wsgi.input_terminated`` is set. Otherwise, an empty stream is returned unless the + user explicitly disables this safe fallback. + If the limit is reached before the underlying stream is exhausted (such as a file + that is too large, or an infinite stream), the remaining contents of the stream + cannot be read safely. Depending on how the server handles this, clients may show a + "connection reset" failure instead of seeing the 413 response. -def get_query_string(environ: "WSGIEnvironment") -> str: - """Returns the ``QUERY_STRING`` from the WSGI environment. This also - takes care of the WSGI decoding dance. The string returned will be - restricted to ASCII characters. + :param environ: The WSGI environ containing the stream. + :param safe_fallback: Return an empty stream when ``Content-Length`` is not set. + Disabling this allows infinite streams, which can be a denial-of-service risk. + :param max_content_length: The maximum length that content-length or streaming + requests may not exceed. - :param environ: WSGI environment to get the query string from. + .. versionchanged:: 2.3.2 + ``max_content_length`` is only applied to streaming requests if the server sets + ``wsgi.input_terminated``. - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. + .. versionchanged:: 2.3 + Check ``max_content_length`` and raise an error if it is exceeded. .. versionadded:: 0.9 """ - warnings.warn( - "'get_query_string' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - qs = environ.get("QUERY_STRING", "").encode("latin1") - # QUERY_STRING really should be ascii safe but some browsers - # will send us some unicode stuff (I am looking at you IE). - # In that case we want to urllib quote it badly. - return url_quote(qs, safe=":&%=+$!*'(),") + stream = t.cast(t.IO[bytes], environ["wsgi.input"]) + content_length = get_content_length(environ) + if content_length is not None and max_content_length is not None: + if content_length > max_content_length: + raise RequestEntityTooLarge() + + # A WSGI server can set this to indicate that it terminates the input stream. In + # that case the stream is safe without wrapping, or can enforce a max length. + if "wsgi.input_terminated" in environ: + if max_content_length is not None: + # If this is moved above, it can cause the stream to hang if a read attempt + # is made when the client sends no data. For example, the development server + # does not handle buffering except for chunked encoding. + return t.cast( + t.IO[bytes], LimitedStream(stream, max_content_length, is_max=True) + ) -def get_path_info( - environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" -) -> str: - """Return the ``PATH_INFO`` from the WSGI environment and decode it - unless ``charset`` is ``None``. + return stream - :param environ: WSGI environment to get the path from. - :param charset: The charset for the path info, or ``None`` if no - decoding should be performed. - :param errors: The decoding error handling. + # No limit given, return an empty stream unless the user explicitly allows the + # potentially infinite stream. An infinite stream is dangerous if it's not expected, + # as it can tie up a worker indefinitely. + if content_length is None: + return io.BytesIO() if safe_fallback else stream - .. versionadded:: 0.9 - """ - path = environ.get("PATH_INFO", "").encode("latin1") - return _to_str(path, charset, errors, allow_none_charset=True) # type: ignore + return t.cast(t.IO[bytes], LimitedStream(stream, content_length)) -def get_script_name( - environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" -) -> str: - """Return the ``SCRIPT_NAME`` from the WSGI environment and decode - it unless `charset` is set to ``None``. +def get_path_info(environ: WSGIEnvironment) -> str: + """Return ``PATH_INFO`` from the WSGI environment. :param environ: WSGI environment to get the path from. - :param charset: The charset for the path, or ``None`` if no decoding - should be performed. - :param errors: The decoding error handling. - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. + .. versionchanged:: 3.0 + The ``charset`` and ``errors`` parameters were removed. .. versionadded:: 0.9 """ - warnings.warn( - "'get_script_name' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - path = environ.get("SCRIPT_NAME", "").encode("latin1") - return _to_str(path, charset, errors, allow_none_charset=True) # type: ignore - - -def pop_path_info( - environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" -) -> t.Optional[str]: - """Removes and returns the next segment of `PATH_INFO`, pushing it onto - `SCRIPT_NAME`. Returns `None` if there is nothing left on `PATH_INFO`. - - If the `charset` is set to `None` bytes are returned. - - If there are empty segments (``'/foo//bar``) these are ignored but - properly pushed to the `SCRIPT_NAME`: - - >>> env = {'SCRIPT_NAME': '/foo', 'PATH_INFO': '/a/b'} - >>> pop_path_info(env) - 'a' - >>> env['SCRIPT_NAME'] - '/foo/a' - >>> pop_path_info(env) - 'b' - >>> env['SCRIPT_NAME'] - '/foo/a/b' - - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. - - .. versionadded:: 0.5 - - .. versionchanged:: 0.9 - The path is now decoded and a charset and encoding - parameter can be provided. - - :param environ: the WSGI environment that is modified. - :param charset: The ``encoding`` parameter passed to - :func:`bytes.decode`. - :param errors: The ``errors`` paramater passed to - :func:`bytes.decode`. - """ - warnings.warn( - "'pop_path_info' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - - path = environ.get("PATH_INFO") - if not path: - return None - - script_name = environ.get("SCRIPT_NAME", "") - - # shift multiple leading slashes over - old_path = path - path = path.lstrip("/") - if path != old_path: - script_name += "/" * (len(old_path) - len(path)) - - if "/" not in path: - environ["PATH_INFO"] = "" - environ["SCRIPT_NAME"] = script_name + path - rv = path.encode("latin1") - else: - segment, path = path.split("/", 1) - environ["PATH_INFO"] = f"/{path}" - environ["SCRIPT_NAME"] = script_name + segment - rv = segment.encode("latin1") - - return _to_str(rv, charset, errors, allow_none_charset=True) # type: ignore - - -def peek_path_info( - environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" -) -> t.Optional[str]: - """Returns the next segment on the `PATH_INFO` or `None` if there - is none. Works like :func:`pop_path_info` without modifying the - environment: - - >>> env = {'SCRIPT_NAME': '/foo', 'PATH_INFO': '/a/b'} - >>> peek_path_info(env) - 'a' - >>> peek_path_info(env) - 'a' - - If the `charset` is set to `None` bytes are returned. - - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. - - .. versionadded:: 0.5 - - .. versionchanged:: 0.9 - The path is now decoded and a charset and encoding - parameter can be provided. - - :param environ: the WSGI environment that is checked. - """ - warnings.warn( - "'peek_path_info' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - - segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1) - if segments: - return _to_str( # type: ignore - segments[0].encode("latin1"), charset, errors, allow_none_charset=True - ) - return None - - -def extract_path_info( - environ_or_baseurl: t.Union[str, "WSGIEnvironment"], - path_or_url: t.Union[str, _URLTuple], - charset: str = "utf-8", - errors: str = "werkzeug.url_quote", - collapse_http_schemes: bool = True, -) -> t.Optional[str]: - """Extracts the path info from the given URL (or WSGI environment) and - path. The path info returned is a string. The URLs might also be IRIs. - - If the path info could not be determined, `None` is returned. - - Some examples: - - >>> extract_path_info('http://example.com/app', '/app/hello') - '/hello' - >>> extract_path_info('http://example.com/app', - ... 'https://example.com/app/hello') - '/hello' - >>> extract_path_info('http://example.com/app', - ... 'https://example.com/app/hello', - ... collapse_http_schemes=False) is None - True - - Instead of providing a base URL you can also pass a WSGI environment. - - :param environ_or_baseurl: a WSGI environment dict, a base URL or - base IRI. This is the root of the - application. - :param path_or_url: an absolute path from the server root, a - relative path (in which case it's the path info) - or a full URL. - :param charset: the charset for byte data in URLs - :param errors: the error handling on decode - :param collapse_http_schemes: if set to `False` the algorithm does - not assume that http and https on the - same server point to the same - resource. - - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. - - .. versionchanged:: 0.15 - The ``errors`` parameter defaults to leaving invalid bytes - quoted instead of replacing them. - - .. versionadded:: 0.6 - - """ - warnings.warn( - "'extract_path_info' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - - def _normalize_netloc(scheme: str, netloc: str) -> str: - parts = netloc.split("@", 1)[-1].split(":", 1) - port: t.Optional[str] - - if len(parts) == 2: - netloc, port = parts - if (scheme == "http" and port == "80") or ( - scheme == "https" and port == "443" - ): - port = None - else: - netloc = parts[0] - port = None - - if port is not None: - netloc += f":{port}" - - return netloc - - # make sure whatever we are working on is a IRI and parse it - path = uri_to_iri(path_or_url, charset, errors) - if isinstance(environ_or_baseurl, dict): - environ_or_baseurl = get_current_url(environ_or_baseurl, root_only=True) - base_iri = uri_to_iri(environ_or_baseurl, charset, errors) - base_scheme, base_netloc, base_path = url_parse(base_iri)[:3] - cur_scheme, cur_netloc, cur_path = url_parse(url_join(base_iri, path))[:3] - - # normalize the network location - base_netloc = _normalize_netloc(base_scheme, base_netloc) - cur_netloc = _normalize_netloc(cur_scheme, cur_netloc) - - # is that IRI even on a known HTTP scheme? - if collapse_http_schemes: - for scheme in base_scheme, cur_scheme: - if scheme not in ("http", "https"): - return None - else: - if not (base_scheme in ("http", "https") and base_scheme == cur_scheme): - return None - - # are the netlocs compatible? - if base_netloc != cur_netloc: - return None - - # are we below the application path? - base_path = base_path.rstrip("/") - if not cur_path.startswith(base_path): - return None - - return f"/{cur_path[len(base_path) :].lstrip('/')}" + path: bytes = environ.get("PATH_INFO", "").encode("latin1") + return path.decode(errors="replace") class ClosingIterator: @@ -476,9 +233,8 @@ class ClosingIterator: def __init__( self, iterable: t.Iterable[bytes], - callbacks: t.Optional[ - t.Union[t.Callable[[], None], t.Iterable[t.Callable[[], None]]] - ] = None, + callbacks: None + | (t.Callable[[], None] | t.Iterable[t.Callable[[], None]]) = None, ) -> None: iterator = iter(iterable) self._next = t.cast(t.Callable[[], bytes], partial(next, iterator)) @@ -493,7 +249,7 @@ def __init__( callbacks.insert(0, iterable_close) self._callbacks = callbacks - def __iter__(self) -> "ClosingIterator": + def __iter__(self) -> ClosingIterator: return self def __next__(self) -> bytes: @@ -505,7 +261,7 @@ def close(self) -> None: def wrap_file( - environ: "WSGIEnvironment", file: t.IO[bytes], buffer_size: int = 8192 + environ: WSGIEnvironment, file: t.IO[bytes], buffer_size: int = 8192 ) -> t.Iterable[bytes]: """Wraps a file. This uses the WSGI server's file wrapper if available or otherwise the generic :class:`FileWrapper`. @@ -564,12 +320,12 @@ def seek(self, *args: t.Any) -> None: if hasattr(self.file, "seek"): self.file.seek(*args) - def tell(self) -> t.Optional[int]: + def tell(self) -> int | None: if hasattr(self.file, "tell"): return self.file.tell() return None - def __iter__(self) -> "FileWrapper": + def __iter__(self) -> FileWrapper: return self def __next__(self) -> bytes: @@ -598,9 +354,9 @@ class _RangeWrapper: def __init__( self, - iterable: t.Union[t.Iterable[bytes], t.IO[bytes]], + iterable: t.Iterable[bytes] | t.IO[bytes], start_byte: int = 0, - byte_range: t.Optional[int] = None, + byte_range: int | None = None, ): self.iterable = iter(iterable) self.byte_range = byte_range @@ -611,12 +367,10 @@ def __init__( self.end_byte = start_byte + byte_range self.read_length = 0 - self.seekable = ( - hasattr(iterable, "seekable") and iterable.seekable() # type: ignore - ) + self.seekable = hasattr(iterable, "seekable") and iterable.seekable() self.end_reached = False - def __iter__(self) -> "_RangeWrapper": + def __iter__(self) -> _RangeWrapper: return self def _next_chunk(self) -> bytes: @@ -628,7 +382,7 @@ def _next_chunk(self) -> bytes: self.end_reached = True raise - def _first_iteration(self) -> t.Tuple[t.Optional[bytes], int]: + def _first_iteration(self) -> tuple[bytes | None, int]: chunk = None if self.seekable: self.iterable.seek(self.start_byte) # type: ignore @@ -665,356 +419,177 @@ def __next__(self) -> bytes: def close(self) -> None: if hasattr(self.iterable, "close"): - self.iterable.close() # type: ignore - - -def _make_chunk_iter( - stream: t.Union[t.Iterable[bytes], t.IO[bytes]], - limit: t.Optional[int], - buffer_size: int, -) -> t.Iterator[bytes]: - """Helper for the line and chunk iter functions.""" - if isinstance(stream, (bytes, bytearray, str)): - raise TypeError( - "Passed a string or byte object instead of true iterator or stream." - ) - if not hasattr(stream, "read"): - for item in stream: - if item: - yield item - return - stream = t.cast(t.IO[bytes], stream) - if not isinstance(stream, LimitedStream) and limit is not None: - stream = t.cast(t.IO[bytes], LimitedStream(stream, limit)) - _read = stream.read - while True: - item = _read(buffer_size) - if not item: - break - yield item - - -def make_line_iter( - stream: t.Union[t.Iterable[bytes], t.IO[bytes]], - limit: t.Optional[int] = None, - buffer_size: int = 10 * 1024, - cap_at_buffer: bool = False, -) -> t.Iterator[bytes]: - """Safely iterates line-based over an input stream. If the input stream - is not a :class:`LimitedStream` the `limit` parameter is mandatory. - - This uses the stream's :meth:`~file.read` method internally as opposite - to the :meth:`~file.readline` method that is unsafe and can only be used - in violation of the WSGI specification. The same problem applies to the - `__iter__` function of the input stream which calls :meth:`~file.readline` - without arguments. - - If you need line-by-line processing it's strongly recommended to iterate - over the input stream using this helper function. - - .. versionchanged:: 0.8 - This function now ensures that the limit was reached. + self.iterable.close() - .. versionadded:: 0.9 - added support for iterators as input stream. - - .. versionadded:: 0.11.10 - added support for the `cap_at_buffer` parameter. - - :param stream: the stream or iterate to iterate over. - :param limit: the limit in bytes for the stream. (Usually - content length. Not necessary if the `stream` - is a :class:`LimitedStream`. - :param buffer_size: The optional buffer size. - :param cap_at_buffer: if this is set chunks are split if they are longer - than the buffer size. Internally this is implemented - that the buffer size might be exhausted by a factor - of two however. - """ - _iter = _make_chunk_iter(stream, limit, buffer_size) - - first_item = next(_iter, "") - if not first_item: - return - - s = _make_encode_wrapper(first_item) - empty = t.cast(bytes, s("")) - cr = t.cast(bytes, s("\r")) - lf = t.cast(bytes, s("\n")) - crlf = t.cast(bytes, s("\r\n")) - - _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter)) - - def _iter_basic_lines() -> t.Iterator[bytes]: - _join = empty.join - buffer: t.List[bytes] = [] - while True: - new_data = next(_iter, "") - if not new_data: - break - new_buf: t.List[bytes] = [] - buf_size = 0 - for item in t.cast( - t.Iterator[bytes], chain(buffer, new_data.splitlines(True)) - ): - new_buf.append(item) - buf_size += len(item) - if item and item[-1:] in crlf: - yield _join(new_buf) - new_buf = [] - elif cap_at_buffer and buf_size >= buffer_size: - rv = _join(new_buf) - while len(rv) >= buffer_size: - yield rv[:buffer_size] - rv = rv[buffer_size:] - new_buf = [rv] - buffer = new_buf - if buffer: - yield _join(buffer) - - # This hackery is necessary to merge 'foo\r' and '\n' into one item - # of 'foo\r\n' if we were unlucky and we hit a chunk boundary. - previous = empty - for item in _iter_basic_lines(): - if item == lf and previous[-1:] == cr: - previous += item - item = empty - if previous: - yield previous - previous = item - if previous: - yield previous - - -def make_chunk_iter( - stream: t.Union[t.Iterable[bytes], t.IO[bytes]], - separator: bytes, - limit: t.Optional[int] = None, - buffer_size: int = 10 * 1024, - cap_at_buffer: bool = False, -) -> t.Iterator[bytes]: - """Works like :func:`make_line_iter` but accepts a separator - which divides chunks. If you want newline based processing - you should use :func:`make_line_iter` instead as it - supports arbitrary newline markers. - - .. versionadded:: 0.8 - .. versionadded:: 0.9 - added support for iterators as input stream. - - .. versionadded:: 0.11.10 - added support for the `cap_at_buffer` parameter. - - :param stream: the stream or iterate to iterate over. - :param separator: the separator that divides chunks. - :param limit: the limit in bytes for the stream. (Usually - content length. Not necessary if the `stream` - is otherwise already limited). - :param buffer_size: The optional buffer size. - :param cap_at_buffer: if this is set chunks are split if they are longer - than the buffer size. Internally this is implemented - that the buffer size might be exhausted by a factor - of two however. - """ - _iter = _make_chunk_iter(stream, limit, buffer_size) - - first_item = next(_iter, b"") - if not first_item: - return - - _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter)) - if isinstance(first_item, str): - separator = _to_str(separator) - _split = re.compile(f"({re.escape(separator)})").split - _join = "".join - else: - separator = _to_bytes(separator) - _split = re.compile(b"(" + re.escape(separator) + b")").split - _join = b"".join - - buffer: t.List[bytes] = [] - while True: - new_data = next(_iter, b"") - if not new_data: - break - chunks = _split(new_data) - new_buf: t.List[bytes] = [] - buf_size = 0 - for item in chain(buffer, chunks): - if item == separator: - yield _join(new_buf) - new_buf = [] - buf_size = 0 - else: - buf_size += len(item) - new_buf.append(item) - - if cap_at_buffer and buf_size >= buffer_size: - rv = _join(new_buf) - while len(rv) >= buffer_size: - yield rv[:buffer_size] - rv = rv[buffer_size:] - new_buf = [rv] - buf_size = len(rv) - - buffer = new_buf - if buffer: - yield _join(buffer) - - -class LimitedStream(io.IOBase): - """Wraps a stream so that it doesn't read more than n bytes. If the - stream is exhausted and the caller tries to get more bytes from it - :func:`on_exhausted` is called which by default returns an empty - string. The return value of that function is forwarded - to the reader function. So if it returns an empty string - :meth:`read` will return an empty string as well. - - The limit however must never be higher than what the stream can - output. Otherwise :meth:`readlines` will try to read past the - limit. - - .. admonition:: Note on WSGI compliance - - calls to :meth:`readline` and :meth:`readlines` are not - WSGI compliant because it passes a size argument to the - readline methods. Unfortunately the WSGI PEP is not safely - implementable without a size argument to :meth:`readline` - because there is no EOF marker in the stream. As a result - of that the use of :meth:`readline` is discouraged. - - For the same reason iterating over the :class:`LimitedStream` - is not portable. It internally calls :meth:`readline`. - - We strongly suggest using :meth:`read` only or using the - :func:`make_line_iter` which safely iterates line-based - over a WSGI input stream. - - :param stream: the stream to wrap. - :param limit: the limit for the stream, must not be longer than - what the string can provide if the stream does not - end with `EOF` (like `wsgi.input`) +class LimitedStream(io.RawIOBase): + """Wrap a stream so that it doesn't read more than a given limit. This is used to + limit ``wsgi.input`` to the ``Content-Length`` header value or + :attr:`.Request.max_content_length`. + + When attempting to read after the limit has been reached, :meth:`on_exhausted` is + called. When the limit is a maximum, this raises :exc:`.RequestEntityTooLarge`. + + If reading from the stream returns zero bytes or raises an error, + :meth:`on_disconnect` is called, which raises :exc:`.ClientDisconnected`. When the + limit is a maximum and zero bytes were read, no error is raised, since it may be the + end of the stream. + + If the limit is reached before the underlying stream is exhausted (such as a file + that is too large, or an infinite stream), the remaining contents of the stream + cannot be read safely. Depending on how the server handles this, clients may show a + "connection reset" failure instead of seeing the 413 response. + + :param stream: The stream to read from. Must be a readable binary IO object. + :param limit: The limit in bytes to not read past. Should be either the + ``Content-Length`` header value or ``request.max_content_length``. + :param is_max: Whether the given ``limit`` is ``request.max_content_length`` instead + of the ``Content-Length`` header value. This changes how exhausted and + disconnect events are handled. + + .. versionchanged:: 2.3 + Handle ``max_content_length`` differently than ``Content-Length``. + + .. versionchanged:: 2.3 + Implements ``io.RawIOBase`` rather than ``io.IOBase``. """ - def __init__(self, stream: t.IO[bytes], limit: int) -> None: - self._read = stream.read - self._readline = stream.readline + def __init__(self, stream: t.IO[bytes], limit: int, is_max: bool = False) -> None: + self._stream = stream self._pos = 0 self.limit = limit - - def __iter__(self) -> "LimitedStream": - return self + self._limit_is_max = is_max @property def is_exhausted(self) -> bool: - """If the stream is exhausted this attribute is `True`.""" + """Whether the current stream position has reached the limit.""" return self._pos >= self.limit - def on_exhausted(self) -> bytes: - """This is called when the stream tries to read past the limit. - The return value of this function is returned from the reading - function. - """ - # Read null bytes from the stream so that we get the - # correct end of stream marker. - return self._read(0) - - def on_disconnect(self) -> bytes: - """What should happen if a disconnect is detected? The return - value of this function is returned from read functions in case - the client went away. By default a - :exc:`~werkzeug.exceptions.ClientDisconnected` exception is raised. - """ - from .exceptions import ClientDisconnected + def on_exhausted(self) -> None: + """Called when attempting to read after the limit has been reached. - raise ClientDisconnected() + The default behavior is to do nothing, unless the limit is a maximum, in which + case it raises :exc:`.RequestEntityTooLarge`. - def exhaust(self, chunk_size: int = 1024 * 64) -> None: - """Exhaust the stream. This consumes all the data left until the - limit is reached. + .. versionchanged:: 2.3 + Raises ``RequestEntityTooLarge`` if the limit is a maximum. - :param chunk_size: the size for a chunk. It will read the chunk - until the stream is exhausted and throw away - the results. + .. versionchanged:: 2.3 + Any return value is ignored. """ - to_read = self.limit - self._pos - chunk = chunk_size - while to_read > 0: - chunk = min(to_read, chunk) - self.read(chunk) - to_read -= chunk + if self._limit_is_max: + raise RequestEntityTooLarge() + + def on_disconnect(self, error: Exception | None = None) -> None: + """Called when an attempted read receives zero bytes before the limit was + reached. This indicates that the client disconnected before sending the full + request body. + + The default behavior is to raise :exc:`.ClientDisconnected`, unless the limit is + a maximum and no error was raised. - def read(self, size: t.Optional[int] = None) -> bytes: - """Read `size` bytes or if size is not provided everything is read. + .. versionchanged:: 2.3 + Added the ``error`` parameter. Do nothing if the limit is a maximum and no + error was raised. - :param size: the number of bytes read. + .. versionchanged:: 2.3 + Any return value is ignored. """ - if self._pos >= self.limit: - return self.on_exhausted() - if size is None or size == -1: # -1 is for consistence with file - size = self.limit - to_read = min(self.limit - self._pos, size) - try: - read = self._read(to_read) - except (OSError, ValueError): - return self.on_disconnect() - if to_read and len(read) != to_read: - return self.on_disconnect() - self._pos += len(read) - return read - - def readline(self, size: t.Optional[int] = None) -> bytes: - """Reads one line from the stream.""" - if self._pos >= self.limit: - return self.on_exhausted() - if size is None: - size = self.limit - self._pos - else: - size = min(size, self.limit - self._pos) - try: - line = self._readline(size) - except (ValueError, OSError): - return self.on_disconnect() - if size and not line: - return self.on_disconnect() - self._pos += len(line) - return line - - def readlines(self, size: t.Optional[int] = None) -> t.List[bytes]: - """Reads a file into a list of strings. It calls :meth:`readline` - until the file is read to the end. It does support the optional - `size` argument if the underlying stream supports it for - `readline`. + if not self._limit_is_max or error is not None: + raise ClientDisconnected() + + # If the limit is a maximum, then we may have read zero bytes because the + # streaming body is complete. There's no way to distinguish that from the + # client disconnecting early. + + def exhaust(self) -> bytes: + """Exhaust the stream by reading until the limit is reached or the client + disconnects, returning the remaining data. + + .. versionchanged:: 2.3 + Return the remaining data. + + .. versionchanged:: 2.2.3 + Handle case where wrapped stream returns fewer bytes than requested. """ - last_pos = self._pos - result = [] - if size is not None: - end = min(self.limit, last_pos + size) + if not self.is_exhausted: + return self.readall() + + return b"" + + def readinto(self, b: bytearray) -> int | None: # type: ignore[override] + size = len(b) + remaining = self.limit - self._pos + + if remaining <= 0: + self.on_exhausted() + return 0 + + if hasattr(self._stream, "readinto"): + # Use stream.readinto if it's available. + if size <= remaining: + # The size fits in the remaining limit, use the buffer directly. + try: + out_size: int | None = self._stream.readinto(b) + except (OSError, ValueError) as e: + self.on_disconnect(error=e) + return 0 + else: + # Use a temp buffer with the remaining limit as the size. + temp_b = bytearray(remaining) + + try: + out_size = self._stream.readinto(temp_b) + except (OSError, ValueError) as e: + self.on_disconnect(error=e) + return 0 + + if out_size: + b[:out_size] = temp_b else: - end = self.limit - while True: - if size is not None: - size -= last_pos - self._pos - if self._pos >= end: + # WSGI requires that stream.read is available. + try: + data = self._stream.read(min(size, remaining)) + except (OSError, ValueError) as e: + self.on_disconnect(error=e) + return 0 + + out_size = len(data) + b[:out_size] = data + + if not out_size: + # Read zero bytes from the stream. + self.on_disconnect() + return 0 + + self._pos += out_size + return out_size + + def readall(self) -> bytes: + if self.is_exhausted: + self.on_exhausted() + return b"" + + out = bytearray() + + # The parent implementation uses "while True", which results in an extra read. + while not self.is_exhausted: + data = self.read(1024 * 64) + + # Stream may return empty before a max limit is reached. + if not data: break - result.append(self.readline(size)) - if size is not None: - last_pos = self._pos - return result + + out.extend(data) + + return bytes(out) def tell(self) -> int: - """Returns the position of the stream. + """Return the current stream position. .. versionadded:: 0.9 """ return self._pos - def __next__(self) -> bytes: - line = self.readline() - if not line: - raise StopIteration() - return line - def readable(self) -> bool: return True diff --git a/tests/conftest.py b/tests/conftest.py index 7ce0896..b73202c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,9 @@ def __init__(self, kwargs): self.log = None def tail_log(self, path): - self.log = open(path) + # surrogateescape allows for handling of file streams + # containing junk binary values as normal text streams + self.log = open(path, errors="surrogateescape") self.log.read() def connect(self, **kwargs): diff --git a/tests/live_apps/data_app.py b/tests/live_apps/data_app.py index a7158c7..9b2e78b 100644 --- a/tests/live_apps/data_app.py +++ b/tests/live_apps/data_app.py @@ -5,13 +5,13 @@ @Request.application -def app(request): +def app(request: Request) -> Response: return Response( json.dumps( { "environ": request.environ, - "form": request.form, - "files": {k: v.read().decode("utf8") for k, v in request.files.items()}, + "form": request.form.to_dict(), + "files": {k: v.read().decode() for k, v in request.files.items()}, }, default=lambda x: str(x), ), diff --git a/tests/middleware/test_dispatcher.py b/tests/middleware/test_dispatcher.py index 5a25a6c..b7b9a77 100644 --- a/tests/middleware/test_dispatcher.py +++ b/tests/middleware/test_dispatcher.py @@ -1,4 +1,3 @@ -from werkzeug._internal import _to_bytes from werkzeug.middleware.dispatcher import DispatcherMiddleware from werkzeug.test import create_environ from werkzeug.test import run_wsgi_app @@ -11,7 +10,7 @@ def null_application(environ, start_response): def dummy_application(environ, start_response): start_response("200 OK", [("Content-Type", "text/plain")]) - yield _to_bytes(environ["SCRIPT_NAME"]) + yield environ["SCRIPT_NAME"].encode() app = DispatcherMiddleware( null_application, @@ -27,7 +26,7 @@ def dummy_application(environ, start_response): environ = create_environ(p) app_iter, status, headers = run_wsgi_app(app, environ) assert status == "200 OK" - assert b"".join(app_iter).strip() == _to_bytes(name) + assert b"".join(app_iter).strip() == name.encode() app_iter, status, headers = run_wsgi_app(app, create_environ("/missing")) assert status == "404 NOT FOUND" diff --git a/tests/middleware/test_profiler.py b/tests/middleware/test_profiler.py new file mode 100644 index 0000000..585aeb5 --- /dev/null +++ b/tests/middleware/test_profiler.py @@ -0,0 +1,50 @@ +import datetime +import os +from unittest.mock import ANY +from unittest.mock import MagicMock +from unittest.mock import patch + +from werkzeug.middleware.profiler import Profile +from werkzeug.middleware.profiler import ProfilerMiddleware +from werkzeug.test import Client + + +def dummy_application(environ, start_response): + start_response("200 OK", [("Content-Type", "text/plain")]) + return [b"Foo"] + + +def test_filename_format_function(): + # This should be called once with the generated file name + mock_capture_name = MagicMock() + + def filename_format(env): + now = datetime.datetime.fromtimestamp(env["werkzeug.profiler"]["time"]) + timestamp = now.strftime("%Y-%m-%d:%H:%M:%S") + path = ( + "_".join(token for token in env["PATH_INFO"].split("/") if token) or "ROOT" + ) + elapsed = env["werkzeug.profiler"]["elapsed"] + name = f"{timestamp}.{env['REQUEST_METHOD']}.{path}.{elapsed:.0f}ms.prof" + mock_capture_name(name=name) + return name + + client = Client( + ProfilerMiddleware( + dummy_application, + stream=None, + profile_dir="profiles", + filename_format=filename_format, + ) + ) + + # Replace the Profile class with a function that simulates an __init__() + # call and returns our mock instance. + mock_profile = MagicMock(wraps=Profile()) + mock_profile.dump_stats = MagicMock() + with patch("werkzeug.middleware.profiler.Profile", lambda: mock_profile): + client.get("/foo/bar") + + mock_capture_name.assert_called_once_with(name=ANY) + name = mock_capture_name.mock_calls[0].kwargs["name"] + mock_profile.dump_stats.assert_called_once_with(os.path.join("profiles", name)) diff --git a/tests/sansio/test_multipart.py b/tests/sansio/test_multipart.py index f9c48b4..cf36fef 100644 --- a/tests/sansio/test_multipart.py +++ b/tests/sansio/test_multipart.py @@ -1,3 +1,5 @@ +import pytest + from werkzeug.datastructures import Headers from werkzeug.sansio.multipart import Data from werkzeug.sansio.multipart import Epilogue @@ -22,15 +24,11 @@ def test_decoder_simple() -> None: asdasd -----------------------------9704338192090380615194531385$-- - """.replace( - "\n", "\r\n" - ).encode( - "utf-8" - ) + """.replace("\n", "\r\n").encode() decoder.receive_data(data) decoder.receive_data(None) events = [decoder.next_event()] - while not isinstance(events[-1], Epilogue) and len(events) < 6: + while not isinstance(events[-1], Epilogue): events.append(decoder.next_event()) assert events == [ Preamble(data=b""), @@ -56,6 +54,57 @@ def test_decoder_simple() -> None: assert data == result +@pytest.mark.parametrize( + "data_start", + [ + b"A", + b"\n", + b"\r", + b"\r\n", + b"\n\r", + b"A\n", + b"A\r", + b"A\r\n", + b"A\n\r", + ], +) +@pytest.mark.parametrize("data_end", [b"", b"\r\n--foo"]) +def test_decoder_data_start_with_different_newline_positions( + data_start: bytes, data_end: bytes +) -> None: + boundary = b"foo" + data = ( + b"\r\n--foo\r\n" + b'Content-Disposition: form-data; name="test"; filename="testfile"\r\n' + b"Content-Type: application/octet-stream\r\n\r\n" + b"" + data_start + b"\r\nBCDE" + data_end + ) + decoder = MultipartDecoder(boundary) + decoder.receive_data(data) + events = [decoder.next_event()] + # We want to check up to data start event + while not isinstance(events[-1], Data): + events.append(decoder.next_event()) + expected = data_start if data_end == b"" else data_start + b"\r\nBCDE" + assert events == [ + Preamble(data=b""), + File( + name="test", + filename="testfile", + headers=Headers( + [ + ( + "Content-Disposition", + 'form-data; name="test"; filename="testfile"', + ), + ("Content-Type", "application/octet-stream"), + ] + ), + ), + Data(data=expected, more_data=True), + ] + + def test_chunked_boundaries() -> None: boundary = b"--boundary" decoder = MultipartDecoder(boundary) @@ -78,3 +127,54 @@ def test_chunked_boundaries() -> None: assert not event.more_data decoder.receive_data(None) assert isinstance(decoder.next_event(), Epilogue) + + +def test_empty_field() -> None: + boundary = b"foo" + decoder = MultipartDecoder(boundary) + data = """ +--foo +Content-Disposition: form-data; name="text" +Content-Type: text/plain; charset="UTF-8" + +Some Text +--foo +Content-Disposition: form-data; name="empty" +Content-Type: text/plain; charset="UTF-8" + +--foo-- + """.replace("\n", "\r\n").encode() + decoder.receive_data(data) + decoder.receive_data(None) + events = [decoder.next_event()] + while not isinstance(events[-1], Epilogue): + events.append(decoder.next_event()) + assert events == [ + Preamble(data=b""), + Field( + name="text", + headers=Headers( + [ + ("Content-Disposition", 'form-data; name="text"'), + ("Content-Type", 'text/plain; charset="UTF-8"'), + ] + ), + ), + Data(data=b"Some Text", more_data=False), + Field( + name="empty", + headers=Headers( + [ + ("Content-Disposition", 'form-data; name="empty"'), + ("Content-Type", 'text/plain; charset="UTF-8"'), + ] + ), + ), + Data(data=b"", more_data=False), + Epilogue(data=b" "), + ] + encoder = MultipartEncoder(boundary) + result = b"" + for event in events: + result += encoder.send_event(event) + assert data == result diff --git a/tests/sansio/test_request.py b/tests/sansio/test_request.py index 310b244..4f4bbd6 100644 --- a/tests/sansio/test_request.py +++ b/tests/sansio/test_request.py @@ -12,6 +12,10 @@ (Headers({"Transfer-Encoding": "chunked", "Content-Length": "6"}), None), (Headers({"Transfer-Encoding": "something", "Content-Length": "6"}), 6), (Headers({"Content-Length": "6"}), 6), + (Headers({"Content-Length": "-6"}), 0), + (Headers({"Content-Length": "+123"}), 0), + (Headers({"Content-Length": "1_23"}), 0), + (Headers({"Content-Length": "🯱🯲🯳"}), 0), (Headers(), None), ], ) diff --git a/tests/sansio/test_utils.py b/tests/sansio/test_utils.py index 8c8faa6..d43de66 100644 --- a/tests/sansio/test_utils.py +++ b/tests/sansio/test_utils.py @@ -1,7 +1,8 @@ -import typing as t +from __future__ import annotations import pytest +from werkzeug.sansio.utils import get_content_length from werkzeug.sansio.utils import get_host @@ -25,8 +26,28 @@ ) def test_get_host( scheme: str, - host_header: t.Optional[str], - server: t.Optional[t.Tuple[str, t.Optional[int]]], + host_header: str | None, + server: tuple[str, int | None] | None, expected: str, ) -> None: assert get_host(scheme, host_header, server) == expected + + +@pytest.mark.parametrize( + ("http_content_length", "http_transfer_encoding", "expected"), + [ + ("2", None, 2), + (" 2", None, 2), + ("2 ", None, 2), + (None, None, None), + (None, "chunked", None), + ("a", None, 0), + ("-2", None, 0), + ], +) +def test_get_content_length( + http_content_length: str | None, + http_transfer_encoding: str | None, + expected: int | None, +) -> None: + assert get_content_length(http_content_length, http_transfer_encoding) == expected diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 7f63b64..64330e1 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -63,7 +63,7 @@ def create_instance(module=None): d = create_instance() s = pickle.dumps(d, protocol) ud = pickle.loads(s) - assert type(ud) == type(d) + assert type(ud) == type(d) # noqa: E721 assert ud == d alternative = pickle.dumps(create_instance("werkzeug"), protocol) assert pickle.loads(alternative) == d @@ -550,8 +550,9 @@ def test_value_conversion(self): assert d.get("foo", type=int) == 1 def test_return_default_when_conversion_is_not_possible(self): - d = self.storage_class(foo="bar") + d = self.storage_class(foo="bar", baz=None) assert d.get("foo", default=-1, type=int) == -1 + assert d.get("baz", default=-1, type=int) == -1 def test_propagate_exceptions_in_conversion(self): d = self.storage_class(foo="bar") @@ -731,16 +732,6 @@ def test_slicing(self): h[:] = [(k, v) for k, v in h if k.startswith("X-")] assert list(h) == [("X-Foo-Poo", "bleh"), ("X-Forwarded-For", "192.168.0.123")] - def test_bytes_operations(self): - h = self.storage_class() - h.set("X-Foo-Poo", "bleh") - h.set("X-Whoops", b"\xff") - h.set(b"X-Bytes", b"something") - - assert h.get("x-foo-poo", as_bytes=True) == b"bleh" - assert h.get("x-whoops", as_bytes=True) == b"\xff" - assert h.get("x-bytes") == "something" - def test_extend(self): h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")]) h.extend(ds.Headers([("a", "3"), ("a", "4")])) @@ -791,13 +782,6 @@ def test_to_wsgi_list(self): assert key == "Key" assert value == "Value" - def test_to_wsgi_list_bytes(self): - h = self.storage_class() - h.set(b"Key", b"Value") - for key, value in h.to_wsgi_list(): - assert key == "Key" - assert value == "Value" - def test_equality(self): # test equality, given keys are case insensitive h1 = self.storage_class() @@ -853,13 +837,6 @@ def test_return_type_is_str(self): assert headers["Foo"] == "\xe2\x9c\x93" assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93") - def test_bytes_operations(self): - foo_val = "\xff" - h = self.storage_class({"HTTP_X_FOO": foo_val}) - - assert h.get("x-foo", as_bytes=True) == b"\xff" - assert h.get("x-foo") == "\xff" - class TestHeaderSet: storage_class = ds.HeaderSet diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index d8fed96..ad20b3f 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -7,6 +7,7 @@ from werkzeug import exceptions from werkzeug.datastructures import Headers from werkzeug.datastructures import WWWAuthenticate +from werkzeug.exceptions import default_exceptions from werkzeug.exceptions import HTTPException from werkzeug.wrappers import Response @@ -96,10 +97,8 @@ def test_method_not_allowed_methods(): def test_unauthorized_www_authenticate(): - basic = WWWAuthenticate() - basic.set_basic("test") - digest = WWWAuthenticate() - digest.set_digest("test", "test") + basic = WWWAuthenticate("basic", {"realm": "test"}) + digest = WWWAuthenticate("digest", {"realm": "test", "nonce": "test"}) exc = exceptions.Unauthorized(www_authenticate=basic) h = Headers(exc.get_headers({})) @@ -140,7 +139,7 @@ def test_retry_after_mixin(cls, value, expect): @pytest.mark.parametrize( "cls", sorted( - (e for e in HTTPException.__subclasses__() if e.code and e.code >= 400), + (e for e in default_exceptions.values() if e.code and e.code >= 400), key=lambda e: e.code, # type: ignore ), ) @@ -160,7 +159,7 @@ def test_description_none(): @pytest.mark.parametrize( "cls", sorted( - (e for e in HTTPException.__subclasses__() if e.code), + (e for e in default_exceptions.values() if e.code), key=lambda e: e.code, # type: ignore ), ) diff --git a/tests/test_formparser.py b/tests/test_formparser.py index 49010b4..1ecb012 100644 --- a/tests/test_formparser.py +++ b/tests/test_formparser.py @@ -69,6 +69,17 @@ def test_limiting(self): req.max_form_memory_size = 400 assert req.form["foo"] == "Hello World" + input_stream = io.BytesIO(b"foo=123456") + req = Request.from_values( + input_stream=input_stream, + content_type="application/x-www-form-urlencoded", + method="POST", + ) + req.max_content_length = 4 + pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) + # content-length was set, so request could exit early without reading anything + assert input_stream.read() == b"foo=123456" + data = ( b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n" b"Hello World\r\n" @@ -81,24 +92,17 @@ def test_limiting(self): content_type="multipart/form-data; boundary=foo", method="POST", ) - req.max_content_length = 4 - pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) + req.max_content_length = 400 + assert req.form["foo"] == "Hello World" - # when the request entity is too large, the input stream should be - # drained so that firefox (and others) do not report connection reset - # when run through gunicorn - # a sufficiently large stream is necessary for block-based reads - input_stream = io.BytesIO(b"foo=" + b"x" * 128 * 1024) req = Request.from_values( - input_stream=input_stream, + input_stream=io.BytesIO(data), content_length=len(data), content_type="multipart/form-data; boundary=foo", method="POST", ) - req.max_content_length = 4 + req.max_form_memory_size = 7 pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) - # ensure that the stream is exhausted - assert input_stream.read() == b"" req = Request.from_values( input_stream=io.BytesIO(data), @@ -106,7 +110,7 @@ def test_limiting(self): content_type="multipart/form-data; boundary=foo", method="POST", ) - req.max_content_length = 400 + req.max_form_memory_size = 400 assert req.form["foo"] == "Hello World" req = Request.from_values( @@ -115,17 +119,15 @@ def test_limiting(self): content_type="multipart/form-data; boundary=foo", method="POST", ) - req.max_form_memory_size = 7 + req.max_form_parts = 1 pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) - req = Request.from_values( - input_stream=io.BytesIO(data), - content_length=len(data), - content_type="multipart/form-data; boundary=foo", - method="POST", - ) - req.max_form_memory_size = 400 - assert req.form["foo"] == "Hello World" + def test_x_www_urlencoded_max_form_parts(self): + r = Request.from_values(method="POST", data={"a": 1, "b": 2}) + r.max_form_parts = 1 + + assert r.form["a"] == "1" + assert r.form["b"] == "2" def test_missing_multipart_boundary(self): data = ( @@ -271,7 +273,7 @@ def test_basic(self): content_type=f'multipart/form-data; boundary="{boundary}"', content_length=len(data), ) as response: - assert response.get_data() == repr(text).encode("utf-8") + assert response.get_data() == repr(text).encode() @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") def test_ie7_unc_path(self): diff --git a/tests/test_http.py b/tests/test_http.py index 3760dc1..bbd51ba 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,4 +1,5 @@ import base64 +import urllib.parse from datetime import date from datetime import datetime from datetime import timedelta @@ -9,6 +10,8 @@ from werkzeug import datastructures from werkzeug import http from werkzeug._internal import _wsgi_encoding_dance +from werkzeug.datastructures import Authorization +from werkzeug.datastructures import WWWAuthenticate from werkzeug.test import create_environ @@ -21,6 +24,10 @@ def test_accept(self): pytest.raises(ValueError, a.index, "de") assert a.to_header() == "en-us,ru;q=0.5" + def test_accept_parameter_with_space(self): + a = http.parse_accept_header('application/x-special; z="a b";q=0.5') + assert a['application/x-special; z="a b"'] == 0.5 + def test_mime_accept(self): a = http.parse_accept_header( "text/xml,application/xml," @@ -88,9 +95,17 @@ def test_set_header(self): hs.add("Foo") assert hs.to_header() == 'foo, Bar, "Blah baz", Hehe' - def test_list_header(self): - hl = http.parse_list_header("foo baz, blah") - assert hl == ["foo baz", "blah"] + @pytest.mark.parametrize( + ("value", "expect"), + [ + ("a b", ["a b"]), + ("a b, c", ["a b", "c"]), + ('a b, "c, d"', ["a b", "c, d"]), + ('"a\\"b", c', ['a"b', "c"]), + ], + ) + def test_list_header(self, value, expect): + assert http.parse_list_header(value) == expect def test_dict_header(self): d = http.parse_dict_header('foo="bar baz", blah=42') @@ -133,33 +148,30 @@ def test_csp_header(self): assert csp.img_src is None def test_authorization_header(self): - a = http.parse_authorization_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + a = Authorization.from_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") assert a.type == "basic" assert a.username == "Aladdin" assert a.password == "open sesame" - a = http.parse_authorization_header( - "Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw==" - ) + a = Authorization.from_header("Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw==") assert a.type == "basic" assert a.username == "русскиЁ" assert a.password == "Буквы" - a = http.parse_authorization_header("Basic 5pmu6YCa6K+dOuS4reaWhw==") + a = Authorization.from_header("Basic 5pmu6YCa6K+dOuS4reaWhw==") assert a.type == "basic" assert a.username == "普通话" assert a.password == "中文" - a = http.parse_authorization_header( - '''Digest username="Mufasa", - realm="testrealm@host.invalid", - nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", - uri="/dir/index.html", - qop=auth, - nc=00000001, - cnonce="0a4f113b", - response="6629fae49393a05397450978507c4ef1", - opaque="5ccc069c403ebaf9f0171e9517f40e41"''' + a = Authorization.from_header( + 'Digest username="Mufasa",' + ' realm="testrealm@host.invalid",' + ' nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",' + ' uri="/dir/index.html",' + " qop=auth, nc=00000001," + ' cnonce="0a4f113b",' + ' response="6629fae49393a05397450978507c4ef1",' + ' opaque="5ccc069c403ebaf9f0171e9517f40e41"' ) assert a.type == "digest" assert a.username == "Mufasa" @@ -172,13 +184,13 @@ def test_authorization_header(self): assert a.response == "6629fae49393a05397450978507c4ef1" assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41" - a = http.parse_authorization_header( - '''Digest username="Mufasa", - realm="testrealm@host.invalid", - nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", - uri="/dir/index.html", - response="e257afa1414a3340d93d30955171dd0e", - opaque="5ccc069c403ebaf9f0171e9517f40e41"''' + a = Authorization.from_header( + 'Digest username="Mufasa",' + ' realm="testrealm@host.invalid",' + ' nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",' + ' uri="/dir/index.html",' + ' response="e257afa1414a3340d93d30955171dd0e",' + ' opaque="5ccc069c403ebaf9f0171e9517f40e41"' ) assert a.type == "digest" assert a.username == "Mufasa" @@ -188,41 +200,87 @@ def test_authorization_header(self): assert a.response == "e257afa1414a3340d93d30955171dd0e" assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41" - assert http.parse_authorization_header("") is None - assert http.parse_authorization_header(None) is None - assert http.parse_authorization_header("foo") is None + assert Authorization.from_header("") is None + assert Authorization.from_header(None) is None + assert Authorization.from_header("foo").type == "foo" + + def test_authorization_token_padding(self): + # padded with = + token = base64.b64encode(b"This has base64 padding").decode() + a = Authorization.from_header(f"Token {token}") + assert a.type == "token" + assert a.token == token + + # padded with == + token = base64.b64encode(b"This has base64 padding..").decode() + a = Authorization.from_header(f"Token {token}") + assert a.type == "token" + assert a.token == token + + def test_authorization_basic_incorrect_padding(self): + assert Authorization.from_header("Basic foo") is None def test_bad_authorization_header_encoding(self): """If the base64 encoded bytes can't be decoded as UTF-8""" content = base64.b64encode(b"\xffser:pass").decode() - assert http.parse_authorization_header(f"Basic {content}") is None + assert Authorization.from_header(f"Basic {content}") is None + + def test_authorization_eq(self): + basic1 = Authorization.from_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + basic2 = Authorization( + "basic", {"username": "Aladdin", "password": "open sesame"} + ) + assert basic1 == basic2 + bearer1 = Authorization.from_header("Bearer abc") + bearer2 = Authorization("bearer", token="abc") + assert bearer1 == bearer2 + assert basic1 != bearer1 + assert basic1 != object() def test_www_authenticate_header(self): - wa = http.parse_www_authenticate_header('Basic realm="WallyWorld"') + wa = WWWAuthenticate.from_header('Basic realm="WallyWorld"') assert wa.type == "basic" assert wa.realm == "WallyWorld" wa.realm = "Foo Bar" assert wa.to_header() == 'Basic realm="Foo Bar"' - wa = http.parse_www_authenticate_header( - '''Digest - realm="testrealm@host.com", - qop="auth,auth-int", - nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", - opaque="5ccc069c403ebaf9f0171e9517f40e41"''' + wa = WWWAuthenticate.from_header( + 'Digest realm="testrealm@host.com",' + ' qop="auth,auth-int",' + ' nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",' + ' opaque="5ccc069c403ebaf9f0171e9517f40e41"' ) assert wa.type == "digest" assert wa.realm == "testrealm@host.com" - assert "auth" in wa.qop - assert "auth-int" in wa.qop + assert wa.parameters["qop"] == "auth,auth-int" assert wa.nonce == "dcd98b7102dd2f0e8b11d0f600bfb0c093" assert wa.opaque == "5ccc069c403ebaf9f0171e9517f40e41" - wa = http.parse_www_authenticate_header("broken") - assert wa.type == "broken" - - assert not http.parse_www_authenticate_header("").type - assert not http.parse_www_authenticate_header("") + assert WWWAuthenticate.from_header("broken").type == "broken" + assert WWWAuthenticate.from_header("") is None + + def test_www_authenticate_token_padding(self): + # padded with = + token = base64.b64encode(b"This has base64 padding").decode() + a = WWWAuthenticate.from_header(f"Token {token}") + assert a.type == "token" + assert a.token == token + + # padded with == + token = base64.b64encode(b"This has base64 padding..").decode() + a = WWWAuthenticate.from_header(f"Token {token}") + assert a.type == "token" + assert a.token == token + + def test_www_authenticate_eq(self): + basic1 = WWWAuthenticate.from_header("Basic realm=abc") + basic2 = WWWAuthenticate("basic", {"realm": "abc"}) + assert basic1 == basic2 + token1 = WWWAuthenticate.from_header("Token abc") + token2 = WWWAuthenticate("token", token="abc") + assert token1 == token2 + assert basic1 != token1 + assert basic1 != object() def test_etags(self): assert http.quote_etag("foo") == '"foo"' @@ -274,68 +332,63 @@ def test_remove_hop_by_hop_headers(self): http.remove_hop_by_hop_headers(headers2) assert headers2 == datastructures.Headers([("Foo", "bar")]) - def test_parse_options_header(self): - assert http.parse_options_header(None) == ("", {}) - assert http.parse_options_header("") == ("", {}) - assert http.parse_options_header(r'something; foo="other\"thing"') == ( - "something", - {"foo": 'other"thing'}, - ) - assert http.parse_options_header(r'something; foo="other\"thing"; meh=42') == ( - "something", - {"foo": 'other"thing', "meh": "42"}, - ) - assert http.parse_options_header( - r'something; foo="other\"thing"; meh=42; bleh' - ) == ("something", {"foo": 'other"thing', "meh": "42", "bleh": None}) - assert http.parse_options_header( - 'something; foo="other;thing"; meh=42; bleh' - ) == ("something", {"foo": "other;thing", "meh": "42", "bleh": None}) - assert http.parse_options_header('something; foo="otherthing"; meh=; bleh') == ( - "something", - {"foo": "otherthing", "meh": None, "bleh": None}, - ) - # Issue #404 - assert http.parse_options_header( - 'multipart/form-data; name="foo bar"; filename="bar foo"' - ) == ("multipart/form-data", {"name": "foo bar", "filename": "bar foo"}) - # Examples from RFC - assert http.parse_options_header("audio/*; q=0.2, audio/basic") == ( - "audio/*", - {"q": "0.2"}, - ) - - assert http.parse_options_header( - "text/plain; q=0.5, text/html\n text/x-dvi; q=0.8, text/x-c" - ) == ("text/plain", {"q": "0.5"}) - # Issue #932 - assert http.parse_options_header( - "form-data; name=\"a_file\"; filename*=UTF-8''" - '"%c2%a3%20and%20%e2%82%ac%20rates"' - ) == ("form-data", {"name": "a_file", "filename": "\xa3 and \u20ac rates"}) - assert http.parse_options_header( - "form-data; name*=UTF-8''\"%C5%AAn%C4%ADc%C5%8Dde%CC%BD\"; " - 'filename="some_file.txt"' - ) == ( - "form-data", - {"name": "\u016an\u012dc\u014dde\u033d", "filename": "some_file.txt"}, - ) + @pytest.mark.parametrize( + ("value", "expect"), + [ + (None, ""), + ("", ""), + (";a=b", ""), + ("v", "v"), + ("v;", "v"), + ], + ) + def test_parse_options_header_empty(self, value, expect): + assert http.parse_options_header(value) == (expect, {}) - def test_parse_options_header_value_with_quotes(self): - assert http.parse_options_header( - 'form-data; name="file"; filename="t\'es\'t.txt"' - ) == ("form-data", {"name": "file", "filename": "t'es't.txt"}) - assert http.parse_options_header( - "form-data; name=\"file\"; filename*=UTF-8''\"'🐍'.txt\"" - ) == ("form-data", {"name": "file", "filename": "'🐍'.txt"}) + @pytest.mark.parametrize( + ("value", "expect"), + [ + ("v;a=b;c=d;", {"a": "b", "c": "d"}), + ("v; ; a=b ; ", {"a": "b"}), + ("v;a", {}), + ("v;a=", {}), + ("v;=b", {}), + ('v;a="b"', {"a": "b"}), + ("v;a=µ", {}), + ('v;a="\';\'";b="µ";', {"a": "';'", "b": "µ"}), + ('v;a="b c"', {"a": "b c"}), + # HTTP headers use \" for internal " + ('v;a="b\\"c";d=e', {"a": 'b"c', "d": "e"}), + # HTTP headers use \\ for internal \ + ('v;a="c:\\\\"', {"a": "c:\\"}), + # Invalid trailing slash in quoted part is left as-is. + ('v;a="c:\\"', {"a": "c:\\"}), + ('v;a="b\\\\\\"c"', {"a": 'b\\"c'}), + # multipart form data uses %22 for internal " + ('v;a="b%22c"', {"a": 'b"c'}), + ("v;a*=b", {"a": "b"}), + ("v;a*=ASCII'en'b", {"a": "b"}), + ("v;a*=US-ASCII''%62", {"a": "b"}), + ("v;a*=UTF-8''%C2%B5", {"a": "µ"}), + ("v;a*=US-ASCII''%C2%B5", {"a": "��"}), + ("v;a*=BAD''%62", {"a": "%62"}), + ("v;a*=UTF-8'''%F0%9F%90%8D'.txt", {"a": "'🐍'.txt"}), + ('v;a="🐍.txt"', {"a": "🐍.txt"}), + ("v;a*0=b;a*1=c;d=e", {"a": "bc", "d": "e"}), + ("v;a*0*=b", {"a": "b"}), + ("v;a*0*=UTF-8''b;a*1=c;a*2*=%C2%B5", {"a": "bcµ"}), + ], + ) + def test_parse_options_header(self, value, expect) -> None: + assert http.parse_options_header(value) == ("v", expect) def test_parse_options_header_broken_values(self): # Issue #995 assert http.parse_options_header(" ") == ("", {}) - assert http.parse_options_header(" , ") == ("", {}) + assert http.parse_options_header(" , ") == (",", {}) assert http.parse_options_header(" ; ") == ("", {}) - assert http.parse_options_header(" ,; ") == ("", {}) - assert http.parse_options_header(" , a ") == ("", {}) + assert http.parse_options_header(" ,; ") == (",", {}) + assert http.parse_options_header(" , a ") == (", a", {}) assert http.parse_options_header(" ; a ") == ("", {}) def test_parse_options_header_case_insensitive(self): @@ -344,16 +397,12 @@ def test_parse_options_header_case_insensitive(self): def test_dump_options_header(self): assert http.dump_options_header("foo", {"bar": 42}) == "foo; bar=42" - assert http.dump_options_header("foo", {"bar": 42, "fizz": None}) in ( - "foo; bar=42; fizz", - "foo; fizz; bar=42", - ) + assert "fizz" not in http.dump_options_header("foo", {"bar": 42, "fizz": None}) def test_dump_header(self): assert http.dump_header([1, 2, 3]) == "1, 2, 3" - assert http.dump_header([1, 2, 3], allow_token=False) == '"1", "2", "3"' - assert http.dump_header({"foo": "bar"}, allow_token=False) == 'foo="bar"' assert http.dump_header({"foo": "bar"}) == "foo=bar" + assert http.dump_header({"foo*": "UTF-8''bar"}) == "foo*=UTF-8''bar" def test_is_resource_modified(self): env = create_environ() @@ -411,7 +460,8 @@ def test_is_resource_modified_for_range_requests(self): def test_parse_cookie(self): cookies = http.parse_cookie( "dismiss-top=6; CP=null*; PHPSESSID=0a539d42abc001cdc762809248d4beed;" - 'a=42; b="\\";"; ; fo234{=bar;blub=Blah; "__Secure-c"=d' + 'a=42; b="\\";"; ; fo234{=bar;blub=Blah; "__Secure-c"=d;' + "==__Host-eq=bad;__Host-eq=good;" ) assert cookies.to_dict() == { "CP": "null*", @@ -422,6 +472,7 @@ def test_parse_cookie(self): "fo234{": "bar", "blub": "Blah", '"__Secure-c"': "d", + "__Host-eq": "good", } def test_dump_cookie(self): @@ -435,7 +486,7 @@ def test_dump_cookie(self): 'foo="bar baz blub"', } assert http.dump_cookie("key", "xxx/") == "key=xxx/; Path=/" - assert http.dump_cookie("key", "xxx=") == "key=xxx=; Path=/" + assert http.dump_cookie("key", "xxx=", path=None) == "key=xxx=" def test_bad_cookies(self): cookies = http.parse_cookie( @@ -458,9 +509,9 @@ def test_empty_keys_are_ignored(self): def test_cookie_quoting(self): val = http.dump_cookie("foo", "?foo") - assert val == 'foo="?foo"; Path=/' - assert http.parse_cookie(val).to_dict() == {"foo": "?foo", "Path": "/"} - assert http.parse_cookie(r'foo="foo\054bar"').to_dict(), {"foo": "foo,bar"} + assert val == "foo=?foo; Path=/" + assert http.parse_cookie(val)["foo"] == "?foo" + assert http.parse_cookie(r'foo="foo\054bar"')["foo"] == "foo,bar" def test_parse_set_cookie_directive(self): val = 'foo="?foo"; version="0.1";' @@ -482,7 +533,7 @@ def test_cookie_unicode_dumping(self): def test_cookie_unicode_keys(self): # Yes, this is technically against the spec but happens val = http.dump_cookie("fö", "fö") - assert val == _wsgi_encoding_dance('fö="f\\303\\266"; Path=/', "utf-8") + assert val == _wsgi_encoding_dance('fö="f\\303\\266"; Path=/') cookies = http.parse_cookie(val) assert cookies["fö"] == "fö" @@ -495,38 +546,30 @@ def test_cookie_domain_encoding(self): val = http.dump_cookie("foo", "bar", domain="\N{SNOWMAN}.com") assert val == "foo=bar; Domain=xn--n3h.com; Path=/" - val = http.dump_cookie("foo", "bar", domain=".\N{SNOWMAN}.com") - assert val == "foo=bar; Domain=.xn--n3h.com; Path=/" - - val = http.dump_cookie("foo", "bar", domain=".foo.com") - assert val == "foo=bar; Domain=.foo.com; Path=/" + val = http.dump_cookie("foo", "bar", domain="foo.com") + assert val == "foo=bar; Domain=foo.com; Path=/" - def test_cookie_maxsize(self, recwarn): + def test_cookie_maxsize(self): val = http.dump_cookie("foo", "bar" * 1360 + "b") - assert len(recwarn) == 0 assert len(val) == 4093 - http.dump_cookie("foo", "bar" * 1360 + "ba") - assert len(recwarn) == 1 - w = recwarn.pop() - assert "cookie is too large" in str(w.message) + with pytest.warns(UserWarning, match="cookie is too large"): + http.dump_cookie("foo", "bar" * 1360 + "ba") - http.dump_cookie("foo", b"w" * 502, max_size=512) - assert len(recwarn) == 1 - w = recwarn.pop() - assert "the limit is 512 bytes" in str(w.message) + with pytest.warns(UserWarning, match="the limit is 512 bytes"): + http.dump_cookie("foo", "w" * 501, max_size=512) @pytest.mark.parametrize( ("samesite", "expected"), ( - ("strict", "foo=bar; Path=/; SameSite=Strict"), - ("lax", "foo=bar; Path=/; SameSite=Lax"), - ("none", "foo=bar; Path=/; SameSite=None"), - (None, "foo=bar; Path=/"), + ("strict", "foo=bar; SameSite=Strict"), + ("lax", "foo=bar; SameSite=Lax"), + ("none", "foo=bar; SameSite=None"), + (None, "foo=bar"), ), ) def test_cookie_samesite_attribute(self, samesite, expected): - value = http.dump_cookie("foo", "bar", samesite=samesite) + value = http.dump_cookie("foo", "bar", samesite=samesite, path=None) assert value == expected def test_cookie_samesite_invalid(self): @@ -619,6 +662,9 @@ def test_content_range_parsing(self): rv = http.parse_content_range_header("bytes 0-98/*asdfsa") assert rv is None + rv = http.parse_content_range_header("bytes */-1") + assert rv is None + rv = http.parse_content_range_header("bytes 0-99/100") assert rv.to_header() == "bytes 0-99/100" rv.start = None @@ -656,7 +702,7 @@ def test_best_match_works(self): ], ) def test_authorization_to_header(value: str) -> None: - parsed = http.parse_authorization_header(value) + parsed = Authorization.from_header(value) assert parsed is not None assert parsed.to_header() == value @@ -715,3 +761,32 @@ def test_parse_date(value, expect): ) def test_http_date(value, expect): assert http.http_date(value) == expect + + +@pytest.mark.parametrize("value", [".5", "+0.5", "0.5_1", "🯰.🯵"]) +def test_accept_invalid_float(value): + quoted = urllib.parse.quote(value) + + if quoted == value: + q = f"q={value}" + else: + q = f"q*=UTF-8''{value}" + + a = http.parse_accept_header(f"en,jp;{q}") + assert list(a.values()) == ["en"] + + +def test_accept_valid_int_one_zero(): + assert http.parse_accept_header("en;q=1") == http.parse_accept_header("en;q=1.0") + assert http.parse_accept_header("en;q=0") == http.parse_accept_header("en;q=0.0") + assert http.parse_accept_header("en;q=5") == http.parse_accept_header("en;q=5.0") + + +@pytest.mark.parametrize("value", ["🯱🯲🯳", "+1-", "1-1_23"]) +def test_range_invalid_int(value): + assert http.parse_range_header(value) is None + + +@pytest.mark.parametrize("value", ["*/🯱🯲🯳", "1-+2/3", "1_23-125/*"]) +def test_content_range_invalid_int(value): + assert http.parse_content_range_header(f"bytes {value}") is None diff --git a/tests/test_internal.py b/tests/test_internal.py index 6e673fd..edae35b 100644 --- a/tests/test_internal.py +++ b/tests/test_internal.py @@ -1,21 +1,8 @@ -from warnings import filterwarnings -from warnings import resetwarnings - -import pytest - -from werkzeug import _internal as internal from werkzeug.test import create_environ from werkzeug.wrappers import Request from werkzeug.wrappers import Response -def test_easteregg(): - req = Request.from_values("/?macgybarchakku") - resp = Response.force_type(internal._easteregg(None), req) - assert b"About Werkzeug" in resp.get_data() - assert b"the Swiss Army knife of Python web development" in resp.get_data() - - def test_wrapper_internals(): req = Request.from_values(data={"foo": "bar"}, method="POST") req._load_form_data() @@ -34,23 +21,10 @@ def test_wrapper_internals(): resp.response = iter(["Test"]) assert repr(resp) == "" - # string data does not set content length response = Response(["Hällo Wörld"]) headers = response.get_wsgi_headers(create_environ()) - assert "Content-Length" not in headers + assert "Content-Length" in headers response = Response(["Hällo Wörld".encode()]) headers = response.get_wsgi_headers(create_environ()) assert "Content-Length" in headers - - # check for internal warnings - filterwarnings("error", category=Warning) - response = Response() - environ = create_environ() - response.response = "What the...?" - pytest.raises(Warning, lambda: list(response.iter_encoded())) - pytest.raises(Warning, lambda: list(response.get_app_iter(environ))) - response.direct_passthrough = True - pytest.raises(Warning, lambda: list(response.iter_encoded())) - pytest.raises(Warning, lambda: list(response.get_app_iter(environ))) - resetwarnings() diff --git a/tests/test_local.py b/tests/test_local.py index 2af69d2..2250a5b 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -170,7 +170,7 @@ class SomeClassWithWrapped: _cv_val.set(42) with pytest.raises(AttributeError): - proxy.__wrapped__ + proxy.__wrapped__ # noqa: B018 ns = local.Local(_cv_ns) ns.foo = SomeClassWithWrapped() @@ -179,7 +179,7 @@ class SomeClassWithWrapped: assert ns("foo").__wrapped__ == "wrapped" with pytest.raises(AttributeError): - ns("bar").__wrapped__ + ns("bar").__wrapped__ # noqa: B018 def test_proxy_doc(): diff --git a/tests/test_routing.py b/tests/test_routing.py index 15d25a7..02db898 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -95,6 +95,7 @@ def test_merge_slashes_match(): r.Rule("/yes/tail/", endpoint="yes_tail"), r.Rule("/with/", endpoint="with_path"), r.Rule("/no//merge", endpoint="no_merge", merge_slashes=False), + r.Rule("/no/merging", endpoint="no_merging", merge_slashes=False), ] ) adapter = url_map.bind("localhost", "/") @@ -124,6 +125,9 @@ def test_merge_slashes_match(): assert adapter.match("/no//merge")[0] == "no_merge" + assert adapter.match("/no/merging")[0] == "no_merging" + pytest.raises(NotFound, lambda: adapter.match("/no//merging")) + @pytest.mark.parametrize( ("path", "expected"), @@ -163,6 +167,7 @@ def test_strict_slashes_redirect(): r.Rule("/bar/", endpoint="get", methods=["GET"]), r.Rule("/bar", endpoint="post", methods=["POST"]), r.Rule("/foo/", endpoint="foo", methods=["POST"]), + r.Rule("//", endpoint="path", methods=["GET"]), ] ) adapter = map.bind("example.org", "/") @@ -170,6 +175,7 @@ def test_strict_slashes_redirect(): # Check if the actual routes works assert adapter.match("/bar/", method="GET") == ("get", {}) assert adapter.match("/bar", method="POST") == ("post", {}) + assert adapter.match("/abc/", method="GET") == ("path", {"var": "abc"}) # Check if exceptions are correct pytest.raises(r.RequestRedirect, adapter.match, "/bar", method="GET") @@ -177,6 +183,9 @@ def test_strict_slashes_redirect(): with pytest.raises(r.RequestRedirect) as error_info: adapter.match("/foo", method="POST") assert error_info.value.code == 308 + with pytest.raises(r.RequestRedirect) as error_info: + adapter.match("/abc", method="GET") + assert error_info.value.new_url == "http://example.org/abc/" # Check differently defined order map = r.Map( @@ -581,7 +590,8 @@ def test_server_name_interpolation(): with pytest.warns(UserWarning): adapter = map.bind_to_environ(env, server_name="foo") - assert adapter.subdomain == "" + + assert adapter.subdomain == "" def test_rule_emptying(): @@ -742,7 +752,7 @@ def test_uuid_converter(): m = r.Map([r.Rule("/a/", endpoint="a")]) a = m.bind("example.org", "/") route, kwargs = a.match("/a/a8098c1a-f86e-11da-bd1a-00112444be1e") - assert type(kwargs["a_uuid"]) == uuid.UUID + assert type(kwargs["a_uuid"]) == uuid.UUID # noqa: E721 def test_converter_with_tuples(): @@ -773,6 +783,35 @@ def to_url(self, values): assert kwargs["foo"] == ("qwert", "yuiop") +def test_nested_regex_groups(): + """ + Regression test for https://github.com/pallets/werkzeug/issues/2590 + """ + + class RegexConverter(r.BaseConverter): + def __init__(self, url_map, *items): + super().__init__(url_map) + self.part_isolating = False + self.regex = items[0] + + # This is a regex pattern with nested groups + DATE_PATTERN = r"((\d{8}T\d{6}([.,]\d{1,3})?)|(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}([.,]\d{1,3})?))Z" # noqa: E501 + + map = r.Map( + [ + r.Rule( + f"///", + endpoint="handler", + ) + ], + converters={"regex": RegexConverter}, + ) + a = map.bind("example.org", "/") + route, kwargs = a.match("/2023-02-16T23:36:36.266Z/2023-02-16T23:46:36.266Z/") + assert kwargs["start"] == "2023-02-16T23:36:36.266Z" + assert kwargs["end"] == "2023-02-16T23:46:36.266Z" + + def test_anyconverter(): m = r.Map( [ @@ -800,6 +839,20 @@ def test_any_converter_build_validates_value() -> None: assert str(exc.value) == "'invalid' is not one of 'patient', 'provider'" +def test_part_isolating_default() -> None: + class TwoConverter(r.BaseConverter): + regex = r"\w+/\w+" + + def to_python(self, value: str) -> t.Any: + return value.split("/") + + m = r.Map( + [r.Rule("//", endpoint="two")], converters={"two": TwoConverter} + ) + a = m.bind("localhost") + assert a.match("/a/b/") == ("two", {"values": ["a", "b"]}) + + @pytest.mark.parametrize( ("endpoint", "value", "expect"), [ @@ -874,6 +927,7 @@ def to_url(self, value: t.Any) -> str: ([1, 2], "?v=1&v=2"), ([1, None, 2], "?v=1&v=2"), ([1, "", 2], "?v=1&v=&v=2"), + ("1+2", "?v=1%2B2"), ], ) def test_build_append_unknown_dict(value, expect): @@ -910,8 +964,7 @@ def test_build_drop_none(): adapter = map.bind("", "/") params = {"flub": None, "flop": None} with pytest.raises(r.BuildError): - x = adapter.build("endp", params) - assert not x + adapter.build("endp", params) params = {"flub": "x", "flop": None} url = adapter.build("endp", params) assert "flop" not in url @@ -992,7 +1045,16 @@ def test_external_building_with_port_bind_to_environ_wrong_servername(): with pytest.warns(UserWarning): adapter = map.bind_to_environ(environ, server_name="example.org") - assert adapter.subdomain == "" + + assert adapter.subdomain == "" + + +def test_bind_long_idna_name_with_port(): + map = r.Map([r.Rule("/", endpoint="index")]) + adapter = map.bind("🐍" + "a" * 52 + ":8443") + name, _, port = adapter.server_name.partition(":") + assert len(name) == 63 + assert port == "8443" def test_converter_parser(): @@ -1014,6 +1076,9 @@ def test_converter_parser(): args, kwargs = r.parse_converter_args('"foo", "bar"') assert args == ("foo", "bar") + with pytest.raises(ValueError): + r.parse_converter_args("min=0;max=500") + def test_alias_redirects(): m = r.Map( @@ -1071,18 +1136,6 @@ def test_double_defaults(prefix): assert a.build("x", {"bar": True}) == f"{prefix}/bar/" -def test_building_bytes(): - m = r.Map( - [ - r.Rule("/", endpoint="a"), - r.Rule("/", defaults={"b": b"\x01\x02\x03"}, endpoint="b"), - ] - ) - a = m.bind("example.org", "/") - assert a.build("a", {"a": b"\x01\x02\x03"}) == "/%01%02%03" - assert a.build("b") == "/%01%02%03" - - def test_host_matching(): m = r.Map( [ @@ -1434,6 +1487,9 @@ def test_strict_slashes_false(): [ r.Rule("/path1", endpoint="leaf_path", strict_slashes=False), r.Rule("/path2/", endpoint="branch_path", strict_slashes=False), + r.Rule( + "/", endpoint="leaf_path_converter", strict_slashes=False + ), ], ) @@ -1443,12 +1499,19 @@ def test_strict_slashes_false(): assert adapter.match("/path1/", method="GET") == ("leaf_path", {}) assert adapter.match("/path2", method="GET") == ("branch_path", {}) assert adapter.match("/path2/", method="GET") == ("branch_path", {}) + assert adapter.match("/any", method="GET") == ( + "leaf_path_converter", + {"path": "any"}, + ) + assert adapter.match("/any/", method="GET") == ( + "leaf_path_converter", + {"path": "any/"}, + ) def test_invalid_rule(): with pytest.raises(ValueError): - map_ = r.Map([r.Rule("/", endpoint="test")]) - map_.bind("localhost") + r.Map([r.Rule("/", endpoint="test")]) def test_multiple_converters_per_part(): diff --git a/tests/test_security.py b/tests/test_security.py index 3e797fc..6fad089 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,5 +1,6 @@ import os import posixpath +import sys import pytest @@ -8,25 +9,42 @@ from werkzeug.security import safe_join -def test_password_hashing(): - hash0 = generate_password_hash("default") - assert check_password_hash(hash0, "default") - assert hash0.startswith("pbkdf2:sha256:260000$") +def test_default_password_method(): + value = generate_password_hash("secret") + assert value.startswith("scrypt:") - hash1 = generate_password_hash("default", "sha1") - hash2 = generate_password_hash("default", method="sha1") + +@pytest.mark.xfail( + sys.implementation.name == "pypy", reason="scrypt unavailable on pypy" +) +def test_scrypt(): + value = generate_password_hash("secret", method="scrypt") + assert check_password_hash(value, "secret") + assert value.startswith("scrypt:32768:8:1$") + + +def test_pbkdf2(): + value = generate_password_hash("secret", method="pbkdf2") + assert check_password_hash(value, "secret") + assert value.startswith("pbkdf2:sha256:600000$") + + +def test_salted_hashes(): + hash1 = generate_password_hash("secret") + hash2 = generate_password_hash("secret") assert hash1 != hash2 - assert check_password_hash(hash1, "default") - assert check_password_hash(hash2, "default") - assert hash1.startswith("sha1$") - assert hash2.startswith("sha1$") + assert check_password_hash(hash1, "secret") + assert check_password_hash(hash2, "secret") + +def test_require_salt(): with pytest.raises(ValueError): - generate_password_hash("default", "sha1", salt_length=0) + generate_password_hash("secret", salt_length=0) + - fakehash = generate_password_hash("default", method="plain") - assert fakehash == "plain$$default" - assert check_password_hash(fakehash, "default") +def test_invalid_method(): + with pytest.raises(ValueError, match="Invalid hash method"): + generate_password_hash("secret", "sha256") def test_safe_join(): diff --git a/tests/test_send_file.py b/tests/test_send_file.py index fc4299a..4aa69f2 100644 --- a/tests/test_send_file.py +++ b/tests/test_send_file.py @@ -107,6 +107,9 @@ def test_object_attachment_requires_name(): ("Vögel.txt", "Vogel.txt", "V%C3%B6gel.txt"), # ":/" are not safe in filename* value ("те:/ст", '":/"', "%D1%82%D0%B5%3A%2F%D1%81%D1%82"), + # general test of extended parameter (non-quoted) + ("(тест.txt", '"(.txt"', "%28%D1%82%D0%B5%D1%81%D1%82.txt"), + ("(test.txt", '"(test.txt"', None), ), ) def test_non_ascii_name(name, ascii, utf8): diff --git a/tests/test_serving.py b/tests/test_serving.py index 0494828..4abc755 100644 --- a/tests/test_serving.py +++ b/tests/test_serving.py @@ -7,13 +7,18 @@ import sys from io import BytesIO from pathlib import Path +from unittest.mock import patch import pytest +from watchdog.events import EVENT_TYPE_MODIFIED +from watchdog.events import EVENT_TYPE_OPENED +from watchdog.events import FileModifiedEvent from werkzeug import run_simple from werkzeug._reloader import _find_stat_paths from werkzeug._reloader import _find_watchdog_paths from werkzeug._reloader import _get_args_for_reloading +from werkzeug._reloader import WatchdogReloaderLoop from werkzeug.datastructures import FileStorage from werkzeug.serving import make_ssl_devcert from werkzeug.test import stream_encode_multipart @@ -115,6 +120,23 @@ def test_reloader_sys_path(tmp_path, dev_server, reloader_type): assert client.request().status == 200 +@patch.object(WatchdogReloaderLoop, "trigger_reload") +def test_watchdog_reloader_ignores_opened(mock_trigger_reload): + reloader = WatchdogReloaderLoop() + modified_event = FileModifiedEvent("") + modified_event.event_type = EVENT_TYPE_MODIFIED + reloader.event_handler.on_any_event(modified_event) + mock_trigger_reload.assert_called_once() + + reloader.trigger_reload.reset_mock() + + opened_event = FileModifiedEvent("") + opened_event.event_type = EVENT_TYPE_OPENED + reloader.event_handler.on_any_event(opened_event) + reloader.trigger_reload.assert_not_called() + + +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="not needed on >= 3.10") def test_windows_get_args_for_reloading(monkeypatch, tmp_path): argv = [str(tmp_path / "test.exe"), "run"] monkeypatch.setattr("sys.executable", str(tmp_path / "python.exe")) @@ -125,14 +147,18 @@ def test_windows_get_args_for_reloading(monkeypatch, tmp_path): assert rv == argv +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.parametrize("find", [_find_stat_paths, _find_watchdog_paths]) def test_exclude_patterns(find): - # Imported paths under sys.prefix will be included by default. + # Select a path to exclude from the unfiltered list, assert that it is present and + # then gets excluded. paths = find(set(), set()) - assert any(p.startswith(sys.prefix) for p in paths) + path_to_exclude = next(iter(paths)) + assert any(p.startswith(path_to_exclude) for p in paths) + # Those paths should be excluded due to the pattern. - paths = find(set(), {f"{sys.prefix}*"}) - assert not any(p.startswith(sys.prefix) for p in paths) + paths = find(set(), {f"{path_to_exclude}*"}) + assert not any(p.startswith(path_to_exclude) for p in paths) @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @@ -254,6 +280,7 @@ def test_multiline_header_folding(standard_app): @pytest.mark.parametrize("endpoint", ["", "crash"]) +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server def test_streaming_close_response(dev_server, endpoint): """When using HTTP/1.0, chunked encoding is not supported. Fall @@ -265,6 +292,7 @@ def test_streaming_close_response(dev_server, endpoint): assert r.data == "".join(str(x) + "\n" for x in range(5)).encode() +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server def test_streaming_chunked_response(dev_server): """When using HTTP/1.1, use Transfer-Encoding: chunked for streamed diff --git a/tests/test_test.py b/tests/test_test.py index 02d637e..d317d69 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -10,13 +10,13 @@ from werkzeug.datastructures import Headers from werkzeug.datastructures import MultiDict from werkzeug.formparser import parse_form_data -from werkzeug.http import parse_authorization_header from werkzeug.test import Client from werkzeug.test import ClientRedirectError from werkzeug.test import create_environ from werkzeug.test import EnvironBuilder from werkzeug.test import run_wsgi_app from werkzeug.test import stream_encode_multipart +from werkzeug.test import TestResponse from werkzeug.utils import redirect from werkzeug.wrappers import Request from werkzeug.wrappers import Response @@ -74,7 +74,7 @@ def multi_value_post_app(environ, start_response): def test_cookie_forging(): c = Client(cookie_app) - c.set_cookie("localhost", "foo", "bar") + c.set_cookie("foo", "bar") response = c.open() assert response.text == "foo=bar" @@ -88,7 +88,7 @@ def test_set_cookie_app(): def test_cookiejar_stores_cookie(): c = Client(cookie_app) c.open() - assert "test" in c.cookie_jar._cookies["localhost.local"]["/"] + assert c.get_cookie("test") is not None def test_no_initial_cookie(): @@ -118,6 +118,25 @@ def test_cookie_for_different_path(): assert response.text == "test=test" +def test_cookie_default_path() -> None: + """When no path is set for a cookie, the default uses everything up to but not + including the first slash. + """ + + @Request.application + def app(request: Request) -> Response: + r = Response() + r.set_cookie("k", "v", path=None) + return r + + c = Client(app) + c.get("/nested/leaf") + assert c.get_cookie("k") is None + assert c.get_cookie("k", path="/nested") is not None + c.get("/nested/dir/") + assert c.get_cookie("k", path="/nested/dir") is not None + + def test_environ_builder_basics(): b = EnvironBuilder() assert b.content_type is None @@ -284,9 +303,8 @@ def test_environ_builder_content_type(): def test_basic_auth(): builder = EnvironBuilder(auth=("username", "password")) request = builder.get_request() - auth = parse_authorization_header(request.headers["Authorization"]) - assert auth.username == "username" - assert auth.password == "password" + assert request.authorization.username == "username" + assert request.authorization.password == "password" def test_auth_object(): @@ -340,6 +358,23 @@ def test_environ_builder_unicode_file_mix(): files["f"].close() +def test_environ_builder_empty_file(): + f = FileStorage(BytesIO(rb""), "empty.txt") + d = MultiDict(dict(f=f, s="")) + stream, length, boundary = stream_encode_multipart(d) + _, form, files = parse_form_data( + { + "wsgi.input": stream, + "CONTENT_LENGTH": str(length), + "CONTENT_TYPE": f'multipart/form-data; boundary="{boundary}"', + } + ) + assert form["s"] == "" + assert files["f"].read() == rb"" + stream.close() + files["f"].close() + + def test_create_environ(): env = create_environ("/foo?bar=baz", "http://example.org/") expected = { @@ -392,7 +427,7 @@ def test_file_closing(): class SpecialInput: def read(self, size): - return "" + return b"" def close(self): closed.append(self) @@ -752,8 +787,8 @@ def test_multiple_cookies(): @Request.application def test_app(request): response = Response(repr(sorted(request.cookies.items()))) - response.set_cookie("test1", b"foo") - response.set_cookie("test2", b"bar") + response.set_cookie("test1", "foo") + response.set_cookie("test2", "bar") return response client = Client(test_app) @@ -869,3 +904,24 @@ def test_no_content_type_header_addition(): c = Client(no_response_headers_app) response = c.open() assert response.headers == Headers([("Content-Length", "8")]) + + +def test_client_response_wrapper(): + class CustomResponse(Response): + pass + + class CustomTestResponse(TestResponse, Response): + pass + + c1 = Client(Response(), CustomResponse) + r1 = c1.open() + + assert isinstance(r1, CustomResponse) + assert type(r1) is not CustomResponse # Got subclassed + assert issubclass(type(r1), CustomResponse) + + c2 = Client(Response(), CustomTestResponse) + r2 = c2.open() + + assert isinstance(r2, CustomTestResponse) + assert type(r2) is CustomTestResponse # Did not get subclassed diff --git a/tests/test_urls.py b/tests/test_urls.py index a409709..101b886 100644 --- a/tests/test_urls.py +++ b/tests/test_urls.py @@ -1,240 +1,26 @@ -import io - import pytest from werkzeug import urls -from werkzeug.datastructures import OrderedMultiDict - - -def test_parsing(): - url = urls.url_parse("http://anon:hunter2@[2001:db8:0:1]:80/a/b/c") - assert url.netloc == "anon:hunter2@[2001:db8:0:1]:80" - assert url.username == "anon" - assert url.password == "hunter2" - assert url.port == 80 - assert url.ascii_host == "2001:db8:0:1" - - assert url.get_file_location() == (None, None) # no file scheme - - -@pytest.mark.parametrize("implicit_format", (True, False)) -@pytest.mark.parametrize("localhost", ("127.0.0.1", "::1", "localhost")) -def test_fileurl_parsing_windows(implicit_format, localhost, monkeypatch): - if implicit_format: - pathformat = None - monkeypatch.setattr("os.name", "nt") - else: - pathformat = "windows" - monkeypatch.delattr("os.name") # just to make sure it won't get used - - url = urls.url_parse("file:///C:/Documents and Settings/Foobar/stuff.txt") - assert url.netloc == "" - assert url.scheme == "file" - assert url.get_file_location(pathformat) == ( - None, - r"C:\Documents and Settings\Foobar\stuff.txt", - ) - - url = urls.url_parse("file://///server.tld/file.txt") - assert url.get_file_location(pathformat) == ("server.tld", r"file.txt") - - url = urls.url_parse("file://///server.tld") - assert url.get_file_location(pathformat) == ("server.tld", "") - - url = urls.url_parse(f"file://///{localhost}") - assert url.get_file_location(pathformat) == (None, "") - - url = urls.url_parse(f"file://///{localhost}/file.txt") - assert url.get_file_location(pathformat) == (None, r"file.txt") - - -def test_replace(): - url = urls.url_parse("http://de.wikipedia.org/wiki/Troll") - assert url.replace(query="foo=bar") == urls.url_parse( - "http://de.wikipedia.org/wiki/Troll?foo=bar" - ) - assert url.replace(scheme="https") == urls.url_parse( - "https://de.wikipedia.org/wiki/Troll" - ) - - -def test_quoting(): - assert urls.url_quote("\xf6\xe4\xfc") == "%C3%B6%C3%A4%C3%BC" - assert urls.url_unquote(urls.url_quote('#%="\xf6')) == '#%="\xf6' - assert urls.url_quote_plus("foo bar") == "foo+bar" - assert urls.url_unquote_plus("foo+bar") == "foo bar" - assert urls.url_quote_plus("foo+bar") == "foo%2Bbar" - assert urls.url_unquote_plus("foo%2Bbar") == "foo+bar" - assert urls.url_encode({b"a": None, b"b": b"foo bar"}) == "b=foo+bar" - assert urls.url_encode({"a": None, "b": "foo bar"}) == "b=foo+bar" - assert ( - urls.url_fix("http://de.wikipedia.org/wiki/Elf (Begriffsklärung)") - == "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)" - ) - assert urls.url_quote_plus(42) == "42" - assert urls.url_quote(b"\xff") == "%FF" - - -def test_bytes_unquoting(): - assert ( - urls.url_unquote(urls.url_quote('#%="\xf6', charset="latin1"), charset=None) - == b'#%="\xf6' - ) - - -def test_url_decoding(): - x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel") - assert x["foo"] == "42" - assert x["bar"] == "23" - assert x["uni"] == "Hänsel" - - x = urls.url_decode(b"foo=42;bar=23;uni=H%C3%A4nsel", separator=b";") - assert x["foo"] == "42" - assert x["bar"] == "23" - assert x["uni"] == "Hänsel" - - x = urls.url_decode(b"%C3%9Ch=H%C3%A4nsel") - assert x["Üh"] == "Hänsel" - - -def test_url_bytes_decoding(): - x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel", charset=None) - assert x[b"foo"] == b"42" - assert x[b"bar"] == b"23" - assert x[b"uni"] == "Hänsel".encode() - - -def test_stream_decoding_string_fails(): - pytest.raises(TypeError, urls.url_decode_stream, "testing") - - -def test_url_encoding(): - assert urls.url_encode({"foo": "bar 45"}) == "foo=bar+45" - d = {"foo": 1, "bar": 23, "blah": "Hänsel"} - assert urls.url_encode(d, sort=True) == "bar=23&blah=H%C3%A4nsel&foo=1" - assert ( - urls.url_encode(d, sort=True, separator=";") == "bar=23;blah=H%C3%A4nsel;foo=1" - ) - - -def test_sorted_url_encode(): - assert ( - urls.url_encode( - {"a": 42, "b": 23, 1: 1, 2: 2}, sort=True, key=lambda i: str(i[0]) - ) - == "1=1&2=2&a=42&b=23" - ) - assert ( - urls.url_encode( - {"A": 1, "a": 2, "B": 3, "b": 4}, - sort=True, - key=lambda x: x[0].lower() + x[0], - ) - == "A=1&a=2&B=3&b=4" - ) - - -def test_streamed_url_encoding(): - out = io.StringIO() - urls.url_encode_stream({"foo": "bar 45"}, out) - assert out.getvalue() == "foo=bar+45" - - d = {"foo": 1, "bar": 23, "blah": "Hänsel"} - out = io.StringIO() - urls.url_encode_stream(d, out, sort=True) - assert out.getvalue() == "bar=23&blah=H%C3%A4nsel&foo=1" - out = io.StringIO() - urls.url_encode_stream(d, out, sort=True, separator=";") - assert out.getvalue() == "bar=23;blah=H%C3%A4nsel;foo=1" - - gen = urls.url_encode_stream(d, sort=True) - assert next(gen) == "bar=23" - assert next(gen) == "blah=H%C3%A4nsel" - assert next(gen) == "foo=1" - pytest.raises(StopIteration, lambda: next(gen)) - - -def test_url_fixing(): - x = urls.url_fix("http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)") - assert x == "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)" - - x = urls.url_fix("http://just.a.test/$-_.+!*'(),") - assert x == "http://just.a.test/$-_.+!*'()," - - x = urls.url_fix("http://höhöhö.at/höhöhö/hähähä") - assert x == r"http://xn--hhh-snabb.at/h%C3%B6h%C3%B6h%C3%B6/h%C3%A4h%C3%A4h%C3%A4" - - -def test_url_fixing_filepaths(): - x = urls.url_fix(r"file://C:\Users\Administrator\My Documents\ÑÈáÇíí") - assert x == ( - r"file:///C%3A/Users/Administrator/My%20Documents/" - r"%C3%91%C3%88%C3%A1%C3%87%C3%AD%C3%AD" - ) - - a = urls.url_fix(r"file:/C:/") - b = urls.url_fix(r"file://C:/") - c = urls.url_fix(r"file:///C:/") - assert a == b == c == r"file:///C%3A/" - - x = urls.url_fix(r"file://host/sub/path") - assert x == r"file://host/sub/path" - - x = urls.url_fix(r"file:///") - assert x == r"file:///" - - -def test_url_fixing_qs(): - x = urls.url_fix(b"http://example.com/?foo=%2f%2f") - assert x == "http://example.com/?foo=%2f%2f" - - x = urls.url_fix( - "http://acronyms.thefreedictionary.com/" - "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation" - ) - assert x == ( - "http://acronyms.thefreedictionary.com/" - "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation" - ) def test_iri_support(): assert urls.uri_to_iri("http://xn--n3h.net/") == "http://\u2603.net/" - assert ( - urls.uri_to_iri(b"http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th") - == "http://\xfcser:p\xe4ssword@\u2603.net/p\xe5th" - ) assert urls.iri_to_uri("http://☃.net/") == "http://xn--n3h.net/" assert ( urls.iri_to_uri("http://üser:pässword@☃.net/påth") == "http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th" ) - assert ( urls.uri_to_iri("http://test.com/%3Fmeh?foo=%26%2F") - == "http://test.com/%3Fmeh?foo=%26%2F" + == "http://test.com/%3Fmeh?foo=%26/" ) - - # this should work as well, might break on 2.4 because of a broken - # idna codec - assert urls.uri_to_iri(b"/foo") == "/foo" assert urls.iri_to_uri("/foo") == "/foo" - assert ( urls.iri_to_uri("http://föö.com:8080/bam/baz") == "http://xn--f-1gaa.com:8080/bam/baz" ) -def test_iri_safe_conversion(): - assert urls.iri_to_uri("magnet:?foo=bar") == "magnet:?foo=bar" - assert urls.iri_to_uri("itms-service://?foo=bar") == "itms-service:?foo=bar" - assert ( - urls.iri_to_uri("itms-service://?foo=bar", safe_conversion=True) - == "itms-service://?foo=bar" - ) - - def test_iri_safe_quoting(): uri = "http://xn--f-1gaa.com/%2F%25?q=%C3%B6&x=%3D%25#%25" iri = "http://föö.com/%2F%25?q=ö&x=%3D%25#%25" @@ -242,83 +28,11 @@ def test_iri_safe_quoting(): assert urls.iri_to_uri(urls.uri_to_iri(uri)) == uri -def test_ordered_multidict_encoding(): - d = OrderedMultiDict() - d.add("foo", 1) - d.add("foo", 2) - d.add("foo", 3) - d.add("bar", 0) - d.add("foo", 4) - assert urls.url_encode(d) == "foo=1&foo=2&foo=3&bar=0&foo=4" - - -def test_multidict_encoding(): - d = OrderedMultiDict() - d.add("2013-10-10T23:26:05.657975+0000", "2013-10-10T23:26:05.657975+0000") - assert ( - urls.url_encode(d) - == "2013-10-10T23%3A26%3A05.657975%2B0000=2013-10-10T23%3A26%3A05.657975%2B0000" - ) - - -def test_url_unquote_plus_unicode(): - # was broken in 0.6 - assert urls.url_unquote_plus("\x6d") == "\x6d" - - def test_quoting_of_local_urls(): rv = urls.iri_to_uri("/foo\x8f") assert rv == "/foo%C2%8F" -def test_url_attributes(): - rv = urls.url_parse("http://foo%3a:bar%3a@[::1]:80/123?x=y#frag") - assert rv.scheme == "http" - assert rv.auth == "foo%3a:bar%3a" - assert rv.username == "foo:" - assert rv.password == "bar:" - assert rv.raw_username == "foo%3a" - assert rv.raw_password == "bar%3a" - assert rv.host == "::1" - assert rv.port == 80 - assert rv.path == "/123" - assert rv.query == "x=y" - assert rv.fragment == "frag" - - rv = urls.url_parse("http://\N{SNOWMAN}.com/") - assert rv.host == "\N{SNOWMAN}.com" - assert rv.ascii_host == "xn--n3h.com" - - -def test_url_attributes_bytes(): - rv = urls.url_parse(b"http://foo%3a:bar%3a@[::1]:80/123?x=y#frag") - assert rv.scheme == b"http" - assert rv.auth == b"foo%3a:bar%3a" - assert rv.username == "foo:" - assert rv.password == "bar:" - assert rv.raw_username == b"foo%3a" - assert rv.raw_password == b"bar%3a" - assert rv.host == b"::1" - assert rv.port == 80 - assert rv.path == b"/123" - assert rv.query == b"x=y" - assert rv.fragment == b"frag" - - -def test_url_joining(): - assert urls.url_join("/foo", "/bar") == "/bar" - assert urls.url_join("http://example.com/foo", "/bar") == "http://example.com/bar" - assert urls.url_join("file:///tmp/", "test.html") == "file:///tmp/test.html" - assert urls.url_join("file:///tmp/x", "test.html") == "file:///tmp/test.html" - assert urls.url_join("file:///tmp/x", "../../../x.html") == "file:///x.html" - - -def test_partial_unencoded_decode(): - ref = "foo=정상처리".encode("euc-kr") - x = urls.url_decode(ref, charset="euc-kr") - assert x["foo"] == "정상처리" - - def test_iri_to_uri_idempotence_ascii_only(): uri = "http://www.idempoten.ce" uri = urls.iri_to_uri(uri) @@ -355,31 +69,38 @@ def test_uri_to_iri_to_uri(): assert urls.iri_to_uri(iri) == uri -def test_uri_iri_normalization(): - uri = "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93" - iri = "http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713" - - tests = [ +@pytest.mark.parametrize( + "value", + [ "http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713", "http://xn--f-rgao.com/\u2610/fred?utf8=\N{CHECK MARK}", - b"http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93", + "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93", "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93", "http://föñ.com/\u2610/fred?utf8=%E2%9C%93", - b"http://xn--f-rgao.com/\xe2\x98\x90/fred?utf8=\xe2\x9c\x93", - ] - - for test in tests: - assert urls.uri_to_iri(test) == iri - assert urls.iri_to_uri(test) == uri - assert urls.uri_to_iri(urls.iri_to_uri(test)) == iri - assert urls.iri_to_uri(urls.uri_to_iri(test)) == uri - assert urls.uri_to_iri(urls.uri_to_iri(test)) == iri - assert urls.iri_to_uri(urls.iri_to_uri(test)) == uri + ], +) +def test_uri_iri_normalization(value): + uri = "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93" + iri = "http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713" + assert urls.uri_to_iri(value) == iri + assert urls.iri_to_uri(value) == uri + assert urls.uri_to_iri(urls.iri_to_uri(value)) == iri + assert urls.iri_to_uri(urls.uri_to_iri(value)) == uri + assert urls.uri_to_iri(urls.uri_to_iri(value)) == iri + assert urls.iri_to_uri(urls.iri_to_uri(value)) == uri def test_uri_to_iri_dont_unquote_space(): assert urls.uri_to_iri("abc%20def") == "abc%20def" -def test_iri_to_uri_dont_quote_reserved(): - assert urls.iri_to_uri("/path[bracket]?(paren)") == "/path[bracket]?(paren)" +def test_iri_to_uri_dont_quote_valid_code_points(): + # [] are not valid URL code points according to WhatWG URL Standard + # https://url.spec.whatwg.org/#url-code-points + assert urls.iri_to_uri("/path[bracket]?(paren)") == "/path%5Bbracket%5D?(paren)" + + +# Python < 3.12 +def test_itms_services() -> None: + url = "itms-services://?action=download-manifest&url=https://test.example/path" + assert urls.iri_to_uri(url) == url diff --git a/tests/test_utils.py b/tests/test_utils.py index ed8d8d0..c48eba5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect from datetime import datetime @@ -9,48 +11,32 @@ from werkzeug.http import http_date from werkzeug.http import parse_date from werkzeug.test import Client +from werkzeug.test import EnvironBuilder from werkzeug.wrappers import Response -def test_redirect(): - resp = utils.redirect("/füübär") - assert resp.headers["Location"] == "/f%C3%BC%C3%BCb%C3%A4r" - assert resp.status_code == 302 - assert resp.get_data() == ( - b"\n" - b"\n" - b"Redirecting...\n" - b"

Redirecting...

\n" - b"

You should be redirected automatically to the target URL: " - b'/f\xc3\xbc\xc3\xbcb\xc3\xa4r. ' - b"If not, click the link.\n" - ) +@pytest.mark.parametrize( + ("url", "code", "expect"), + [ + ("http://example.com", None, "http://example.com"), + ("/füübär", 305, "/f%C3%BC%C3%BCb%C3%A4r"), + ("http://☃.example.com/", 307, "http://xn--n3h.example.com/"), + ("itms-services://?url=abc", None, "itms-services://?url=abc"), + ], +) +def test_redirect(url: str, code: int | None, expect: str) -> None: + environ = EnvironBuilder().get_environ() - resp = utils.redirect("http://☃.net/", 307) - assert resp.headers["Location"] == "http://xn--n3h.net/" - assert resp.status_code == 307 - assert resp.get_data() == ( - b"\n" - b"\n" - b"Redirecting...\n" - b"

Redirecting...

\n" - b"

You should be redirected automatically to the target URL: " - b'http://\xe2\x98\x83.net/. ' - b"If not, click the link.\n" - ) + if code is None: + resp = utils.redirect(url) + assert resp.status_code == 302 + else: + resp = utils.redirect(url, code) + assert resp.status_code == code - resp = utils.redirect("http://example.com/", 305) - assert resp.headers["Location"] == "http://example.com/" - assert resp.status_code == 305 - assert resp.get_data() == ( - b"\n" - b"\n" - b"Redirecting...\n" - b"

Redirecting...

\n" - b"

You should be redirected automatically to the target URL: " - b'http://example.com/. ' - b"If not, click the link.\n" - ) + assert resp.headers["Location"] == url + assert resp.get_wsgi_headers(environ)["Location"] == expect + assert resp.get_data(as_text=True).count(url) == 2 def test_redirect_xss(): @@ -190,6 +176,7 @@ def test_assign(): def test_import_string(): from datetime import date + from werkzeug.debug import DebuggedApplication assert utils.import_string("datetime.date") is date diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index b769a38..f756944 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -20,9 +20,11 @@ from werkzeug.datastructures import LanguageAccept from werkzeug.datastructures import MIMEAccept from werkzeug.datastructures import MultiDict +from werkzeug.datastructures import WWWAuthenticate from werkzeug.exceptions import BadRequest from werkzeug.exceptions import RequestedRangeNotSatisfiable from werkzeug.exceptions import SecurityError +from werkzeug.exceptions import UnsupportedMediaType from werkzeug.http import COEP from werkzeug.http import COOP from werkzeug.http import generate_etag @@ -136,11 +138,12 @@ def test_url_request_descriptors(): def test_url_request_descriptors_query_quoting(): - next = "http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash" - req = wrappers.Request.from_values(f"/bar?next={next}", "http://example.com/") + quoted = "http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash" + unquoted = "http://www.example.com/?next%3D/baz%23my%3Dhash" + req = wrappers.Request.from_values(f"/bar?next={quoted}", "http://example.com/") assert req.path == "/bar" - assert req.full_path == f"/bar?next={next}" - assert req.url == f"http://example.com/bar?next={next}" + assert req.full_path == f"/bar?next={quoted}" + assert req.url == f"http://example.com/bar?next={unquoted}" def test_url_request_descriptors_hosts(): @@ -349,13 +352,6 @@ def test_response_init_status_empty_string(): assert "Empty status argument" in str(info.value) -def test_response_init_status_tuple(): - with pytest.raises(TypeError) as info: - wrappers.Response(None, tuple()) - - assert "Invalid status argument" in str(info.value) - - def test_type_forcing(): def wsgi_application(environ, start_response): start_response("200 OK", [("Content-Type", "text/html")]) @@ -686,27 +682,26 @@ def test_etag_response_freezing(): def test_authenticate(): resp = wrappers.Response() - resp.www_authenticate.type = "basic" resp.www_authenticate.realm = "Testing" - assert resp.headers["WWW-Authenticate"] == 'Basic realm="Testing"' - resp.www_authenticate.realm = None - resp.www_authenticate.type = None + assert resp.headers["WWW-Authenticate"] == "Basic realm=Testing" + del resp.www_authenticate assert "WWW-Authenticate" not in resp.headers def test_authenticate_quoted_qop(): # Example taken from https://github.com/pallets/werkzeug/issues/633 resp = wrappers.Response() - resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth", "auth-int")) + resp.www_authenticate = WWWAuthenticate( + "digest", {"realm": "REALM", "nonce": "NONCE", "qop": "auth, auth-int"} + ) - actual = set(f"{resp.headers['WWW-Authenticate']},".split()) - expected = set('Digest nonce="NONCE", realm="REALM", qop="auth, auth-int",'.split()) + actual = resp.headers["WWW-Authenticate"] + expected = 'Digest realm="REALM", nonce="NONCE", qop="auth, auth-int"' assert actual == expected - resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth",)) - - actual = set(f"{resp.headers['WWW-Authenticate']},".split()) - expected = set('Digest nonce="NONCE", realm="REALM", qop="auth",'.split()) + resp.www_authenticate.parameters["qop"] = "auth" + actual = resp.headers["WWW-Authenticate"] + expected = 'Digest realm="REALM", nonce="NONCE", qop="auth"' assert actual == expected @@ -875,12 +870,6 @@ def test_file_closing_with(): assert foo.closed is True -def test_url_charset_reflection(): - req = wrappers.Request.from_values() - req.charset = "utf-7" - assert req.url_charset == "utf-7" - - def test_response_streamed(): r = wrappers.Response() assert not r.is_streamed @@ -1048,25 +1037,25 @@ class MyRequest(wrappers.Request): parameter_storage_class = dict req = MyRequest.from_values("/?foo=baz", headers={"Cookie": "foo=bar"}) - assert type(req.cookies) is dict + assert type(req.cookies) is dict # noqa: E721 assert req.cookies == {"foo": "bar"} - assert type(req.access_route) is list + assert type(req.access_route) is list # noqa: E721 - assert type(req.args) is dict - assert type(req.values) is CombinedMultiDict + assert type(req.args) is dict # noqa: E721 + assert type(req.values) is CombinedMultiDict # noqa: E721 assert req.values["foo"] == "baz" req = wrappers.Request.from_values(headers={"Cookie": "foo=bar;foo=baz"}) - assert type(req.cookies) is ImmutableMultiDict + assert type(req.cookies) is ImmutableMultiDict # noqa: E721 assert req.cookies.to_dict() == {"foo": "bar"} # it is possible to have multiple cookies with the same name assert req.cookies.getlist("foo") == ["bar", "baz"] - assert type(req.access_route) is ImmutableList + assert type(req.access_route) is ImmutableList # noqa: E721 MyRequest.list_storage_class = tuple req = MyRequest.from_values() - assert type(req.access_route) is tuple + assert type(req.access_route) is tuple # noqa: E721 def test_response_headers_passthrough(): @@ -1165,6 +1154,7 @@ class MyResponse(wrappers.Response): ("auto", "location", "expect"), ( (False, "/test", "/test"), + (False, "/\\\\test.example?q", "/%5C%5Ctest.example?q"), (True, "/test", "http://localhost/test"), (True, "test", "http://localhost/a/b/test"), (True, "./test", "http://localhost/a/b/test"), @@ -1206,14 +1196,6 @@ def test_malformed_204_response_has_no_content_length(): assert b"".join(app_iter) == b"" # ensure data will not be sent -def test_modified_url_encoding(): - class ModifiedRequest(wrappers.Request): - url_charset = "euc-kr" - - req = ModifiedRequest.from_values(query_string={"foo": "정상처리"}, charset="euc-kr") - assert req.args["foo"] == "정상처리" - - def test_request_method_case_sensitivity(): req = wrappers.Request( {"REQUEST_METHOD": "get", "SERVER_NAME": "eggs", "SERVER_PORT": "80"} @@ -1350,7 +1332,7 @@ def test_bad_content_type(self): value = [1, 2, 3] request = wrappers.Request.from_values(json=value, content_type="text/plain") - with pytest.raises(BadRequest): + with pytest.raises(UnsupportedMediaType): request.get_json() assert request.get_json(silent=True) is None diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index b0f71bc..7f4d2e9 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import io import json import os +import typing as t import pytest @@ -84,7 +87,6 @@ def foo(environ, start_response): def test_path_info_and_script_name_fetching(): env = create_environ("/\N{SNOWMAN}", "http://example.com/\N{COMET}/") assert wsgi.get_path_info(env) == "/\N{SNOWMAN}" - assert wsgi.get_path_info(env, charset=None) == "/\N{SNOWMAN}".encode() def test_limited_stream(): @@ -117,11 +119,10 @@ def on_exhausted(self): stream = wsgi.LimitedStream(io_, 9) assert stream.readlines() == [b"123456\n", b"ab"] - io_ = io.BytesIO(b"123456\nabcdefg") + io_ = io.BytesIO(b"123\n456\nabcdefg") stream = wsgi.LimitedStream(io_, 9) - assert stream.readlines(2) == [b"12"] - assert stream.readlines(2) == [b"34"] - assert stream.readlines() == [b"56\n", b"ab"] + assert stream.readlines(2) == [b"123\n"] + assert stream.readlines() == [b"456\n", b"a"] io_ = io.BytesIO(b"123456\nabcdefg") stream = wsgi.LimitedStream(io_, 9) @@ -146,13 +147,8 @@ def on_exhausted(self): stream = wsgi.LimitedStream(io_, 0) assert stream.read(-1) == b"" - io_ = io.StringIO("123456") - stream = wsgi.LimitedStream(io_, 0) - assert stream.read(-1) == "" - - io_ = io.StringIO("123\n456\n") - stream = wsgi.LimitedStream(io_, 8) - assert list(stream) == ["123\n", "456\n"] + stream = wsgi.LimitedStream(io.BytesIO(b"123\n456\n"), 8) + assert list(stream) == [b"123\n", b"456\n"] def test_limited_stream_json_load(): @@ -165,21 +161,63 @@ def test_limited_stream_json_load(): def test_limited_stream_disconnection(): - io_ = io.BytesIO(b"A bit of content") - - # disconnect detection on out of bytes - stream = wsgi.LimitedStream(io_, 255) + # disconnect because stream returns zero bytes + stream = wsgi.LimitedStream(io.BytesIO(), 255) with pytest.raises(ClientDisconnected): stream.read() - # disconnect detection because file close - io_ = io.BytesIO(b"x" * 255) - io_.close() - stream = wsgi.LimitedStream(io_, 255) + # disconnect because stream is closed + data = io.BytesIO(b"x" * 255) + data.close() + stream = wsgi.LimitedStream(data, 255) + with pytest.raises(ClientDisconnected): stream.read() +def test_limited_stream_read_with_raw_io(): + class OneByteStream(t.BinaryIO): + def __init__(self, buf: bytes) -> None: + self.buf = buf + self.pos = 0 + + def read(self, size: int | None = None) -> bytes: + """Return one byte at a time regardless of requested size.""" + + if size is None or size == -1: + raise ValueError("expected read to be called with specific limit") + + if size == 0 or len(self.buf) < self.pos: + return b"" + + b = self.buf[self.pos : self.pos + 1] + self.pos += 1 + return b + + stream = wsgi.LimitedStream(OneByteStream(b"foo"), 4) + assert stream.read(5) == b"f" + assert stream.read(5) == b"o" + assert stream.read(5) == b"o" + + # The stream has fewer bytes (3) than the limit (4), therefore the read returns 0 + # bytes before the limit is reached. + with pytest.raises(ClientDisconnected): + stream.read(5) + + stream = wsgi.LimitedStream(OneByteStream(b"foo123"), 3) + assert stream.read(5) == b"f" + assert stream.read(5) == b"o" + assert stream.read(5) == b"o" + # The limit was reached, therefore the wrapper is exhausted, not disconnected. + assert stream.read(5) == b"" + + stream = wsgi.LimitedStream(OneByteStream(b"foo"), 3) + assert stream.read() == b"foo" + + stream = wsgi.LimitedStream(OneByteStream(b"foo"), 2) + assert stream.read() == b"fo" + + def test_get_host_fallback(): assert ( wsgi.get_host( @@ -218,123 +256,6 @@ def test_get_current_url_invalid_utf8(): assert rv == "http://localhost/?foo=bar&baz=blah&meh=%CF" -def test_multi_part_line_breaks(): - data = "abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK" - test_stream = io.StringIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16)) - assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"] - - data = "abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz" - test_stream = io.StringIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24)) - assert lines == [ - "abc\r\n", - "This line is broken by the buffer length.\r\n", - "Foo bar baz", - ] - - -def test_multi_part_line_breaks_bytes(): - data = b"abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK" - test_stream = io.BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16)) - assert lines == [ - b"abcdef\r\n", - b"ghijkl\r\n", - b"mnopqrstuvwxyz\r\n", - b"ABCDEFGHIJK", - ] - - data = b"abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz" - test_stream = io.BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24)) - assert lines == [ - b"abc\r\n", - b"This line is broken by the buffer length.\r\n", - b"Foo bar baz", - ] - - -def test_multi_part_line_breaks_problematic(): - data = "abc\rdef\r\nghi" - for _ in range(1, 10): - test_stream = io.StringIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=4)) - assert lines == ["abc\r", "def\r\n", "ghi"] - - -def test_iter_functions_support_iterators(): - data = ["abcdef\r\nghi", "jkl\r\nmnopqrstuvwxyz\r", "\nABCDEFGHIJK"] - lines = list(wsgi.make_line_iter(data)) - assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"] - - -def test_make_chunk_iter(): - data = ["abcdefXghi", "jklXmnopqrstuvwxyzX", "ABCDEFGHIJK"] - rv = list(wsgi.make_chunk_iter(data, "X")) - assert rv == ["abcdef", "ghijkl", "mnopqrstuvwxyz", "ABCDEFGHIJK"] - - data = "abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" - test_stream = io.StringIO(data) - rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4)) - assert rv == ["abcdef", "ghijkl", "mnopqrstuvwxyz", "ABCDEFGHIJK"] - - -def test_make_chunk_iter_bytes(): - data = [b"abcdefXghi", b"jklXmnopqrstuvwxyzX", b"ABCDEFGHIJK"] - rv = list(wsgi.make_chunk_iter(data, "X")) - assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - - data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" - test_stream = io.BytesIO(data) - rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4)) - assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - - data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" - test_stream = io.BytesIO(data) - rv = list( - wsgi.make_chunk_iter( - test_stream, "X", limit=len(data), buffer_size=4, cap_at_buffer=True - ) - ) - assert rv == [ - b"abcd", - b"ef", - b"ghij", - b"kl", - b"mnop", - b"qrst", - b"uvwx", - b"yz", - b"ABCD", - b"EFGH", - b"IJK", - ] - - -def test_lines_longer_buffer_size(): - data = "1234567890\n1234567890\n" - for bufsize in range(1, 15): - lines = list( - wsgi.make_line_iter(io.StringIO(data), limit=len(data), buffer_size=bufsize) - ) - assert lines == ["1234567890\n", "1234567890\n"] - - -def test_lines_longer_buffer_size_cap(): - data = "1234567890\n1234567890\n" - for bufsize in range(1, 15): - lines = list( - wsgi.make_line_iter( - io.StringIO(data), - limit=len(data), - buffer_size=bufsize, - cap_at_buffer=True, - ) - ) - assert len(lines[0]) == bufsize or lines[0].endswith("\n") - - def test_range_wrapper(): response = Response(b"Hello World") range_wrapper = _RangeWrapper(response.response, 6, 4) diff --git a/tox.ini b/tox.ini index 056ca0d..f7bc0b3 100644 --- a/tox.ini +++ b/tox.ini @@ -1,19 +1,24 @@ [tox] envlist = - py3{11,10,9,8,7},pypy3{8,7} + py3{12,11,10,9,8} + pypy310 style typing docs skip_missing_interpreters = true [testenv] +package = wheel +wheel_build_env = .pkg +constrain_package_deps = true +use_frozen_constraints = true deps = -r requirements/tests.txt commands = pytest -v --tb=short --basetemp={envtmpdir} {posargs} [testenv:style] deps = pre-commit skip_install = true -commands = pre-commit run --all-files --show-diff-on-failure +commands = pre-commit run --all-files [testenv:typing] deps = -r requirements/typing.txt @@ -21,4 +26,18 @@ commands = mypy [testenv:docs] deps = -r requirements/docs.txt -commands = sphinx-build -W -b html -d {envtmpdir}/doctrees docs {envtmpdir}/html +commands = sphinx-build -E -W -b dirhtml docs docs/_build/dirhtml + +[testenv:update-requirements] +deps = + pip-tools + pre-commit +skip_install = true +change_dir = requirements +commands = + pre-commit autoupdate -j4 + pip-compile -U build.in + pip-compile -U docs.in + pip-compile -U tests.in + pip-compile -U typing.in + pip-compile -U dev.in