diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml index af9700018..57f926a29 100644 --- a/.github/workflows/bench.yml +++ b/.github/workflows/bench.yml @@ -17,10 +17,10 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 with: - python-version: "3.11" + python-version: "3.12" - name: Install dependencies - run: pip install "numpy>=1.21,<2.0.0" + run: pip install "numpy>=1.23,<2.0.0" - name: Install bench dependencies run: pip install .[bench] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fcf08696e..981f49e0c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,18 +7,22 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.9, "3.10", "3.11"] - numpy: [null, "numpy>=1.21,<2.0.0"] + python-version: ["3.9", "3.10", "3.11", "3.12"] + numpy: [null, "numpy>=1.23,<2.0.0", "numpy>=2.0.0rc1"] uncertainties: [null, "uncertainties==3.1.6", "uncertainties>=3.1.6,<4.0.0"] extras: [null] include: - - python-version: 3.9 # Minimal versions - numpy: "numpy" - extras: matplotlib==2.2.5 - - python-version: 3.9 + - python-version: "3.10" # Minimal versions + numpy: "numpy>=1.23,<2.0.0" + extras: matplotlib==3.5.3 + - python-version: "3.10" numpy: "numpy" uncertainties: "uncertainties" - extras: "sparse xarray netCDF4 dask[complete]==2023.4.0 graphviz babel==2.8 mip>=1.13" + extras: "sparse xarray netCDF4 dask[complete]==2024.5.1 graphviz babel==2.8 mip>=1.13" + - python-version: "3.10" + numpy: "numpy==1.26.1" + uncertainties: null + extras: "babel==2.15 matplotlib==3.9.0" runs-on: ubuntu-latest env: @@ -100,8 +104,8 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.9, "3.10", "3.11"] - numpy: [ "numpy>=1.21,<2.0.0" ] + python-version: ["3.10", "3.11", "3.12"] + numpy: [ "numpy>=1.23,<2.0.0" ] runs-on: windows-latest env: @@ -161,8 +165,8 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.9, "3.10", "3.11"] - numpy: [null, "numpy>=1.21,<2.0.0" ] + python-version: ["3.10", "3.11", "3.12"] + numpy: [null, "numpy>=1.23,<2.0.0" ] runs-on: macos-latest env: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5f17aba71..8ebea5e60 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,7 +17,7 @@ jobs: - name: Set up minimal Python version uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: "3.10" - name: Get pip cache dir id: pip-cache diff --git a/.gitignore b/.gitignore index ae702bac3..69fd3338d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,8 @@ MANIFEST .mypy_cache pip-wheel-metadata pint/testsuite/dask-worker-space +venv +.envrc # WebDAV file system cache files .DAV/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a4a3f4aa9..75bfa6297 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,21 @@ exclude: '^pint/_vendor' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace -- repo: https://github.com/psf/black - rev: 23.1.0 - hooks: - - id: black - - id: black-jupyter -- repo: https://github.com/charliermarsh/ruff-pre-commit - rev: 'v0.0.240' +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.7 hooks: - id: ruff - args: ["--fix"] + args: ["--fix", "--show-fixes"] + types_or: [ python, pyi, jupyter ] + - id: ruff-format + types_or: [ python, pyi, jupyter ] - repo: https://github.com/executablebooks/mdformat - rev: 0.7.16 + rev: 0.7.17 hooks: - id: mdformat additional_dependencies: diff --git a/CHANGES b/CHANGES index 6b1d0484e..53e5a5fc5 100644 --- a/CHANGES +++ b/CHANGES @@ -1,10 +1,49 @@ Pint Changelog ============== -0.24 (unreleased) +0.25 (unreleased) ----------------- -- Nothing changed yet. +- Nothing added yet. + + +0.24.1 (2024-06-24) +----------------- + +- Fix custom formatter needing the registry object. (PR #2011) +- Support python 3.9 following difficulties installing with NumPy 2. (PR #2019) +- Fix default formatting of dimensionless unit issue. (PR #2012) +- Fix bug preventing custom formatters with modifiers working. (PR #2021) + + +0.24 (2024-06-07) +----------------- + +- Fix detection of invalid conversion between offset and delta units. (PR #1905) +- Added dBW, decibel Watts, which is used in RF high power applications +- NumPy 2.0 support + (PR #1985, #1971) +- Implement numpy roll (Related to issue #981) +- Implement numpy correlate + (PR #1990) +- Add `dim_sort` function to _formatter_helpers. +- Add `dim_order` and `default_sort_func` properties to FullFormatter. + (PR #1926, fixes Issue #1841) +- Documented packages using pint. + (PR #1960) +- Fixed bug causing operations between arrays of quantity scalars and quantity holding + array resulting in incorrect units. + (PR #1677) +- Fix LaTeX siuntix formatting when using non_int_type=decimal.Decimal. + (PR #1977) +- Added refractive index units. + (PR #1816) +- Fix converting to offset units of higher dimension e.g. gauge pressure + (PR #1949) +- Fix unhandled TypeError when auto_reduce_dimensions=True and non_int_type=Decimal + (PR #1853) +- Improved error message in `get_dimensionality()` when non existent units are passed. + (PR #1874, Issue #1716) 0.23 (2023-12-08) @@ -35,6 +74,7 @@ Pint Changelog - Add numpy.linalg.norm implementation. (PR #1251) + 0.22 (2023-05-25) ----------------- diff --git a/README.rst b/README.rst index 89f19f474..3c16a4541 100644 --- a/README.rst +++ b/README.rst @@ -2,8 +2,13 @@ :target: https://pypi.python.org/pypi/pint :alt: Latest Version -.. image:: https://img.shields.io/badge/code%20style-black-000000.svg - :target: https://github.com/python/black +.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json + :target: https://github.com/astral-sh/ruff + :alt: Ruff + +.. image:: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/format.json + :target: https://github.com/astral-sh/ruff + :alt: Ruff-Format .. image:: https://readthedocs.org/projects/pint/badge/ :target: https://pint.readthedocs.org/ diff --git a/docs/advanced/currencies.rst b/docs/advanced/currencies.rst index 26b66b531..addc94785 100644 --- a/docs/advanced/currencies.rst +++ b/docs/advanced/currencies.rst @@ -84,3 +84,16 @@ currency on its own dimension, and then implement transformations:: More sophisticated formulas, e.g. dealing with flat fees and thresholds, can be implemented with arbitrary python code by programmatically defining a context (see :ref:`contexts`). + +Currency Symbols +---------------- + +Many common currency symbols are not supported by the pint parser. A preprocessor can be used as a workaround: + +.. doctest:: + + >>> import pint + >>> ureg = pint.UnitRegistry(preprocessors = [lambda s: s.replace("€", "EUR")]) + >>> ureg.define("euro = [currency] = € = EUR") + >>> print(ureg.Quantity("1 €")) + 1 euro diff --git a/docs/advanced/pitheorem.rst b/docs/advanced/pitheorem.rst index cd3716528..06409d8b5 100644 --- a/docs/advanced/pitheorem.rst +++ b/docs/advanced/pitheorem.rst @@ -33,8 +33,10 @@ Which can be pretty printed using the `Pint` formatter: >>> from pint import formatter >>> result = pi_theorem({'V': '[length]/[time]', 'T': '[time]', 'L': '[length]'}) - >>> print(formatter(result[0].items())) - T * V / L + >>> numerator = [item for item in result[0].items() if item[1]>0] + >>> denominator = [item for item in result[0].items() if item[1]<0] + >>> print(formatter(numerator, denominator)) + V * T / L You can also apply the Buckingham π theorem associated to a Registry. In this case, you can use derived dimensions such as speed: diff --git a/docs/api/facets.rst b/docs/api/facets.rst index f4b6a54e8..d835f5cea 100644 --- a/docs/api/facets.rst +++ b/docs/api/facets.rst @@ -16,7 +16,7 @@ The default UnitRegistry inherits from all of them. :members: :exclude-members: Quantity, Unit, Measurement, Group, Context, System -.. automodule:: pint.facets.formatting +.. automodule:: pint.delegates.formatter :members: :exclude-members: Quantity, Unit, Measurement, Group, Context, System diff --git a/docs/conf.py b/docs/conf.py index ee74481f8..d856e1075 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,6 +10,7 @@ # # All configuration values have a default; values that are commented out # serve to show the default. +from __future__ import annotations import datetime from importlib.metadata import version diff --git a/docs/ecosystem.rst b/docs/ecosystem.rst index 7610fd019..c83c52f49 100644 --- a/docs/ecosystem.rst +++ b/docs/ecosystem.rst @@ -7,5 +7,18 @@ Here is a list of known projects, packages and integrations using pint. Pint integrations: ------------------ +- `ucumvert `_ `UCUM `_ (Unified Code for Units of Measure) integration - `pint-pandas `_ Pandas integration - `pint-xarray `_ Xarray integration + + +Packages using pint: +------------------ + +- `fluids `_ Practical fluid dynamics calculations +- `ht `_ Practical heat transfer calculations +- `chemicals `_ Chemical property calculations and lookups +- `thermo `_ Thermodynamic equilibrium calculations +- `Taurus `_ Control system UI creation +- `InstrumentKit `_ Interacting with laboratory equipment over various buses. +- `NEMO `_ Electricity production cost model diff --git a/docs/getting/tutorial.rst b/docs/getting/tutorial.rst index bb3505b51..d675860f2 100644 --- a/docs/getting/tutorial.rst +++ b/docs/getting/tutorial.rst @@ -428,7 +428,7 @@ If Babel_ is installed you can translate unit names to any language .. doctest:: >>> ureg.formatter.format_quantity(accel, locale='fr_FR') - '1,3 mètres/secondes²' + '1,3 mètres par seconde²' You can also specify the format locale at the registry level either at creation: @@ -449,11 +449,11 @@ and by doing that, string formatting is now localized: >>> ureg.default_format = 'P' >>> accel = 1.3 * ureg.parse_units('meter/second**2') >>> str(accel) - '1,3 mètres/secondes²' + '1,3 mètres par seconde²' >>> "%s" % accel - '1,3 mètres/secondes²' + '1,3 mètres par seconde²' >>> "{}".format(accel) - '1,3 mètres/secondes²' + '1,3 mètres par seconde²' If you want to customize string formatting, take a look at :ref:`formatting`. diff --git a/docs/user/angular_frequency.rst b/docs/user/angular_frequency.rst index 4fbb7bdce..61bdf1614 100644 --- a/docs/user/angular_frequency.rst +++ b/docs/user/angular_frequency.rst @@ -1,12 +1,43 @@ .. _angular_frequency: +Angles and Angular Frequency +============================= + +Angles +------ + +pint treats angle quantities as `dimensionless`, following the conventions of SI. The base unit for angle is the `radian`. +The SI BIPM Brochure (Bureau International des Poids et Mesures) states: + +.. note:: + + Plane and solid angles, when expressed in radians and steradians respectively, are in effect + also treated within the SI as quantities with the unit one (see section 5.4.8). The symbols rad + and sr are written explicitly where appropriate, in order to emphasize that, for radians or + steradians, the quantity being considered is, or involves the plane angle or solid angle + respectively. For steradians it emphasizes the distinction between units of flux and intensity + in radiometry and photometry for example. However, it is a long-established practice in + mathematics and across all areas of science to make use of rad = 1 and sr = 1. + + +This leads to behavior some users may find unintuitive. For example, since angles have no dimensionality, it is not possible to check whether a quantity has an angle dimension. + +.. code-block:: python + + >>> import pint + >>> ureg = pint.UnitRegistry() + >>> angle = ureg('1 rad') + >>> angle.dimensionality + + + Angular Frequency -================= +----------------- `Hertz` is a unit for frequency, that is often also used for angular frequency. For example, a shaft spinning at `60 revolutions per minute` will often be said to spin at `1 Hz`, rather than `1 revolution per second`. -By default, pint treats angle quantities as `dimensionless`, so allows conversions between frequencies and angular frequencies. The base unit for angle is the `radian`. This leads to some unintuitive behaviour, as pint will convert angular frequencies into frequencies by converting angles into `radians`, rather than `revolutions`. This leads to converted values `2 * pi` larger than expected: +Since pint treats angle quantities as `dimensionless`, it allows conversions between frequencies and angular frequencies. This leads to some unintuitive behaviour, as pint will convert angular frequencies into frequencies by converting angles into `radians`, rather than `revolutions`. This leads to converted values `2 * pi` larger than expected: .. code-block:: python @@ -16,7 +47,7 @@ By default, pint treats angle quantities as `dimensionless`, so allows conversio >>> angular_frequency.to('Hz') -pint follows the conventions of SI. The SI BIPM Brochure (Bureau International des Poids et Mesures) states: +The SI BIPM Brochure (Bureau International des Poids et Mesures) states: .. note:: diff --git a/docs/user/defining-quantities.rst b/docs/user/defining-quantities.rst index e40b08cf9..a7405151a 100644 --- a/docs/user/defining-quantities.rst +++ b/docs/user/defining-quantities.rst @@ -134,7 +134,7 @@ For example, the units of .. doctest:: >>> Q_('3 l / 100 km') - + may be unexpected at first but, are a consequence of applying this rule. Use brackets to get the expected result: diff --git a/docs/user/formatting.rst b/docs/user/formatting.rst index f17939a86..fbf2fae42 100644 --- a/docs/user/formatting.rst +++ b/docs/user/formatting.rst @@ -95,10 +95,11 @@ formats: ... def format_unit_simple(unit, registry, **options): ... return " * ".join(f"{u} ** {p}" for u, p in unit.items()) >>> f"{q:Z}" - '2.3e-06 meter ** 3 * second ** -2 * kilogram ** -1' + '2.3e-06 kilogram ** -1 * meter ** 3 * second ** -2' where ``unit`` is a :py:class:`dict` subclass containing the unit names and -their exponents. +their exponents, ``registry`` is the current instance of :py:class:``UnitRegistry`` and +``options`` is not yet implemented. You can choose to replace the complete formatter. Briefly, the formatter if an object with the following methods: `format_magnitude`, `format_unit`, `format_quantity`, `format_uncertainty`, @@ -111,10 +112,11 @@ following methods: `format_magnitude`, `format_unit`, `format_quantity`, `format ... ... default_format = "" ... - ... def format_unit(self, unit, uspec: str = "", **babel_kwds) -> str: + ... def format_unit(self, unit, uspec, sort_func, **babel_kwds) -> str: ... return "ups!" ... >>> ureg.formatter = MyFormatter() + >>> ureg.formatter._registry = ureg >>> str(q) '2.3e-06 ups!' diff --git a/docs/user/log_units.rst b/docs/user/log_units.rst index 03e007914..096397350 100644 --- a/docs/user/log_units.rst +++ b/docs/user/log_units.rst @@ -111,16 +111,16 @@ will not work: .. doctest:: >>> -161.0 * ureg('dBm/Hz') == (-161.0 * ureg.dBm / ureg.Hz) - False + np.False_ But this will: .. doctest:: >>> ureg('-161.0 dBm/Hz') == (-161.0 * ureg.dBm / ureg.Hz) - True + np.True_ >>> Q_(-161.0, 'dBm') / ureg.Hz == (-161.0 * ureg.dBm / ureg.Hz) - True + np.True_ To begin using this feature while avoiding problems, define logarithmic units as single-unit quantities and convert them to their base units as quickly as diff --git a/docs/user/numpy.ipynb b/docs/user/numpy.ipynb index 54910018e..0b1b22197 100644 --- a/docs/user/numpy.ipynb +++ b/docs/user/numpy.ipynb @@ -33,6 +33,8 @@ "outputs": [], "source": [ "# Import NumPy\n", + "from __future__ import annotations\n", + "\n", "import numpy as np\n", "\n", "# Import Pint\n", diff --git a/pint/__init__.py b/pint/__init__.py index 127a45ca6..abfef2703 100644 --- a/pint/__init__.py +++ b/pint/__init__.py @@ -16,7 +16,6 @@ from importlib.metadata import version from .delegates.formatter._format_helpers import formatter - from .errors import ( # noqa: F401 DefinitionSyntaxError, DimensionalityError, @@ -31,7 +30,6 @@ from .registry import ApplicationRegistry, LazyRegistry, UnitRegistry from .util import logger, pi_theorem # noqa: F401 - # Default Quantity, Unit and Measurement are the ones # build in the default registry. Quantity = UnitRegistry.Quantity diff --git a/pint/_typing.py b/pint/_typing.py index 7a67efc45..241459ef1 100644 --- a/pint/_typing.py +++ b/pint/_typing.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, Protocol +from collections.abc import Callable from decimal import Decimal from fractions import Fraction +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, Union -from .compat import TypeAlias, Never +from .compat import Never, TypeAlias if TYPE_CHECKING: from .facets.plain import PlainQuantity as Quantity diff --git a/pint/_vendor/__init__.py b/pint/_vendor/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pint/_vendor/appdirs.py b/pint/_vendor/appdirs.py deleted file mode 100644 index c32636a1a..000000000 --- a/pint/_vendor/appdirs.py +++ /dev/null @@ -1,608 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright (c) 2005-2010 ActiveState Software Inc. -# Copyright (c) 2013 Eddy Petrișor - -"""Utilities for determining application-specific dirs. - -See for details and usage. -""" -# Dev Notes: -# - MSDN on where to store app data files: -# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120 -# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html -# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html - -__version__ = "1.4.4" -__version_info__ = tuple(int(segment) for segment in __version__.split(".")) - - -import sys -import os - -PY3 = sys.version_info[0] == 3 - -if PY3: - unicode = str - -if sys.platform.startswith('java'): - import platform - os_name = platform.java_ver()[3][0] - if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc. - system = 'win32' - elif os_name.startswith('Mac'): # "Mac OS X", etc. - system = 'darwin' - else: # "Linux", "SunOS", "FreeBSD", etc. - # Setting this to "linux2" is not ideal, but only Windows or Mac - # are actually checked for and the rest of the module expects - # *sys.platform* style strings. - system = 'linux2' -else: - system = sys.platform - - - -def user_data_dir(appname=None, appauthor=None, version=None, roaming=False): - r"""Return full path to the user-specific data dir for this application. - - "appname" is the name of application. - If None, just the system directory is returned. - "appauthor" (only used on Windows) is the name of the - appauthor or distributing body for this application. Typically - it is the owning company name. This falls back to appname. You may - pass False to disable it. - "version" is an optional version path element to append to the - path. You might want to use this if you want multiple versions - of your app to be able to run independently. If used, this - would typically be ".". - Only applied when appname is present. - "roaming" (boolean, default False) can be set True to use the Windows - roaming appdata directory. That means that for users on a Windows - network setup for roaming profiles, this user data will be - sync'd on login. See - - for a discussion of issues. - - Typical user data directories are: - Mac OS X: ~/Library/Application Support/ - Unix: ~/.local/share/ # or in $XDG_DATA_HOME, if defined - Win XP (not roaming): C:\Documents and Settings\\Application Data\\ - Win XP (roaming): C:\Documents and Settings\\Local Settings\Application Data\\ - Win 7 (not roaming): C:\Users\\AppData\Local\\ - Win 7 (roaming): C:\Users\\AppData\Roaming\\ - - For Unix, we follow the XDG spec and support $XDG_DATA_HOME. - That means, by default "~/.local/share/". - """ - if system == "win32": - if appauthor is None: - appauthor = appname - const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA" - path = os.path.normpath(_get_win_folder(const)) - if appname: - if appauthor is not False: - path = os.path.join(path, appauthor, appname) - else: - path = os.path.join(path, appname) - elif system == 'darwin': - path = os.path.expanduser('~/Library/Application Support/') - if appname: - path = os.path.join(path, appname) - else: - path = os.getenv('XDG_DATA_HOME', os.path.expanduser("~/.local/share")) - if appname: - path = os.path.join(path, appname) - if appname and version: - path = os.path.join(path, version) - return path - - -def site_data_dir(appname=None, appauthor=None, version=None, multipath=False): - r"""Return full path to the user-shared data dir for this application. - - "appname" is the name of application. - If None, just the system directory is returned. - "appauthor" (only used on Windows) is the name of the - appauthor or distributing body for this application. Typically - it is the owning company name. This falls back to appname. You may - pass False to disable it. - "version" is an optional version path element to append to the - path. You might want to use this if you want multiple versions - of your app to be able to run independently. If used, this - would typically be ".". - Only applied when appname is present. - "multipath" is an optional parameter only applicable to *nix - which indicates that the entire list of data dirs should be - returned. By default, the first item from XDG_DATA_DIRS is - returned, or '/usr/local/share/', - if XDG_DATA_DIRS is not set - - Typical site data directories are: - Mac OS X: /Library/Application Support/ - Unix: /usr/local/share/ or /usr/share/ - Win XP: C:\Documents and Settings\All Users\Application Data\\ - Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) - Win 7: C:\ProgramData\\ # Hidden, but writeable on Win 7. - - For Unix, this is using the $XDG_DATA_DIRS[0] default. - - WARNING: Do not use this on Windows. See the Vista-Fail note above for why. - """ - if system == "win32": - if appauthor is None: - appauthor = appname - path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA")) - if appname: - if appauthor is not False: - path = os.path.join(path, appauthor, appname) - else: - path = os.path.join(path, appname) - elif system == 'darwin': - path = os.path.expanduser('/Library/Application Support') - if appname: - path = os.path.join(path, appname) - else: - # XDG default for $XDG_DATA_DIRS - # only first, if multipath is False - path = os.getenv('XDG_DATA_DIRS', - os.pathsep.join(['/usr/local/share', '/usr/share'])) - pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)] - if appname: - if version: - appname = os.path.join(appname, version) - pathlist = [os.sep.join([x, appname]) for x in pathlist] - - if multipath: - path = os.pathsep.join(pathlist) - else: - path = pathlist[0] - return path - - if appname and version: - path = os.path.join(path, version) - return path - - -def user_config_dir(appname=None, appauthor=None, version=None, roaming=False): - r"""Return full path to the user-specific config dir for this application. - - "appname" is the name of application. - If None, just the system directory is returned. - "appauthor" (only used on Windows) is the name of the - appauthor or distributing body for this application. Typically - it is the owning company name. This falls back to appname. You may - pass False to disable it. - "version" is an optional version path element to append to the - path. You might want to use this if you want multiple versions - of your app to be able to run independently. If used, this - would typically be ".". - Only applied when appname is present. - "roaming" (boolean, default False) can be set True to use the Windows - roaming appdata directory. That means that for users on a Windows - network setup for roaming profiles, this user data will be - sync'd on login. See - - for a discussion of issues. - - Typical user config directories are: - Mac OS X: same as user_data_dir - Unix: ~/.config/ # or in $XDG_CONFIG_HOME, if defined - Win *: same as user_data_dir - - For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME. - That means, by default "~/.config/". - """ - if system in ["win32", "darwin"]: - path = user_data_dir(appname, appauthor, None, roaming) - else: - path = os.getenv('XDG_CONFIG_HOME', os.path.expanduser("~/.config")) - if appname: - path = os.path.join(path, appname) - if appname and version: - path = os.path.join(path, version) - return path - - -def site_config_dir(appname=None, appauthor=None, version=None, multipath=False): - r"""Return full path to the user-shared data dir for this application. - - "appname" is the name of application. - If None, just the system directory is returned. - "appauthor" (only used on Windows) is the name of the - appauthor or distributing body for this application. Typically - it is the owning company name. This falls back to appname. You may - pass False to disable it. - "version" is an optional version path element to append to the - path. You might want to use this if you want multiple versions - of your app to be able to run independently. If used, this - would typically be ".". - Only applied when appname is present. - "multipath" is an optional parameter only applicable to *nix - which indicates that the entire list of config dirs should be - returned. By default, the first item from XDG_CONFIG_DIRS is - returned, or '/etc/xdg/', if XDG_CONFIG_DIRS is not set - - Typical site config directories are: - Mac OS X: same as site_data_dir - Unix: /etc/xdg/ or $XDG_CONFIG_DIRS[i]/ for each value in - $XDG_CONFIG_DIRS - Win *: same as site_data_dir - Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) - - For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False - - WARNING: Do not use this on Windows. See the Vista-Fail note above for why. - """ - if system in ["win32", "darwin"]: - path = site_data_dir(appname, appauthor) - if appname and version: - path = os.path.join(path, version) - else: - # XDG default for $XDG_CONFIG_DIRS - # only first, if multipath is False - path = os.getenv('XDG_CONFIG_DIRS', '/etc/xdg') - pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)] - if appname: - if version: - appname = os.path.join(appname, version) - pathlist = [os.sep.join([x, appname]) for x in pathlist] - - if multipath: - path = os.pathsep.join(pathlist) - else: - path = pathlist[0] - return path - - -def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True): - r"""Return full path to the user-specific cache dir for this application. - - "appname" is the name of application. - If None, just the system directory is returned. - "appauthor" (only used on Windows) is the name of the - appauthor or distributing body for this application. Typically - it is the owning company name. This falls back to appname. You may - pass False to disable it. - "version" is an optional version path element to append to the - path. You might want to use this if you want multiple versions - of your app to be able to run independently. If used, this - would typically be ".". - Only applied when appname is present. - "opinion" (boolean) can be False to disable the appending of - "Cache" to the plain app data dir for Windows. See - discussion below. - - Typical user cache directories are: - Mac OS X: ~/Library/Caches/ - Unix: ~/.cache/ (XDG default) - Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Cache - Vista: C:\Users\\AppData\Local\\\Cache - - On Windows the only suggestion in the MSDN docs is that local settings go in - the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming - app data dir (the default returned by `user_data_dir` above). Apps typically - put cache data somewhere *under* the given dir here. Some examples: - ...\Mozilla\Firefox\Profiles\\Cache - ...\Acme\SuperApp\Cache\1.0 - OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value. - This can be disabled with the `opinion=False` option. - """ - if system == "win32": - if appauthor is None: - appauthor = appname - path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA")) - if appname: - if appauthor is not False: - path = os.path.join(path, appauthor, appname) - else: - path = os.path.join(path, appname) - if opinion: - path = os.path.join(path, "Cache") - elif system == 'darwin': - path = os.path.expanduser('~/Library/Caches') - if appname: - path = os.path.join(path, appname) - else: - path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache')) - if appname: - path = os.path.join(path, appname) - if appname and version: - path = os.path.join(path, version) - return path - - -def user_state_dir(appname=None, appauthor=None, version=None, roaming=False): - r"""Return full path to the user-specific state dir for this application. - - "appname" is the name of application. - If None, just the system directory is returned. - "appauthor" (only used on Windows) is the name of the - appauthor or distributing body for this application. Typically - it is the owning company name. This falls back to appname. You may - pass False to disable it. - "version" is an optional version path element to append to the - path. You might want to use this if you want multiple versions - of your app to be able to run independently. If used, this - would typically be ".". - Only applied when appname is present. - "roaming" (boolean, default False) can be set True to use the Windows - roaming appdata directory. That means that for users on a Windows - network setup for roaming profiles, this user data will be - sync'd on login. See - - for a discussion of issues. - - Typical user state directories are: - Mac OS X: same as user_data_dir - Unix: ~/.local/state/ # or in $XDG_STATE_HOME, if defined - Win *: same as user_data_dir - - For Unix, we follow this Debian proposal - to extend the XDG spec and support $XDG_STATE_HOME. - - That means, by default "~/.local/state/". - """ - if system in ["win32", "darwin"]: - path = user_data_dir(appname, appauthor, None, roaming) - else: - path = os.getenv('XDG_STATE_HOME', os.path.expanduser("~/.local/state")) - if appname: - path = os.path.join(path, appname) - if appname and version: - path = os.path.join(path, version) - return path - - -def user_log_dir(appname=None, appauthor=None, version=None, opinion=True): - r"""Return full path to the user-specific log dir for this application. - - "appname" is the name of application. - If None, just the system directory is returned. - "appauthor" (only used on Windows) is the name of the - appauthor or distributing body for this application. Typically - it is the owning company name. This falls back to appname. You may - pass False to disable it. - "version" is an optional version path element to append to the - path. You might want to use this if you want multiple versions - of your app to be able to run independently. If used, this - would typically be ".". - Only applied when appname is present. - "opinion" (boolean) can be False to disable the appending of - "Logs" to the plain app data dir for Windows, and "log" to the - plain cache dir for Unix. See discussion below. - - Typical user log directories are: - Mac OS X: ~/Library/Logs/ - Unix: ~/.cache//log # or under $XDG_CACHE_HOME if defined - Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Logs - Vista: C:\Users\\AppData\Local\\\Logs - - On Windows the only suggestion in the MSDN docs is that local settings - go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in - examples of what some windows apps use for a logs dir.) - - OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA` - value for Windows and appends "log" to the user cache dir for Unix. - This can be disabled with the `opinion=False` option. - """ - if system == "darwin": - path = os.path.join( - os.path.expanduser('~/Library/Logs'), - appname) - elif system == "win32": - path = user_data_dir(appname, appauthor, version) - version = False - if opinion: - path = os.path.join(path, "Logs") - else: - path = user_cache_dir(appname, appauthor, version) - version = False - if opinion: - path = os.path.join(path, "log") - if appname and version: - path = os.path.join(path, version) - return path - - -class AppDirs(object): - """Convenience wrapper for getting application dirs.""" - def __init__(self, appname=None, appauthor=None, version=None, - roaming=False, multipath=False): - self.appname = appname - self.appauthor = appauthor - self.version = version - self.roaming = roaming - self.multipath = multipath - - @property - def user_data_dir(self): - return user_data_dir(self.appname, self.appauthor, - version=self.version, roaming=self.roaming) - - @property - def site_data_dir(self): - return site_data_dir(self.appname, self.appauthor, - version=self.version, multipath=self.multipath) - - @property - def user_config_dir(self): - return user_config_dir(self.appname, self.appauthor, - version=self.version, roaming=self.roaming) - - @property - def site_config_dir(self): - return site_config_dir(self.appname, self.appauthor, - version=self.version, multipath=self.multipath) - - @property - def user_cache_dir(self): - return user_cache_dir(self.appname, self.appauthor, - version=self.version) - - @property - def user_state_dir(self): - return user_state_dir(self.appname, self.appauthor, - version=self.version) - - @property - def user_log_dir(self): - return user_log_dir(self.appname, self.appauthor, - version=self.version) - - -#---- internal support stuff - -def _get_win_folder_from_registry(csidl_name): - """This is a fallback technique at best. I'm not sure if using the - registry for this guarantees us the correct answer for all CSIDL_* - names. - """ - if PY3: - import winreg as _winreg - else: - import _winreg - - shell_folder_name = { - "CSIDL_APPDATA": "AppData", - "CSIDL_COMMON_APPDATA": "Common AppData", - "CSIDL_LOCAL_APPDATA": "Local AppData", - }[csidl_name] - - key = _winreg.OpenKey( - _winreg.HKEY_CURRENT_USER, - r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" - ) - dir, type = _winreg.QueryValueEx(key, shell_folder_name) - return dir - - -def _get_win_folder_with_pywin32(csidl_name): - from win32com.shell import shellcon, shell - dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0) - # Try to make this a unicode path because SHGetFolderPath does - # not return unicode strings when there is unicode data in the - # path. - try: - dir = unicode(dir) - - # Downgrade to short path name if have highbit chars. See - # . - has_high_char = False - for c in dir: - if ord(c) > 255: - has_high_char = True - break - if has_high_char: - try: - import win32api - dir = win32api.GetShortPathName(dir) - except ImportError: - pass - except UnicodeError: - pass - return dir - - -def _get_win_folder_with_ctypes(csidl_name): - import ctypes - - csidl_const = { - "CSIDL_APPDATA": 26, - "CSIDL_COMMON_APPDATA": 35, - "CSIDL_LOCAL_APPDATA": 28, - }[csidl_name] - - buf = ctypes.create_unicode_buffer(1024) - ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf) - - # Downgrade to short path name if have highbit chars. See - # . - has_high_char = False - for c in buf: - if ord(c) > 255: - has_high_char = True - break - if has_high_char: - buf2 = ctypes.create_unicode_buffer(1024) - if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024): - buf = buf2 - - return buf.value - -def _get_win_folder_with_jna(csidl_name): - import array - from com.sun import jna - from com.sun.jna.platform import win32 - - buf_size = win32.WinDef.MAX_PATH * 2 - buf = array.zeros('c', buf_size) - shell = win32.Shell32.INSTANCE - shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf) - dir = jna.Native.toString(buf.tostring()).rstrip("\0") - - # Downgrade to short path name if have highbit chars. See - # . - has_high_char = False - for c in dir: - if ord(c) > 255: - has_high_char = True - break - if has_high_char: - buf = array.zeros('c', buf_size) - kernel = win32.Kernel32.INSTANCE - if kernel.GetShortPathName(dir, buf, buf_size): - dir = jna.Native.toString(buf.tostring()).rstrip("\0") - - return dir - -if system == "win32": - try: - import win32com.shell - _get_win_folder = _get_win_folder_with_pywin32 - except ImportError: - try: - from ctypes import windll - _get_win_folder = _get_win_folder_with_ctypes - except ImportError: - try: - import com.sun.jna - _get_win_folder = _get_win_folder_with_jna - except ImportError: - _get_win_folder = _get_win_folder_from_registry - - -#---- self test code - -if __name__ == "__main__": - appname = "MyApp" - appauthor = "MyCompany" - - props = ("user_data_dir", - "user_config_dir", - "user_cache_dir", - "user_state_dir", - "user_log_dir", - "site_data_dir", - "site_config_dir") - - print("-- app dirs %s --" % __version__) - - print("-- app dirs (with optional 'version')") - dirs = AppDirs(appname, appauthor, version="1.0") - for prop in props: - print("%s: %s" % (prop, getattr(dirs, prop))) - - print("\n-- app dirs (without optional 'version')") - dirs = AppDirs(appname, appauthor) - for prop in props: - print("%s: %s" % (prop, getattr(dirs, prop))) - - print("\n-- app dirs (without optional 'appauthor')") - dirs = AppDirs(appname) - for prop in props: - print("%s: %s" % (prop, getattr(dirs, prop))) - - print("\n-- app dirs (with disabled 'appauthor')") - dirs = AppDirs(appname, appauthor=False) - for prop in props: - print("%s: %s" % (prop, getattr(dirs, prop))) diff --git a/pint/_vendor/flexcache.py b/pint/_vendor/flexcache.py deleted file mode 100644 index 7b3969846..000000000 --- a/pint/_vendor/flexcache.py +++ /dev/null @@ -1,427 +0,0 @@ -""" - flexcache.flexcache - ~~~~~~~~~~~~~~~~~~~ - - Classes for persistent caching and invalidating cached objects, - which are built from a source object and a (potentially expensive) - conversion function. - - Header - ------ - Contains summary information about the source object that will - be saved together with the cached file. - - It's capabilities are divided in three groups: - - The Header itself which contains the information that will - be saved alongside the cached file - - The Naming logic which indicates how the cached filename is - built. - - The Invalidation logic which indicates whether a cached file - is valid (i.e. truthful to the actual source file). - - DiskCache - --------- - Saves and loads to the cache a transformed versions of a source object. - - :copyright: 2022 by flexcache Authors, see AUTHORS for more details. - :license: BSD, see LICENSE for more details. -""" - -from __future__ import annotations - -import abc -import hashlib -import json -import pathlib -import pickle -import platform -import typing -from dataclasses import asdict as dc_asdict -from dataclasses import dataclass -from dataclasses import fields as dc_fields -from typing import Any, Iterable - -######### -# Header -######### - - -@dataclass(frozen=True) -class BaseHeader(abc.ABC): - """Header with no information except the converter_id - - All header files must inherit from this. - """ - - # The actual source of the data (or a reference to it) - # that is going to be converted. - source: Any - - # An identification of the function that is used to - # convert the source into the result object. - converter_id: str - - _source_type = object - - def __post_init__(self): - # TODO: In more modern python versions it would be - # good to check for things like tuple[str]. - if not isinstance(self.source, self._source_type): - raise TypeError( - f"Source must be {self._source_type}, " f"not {type(self.source)}" - ) - - def for_cache_name(self) -> typing.Generator[bytes]: - """The basename for the cache file is a hash hexdigest - built by feeding this collection of values. - - A class can provide it's own set of values by rewriting - `_for_cache_name`. - """ - for el in self._for_cache_name(): - if isinstance(el, str): - yield el.encode("utf-8") - else: - yield el - - def _for_cache_name(self) -> typing.Generator[bytes | str]: - """The basename for the cache file is a hash hexdigest - built by feeding this collection of values. - - Change the behavior by writing your own. - """ - yield self.converter_id - - @abc.abstractmethod - def is_valid(self, cache_path: pathlib.Path) -> bool: - """Return True if the cache_path is an cached version - of the source_object represented by this header. - """ - - -@dataclass(frozen=True) -class BasicPythonHeader(BaseHeader): - """Header with basic Python information.""" - - system: str = platform.system() - python_implementation: str = platform.python_implementation() - python_version: str = platform.python_version() - - -##################### -# Invalidation logic -##################### - - -class InvalidateByExist: - """The cached file is valid if exists and is newer than the source file.""" - - def is_valid(self, cache_path: pathlib.Path) -> bool: - return cache_path.exists() - - -class InvalidateByPathMTime(abc.ABC): - """The cached file is valid if exists and is newer than the source file.""" - - @property - @abc.abstractmethod - def source_path(self) -> pathlib.Path: - ... - - def is_valid(self, cache_path: pathlib.Path): - return ( - cache_path.exists() - and cache_path.stat().st_mtime > self.source_path.stat().st_mtime - ) - - -class InvalidateByMultiPathsMtime(abc.ABC): - """The cached file is valid if exists and is newer than the newest source file.""" - - @property - @abc.abstractmethod - def source_paths(self) -> pathlib.Path: - ... - - @property - def newest_date(self): - return max((t.stat().st_mtime for t in self.source_paths), default=0) - - def is_valid(self, cache_path: pathlib.Path): - return cache_path.exists() and cache_path.stat().st_mtime > self.newest_date - - -############### -# Naming logic -############### - - -class NameByFields: - """Name is built taking into account all fields in the Header - (except the source itself). - """ - - def _for_cache_name(self): - yield from super()._for_cache_name() - for field in dc_fields(self): - if field.name not in ("source", "converter_id"): - yield getattr(self, field.name) - - -class NameByFileContent: - """Given a file source object, the name is built from its content.""" - - _source_type = pathlib.Path - - @property - def source_path(self) -> pathlib.Path: - return self.source - - def _for_cache_name(self): - yield from super()._for_cache_name() - yield self.source_path.read_bytes() - - @classmethod - def from_string(cls, s: str, converter_id: str): - return cls(pathlib.Path(s), converter_id) - - -@dataclass(frozen=True) -class NameByObj: - """Given a pickable source object, the name is built from its content.""" - - pickle_protocol: int = pickle.HIGHEST_PROTOCOL - - def _for_cache_name(self): - yield from super()._for_cache_name() - yield pickle.dumps(self.source, protocol=self.pickle_protocol) - - -class NameByPath: - """Given a file source object, the name is built from its resolved path.""" - - _source_type = pathlib.Path - - @property - def source_path(self) -> pathlib.Path: - return self.source - - def _for_cache_name(self): - yield from super()._for_cache_name() - yield bytes(self.source_path.resolve()) - - @classmethod - def from_string(cls, s: str, converter_id: str): - return cls(pathlib.Path(s), converter_id) - - -class NameByMultiPaths: - """Given multiple file source object, the name is built from their resolved path - in ascending order. - """ - - _source_type = tuple - - @property - def source_paths(self) -> tuple[pathlib.Path]: - return self.source - - def _for_cache_name(self): - yield from super()._for_cache_name() - yield from sorted(bytes(p.resolve()) for p in self.source_paths) - - @classmethod - def from_strings(cls, ss: Iterable[str], converter_id: str): - return cls(tuple(pathlib.Path(s) for s in ss), converter_id) - - -class NameByHashIter: - """Given multiple hashes, the name is built from them in ascending order.""" - - _source_type = tuple - - def _for_cache_name(self): - yield from super()._for_cache_name() - yield from sorted(h for h in self.source) - - -class DiskCache: - """A class to store and load cached objects to disk, which - are built from a source object and conversion function. - - The basename for the cache file is a hash hexdigest - built by feeding a collection of values determined by - the Header object. - - Parameters - ---------- - cache_folder - indicates where the cache files will be saved. - """ - - # Maps classes to header class - _header_classes: dict[type, BaseHeader] = None - - # Hasher object constructor (e.g. a member of hashlib) - # must implement update(b: bytes) and hexdigest() methods - _hasher = hashlib.sha1 - - # If True, for each cached file the header is also stored. - _store_header: bool = True - - def __init__(self, cache_folder: str | pathlib.Path): - self.cache_folder = pathlib.Path(cache_folder) - self.cache_folder.mkdir(parents=True, exist_ok=True) - self._header_classes = self._header_classes or {} - - def register_header_class(self, object_class: type, header_class: BaseHeader): - self._header_classes[object_class] = header_class - - def cache_stem_for(self, header: BaseHeader) -> str: - """Generate a hash representing the basename of a memoized file - for a given header. - - The naming strategy is defined by the header class used. - """ - hd = self._hasher() - for value in header.for_cache_name(): - hd.update(value) - return hd.hexdigest() - - def cache_path_for(self, header: BaseHeader) -> pathlib.Path: - """Generate a Path representing the location of a memoized file - for a given filepath or object. - - The naming strategy is defined by the header class used. - """ - h = self.cache_stem_for(header) - return self.cache_folder.joinpath(h).with_suffix(".pickle") - - def _get_header_class(self, source_object) -> BaseHeader: - for k, v in self._header_classes.items(): - if isinstance(source_object, k): - return v - raise TypeError(f"Cannot find header class for {type(source_object)}") - - def load(self, source_object, converter=None, pass_hash=False) -> tuple[Any, str]: - """Given a source_object, return the converted value stored - in the cache together with the cached path stem - - When the cache is not found: - - If a converter callable is given, use it on the source - object, store the result in the cache and return it. - - Return None, otherwise. - - Two signatures for the converter are valid: - - source_object -> transformed object - - (source_object, cached_path_stem) -> transformed_object - - To use the second one, use `pass_hash=True`. - - If you want to do the conversion yourself outside this class, - use the converter argument to provide a name for it. This is - important as the cached_path_stem depends on the converter name. - """ - header_class = self._get_header_class(source_object) - - if isinstance(converter, str): - converter_id = converter - converter = None - else: - converter_id = getattr(converter, "__name__", "") - - header = header_class(source_object, converter_id) - - cache_path = self.cache_path_for(header) - - converted_object = self.rawload(header, cache_path) - - if converted_object: - return converted_object, cache_path.stem - if converter is None: - return None, cache_path.stem - - if pass_hash: - converted_object = converter(source_object, cache_path.stem) - else: - converted_object = converter(source_object) - - self.rawsave(header, converted_object, cache_path) - - return converted_object, cache_path.stem - - def save(self, converted_object, source_object, converter_id="") -> str: - """Given a converted_object and its corresponding source_object, - store it in the cache and return the cached_path_stem. - """ - - header_class = self._get_header_class(source_object) - header = header_class(source_object, converter_id) - return self.rawsave(header, converted_object, self.cache_path_for(header)).stem - - def rawload( - self, header: BaseHeader, cache_path: pathlib.Path = None - ) -> Any | None: - """Load the converted_object from the cache if it is valid. - - The invalidating strategy is defined by the header class used. - - The cache_path is optional, it will be calculated from the header - if not given. - """ - if cache_path is None: - cache_path = self.cache_path_for(header) - - if header.is_valid(cache_path): - with cache_path.open(mode="rb") as fi: - return pickle.load(fi) - - def rawsave( - self, header: BaseHeader, converted, cache_path: pathlib.Path = None - ) -> pathlib.Path: - """Save the converted object (in pickle format) and - its header (in json format) to the cache folder. - - The cache_path is optional, it will be calculated from the header - if not given. - """ - if cache_path is None: - cache_path = self.cache_path_for(header) - - if self._store_header: - with cache_path.with_suffix(".json").open("w", encoding="utf-8") as fo: - json.dump({k: str(v) for k, v in dc_asdict(header).items()}, fo) - with cache_path.open(mode="wb") as fo: - pickle.dump(converted, fo) - return cache_path - - -class DiskCacheByHash(DiskCache): - """Convenience class used for caching conversions that take a path, - naming by hashing its content. - """ - - @dataclass(frozen=True) - class Header(NameByFileContent, InvalidateByExist, BaseHeader): - pass - - _header_classes = { - pathlib.Path: Header, - str: Header.from_string, - } - - -class DiskCacheByMTime(DiskCache): - """Convenience class used for caching conversions that take a path, - naming by hashing its full path and invalidating by the file - modification time. - """ - - @dataclass(frozen=True) - class Header(NameByPath, InvalidateByPathMTime, BaseHeader): - pass - - _header_classes = { - pathlib.Path: Header, - str: Header.from_string, - } diff --git a/pint/_vendor/flexparser.py b/pint/_vendor/flexparser.py deleted file mode 100644 index cac3c2b49..000000000 --- a/pint/_vendor/flexparser.py +++ /dev/null @@ -1,1686 +0,0 @@ -""" - flexparser.flexparser - ~~~~~~~~~~~~~~~~~~~~~ - - Classes and functions to create parsers. - - The idea is quite simple. You write a class for every type of content - (called here ``ParsedStatement``) you need to parse. Each class should - have a ``from_string`` constructor. We used extensively the ``typing`` - module to make the output structure easy to use and less error prone. - - For more information, take a look at https://github.com/hgrecco/flexparser - - :copyright: 2022 by flexparser Authors, see AUTHORS for more details. - :license: BSD, see LICENSE for more details. -""" - -from __future__ import annotations - -import sys -import collections -import dataclasses -import enum -import functools -import hashlib -import hmac -import inspect -import logging -import pathlib -import re -import typing as ty -from dataclasses import dataclass -from functools import cached_property -from importlib import resources -from typing import Any, Union, Optional, no_type_check - -if sys.version_info >= (3, 10): - from typing import TypeAlias # noqa -else: - from typing_extensions import TypeAlias # noqa - - -if sys.version_info >= (3, 11): - from typing import Self # noqa -else: - from typing_extensions import Self # noqa - - -_LOGGER = logging.getLogger("flexparser") - -_SENTINEL = object() - - -class HasherProtocol(ty.Protocol): - @property - def name(self) -> str: - ... - - def hexdigest(self) -> str: - ... - - -class GenericInfo: - _specialized: Optional[ - dict[type, Optional[list[tuple[type, dict[ty.TypeVar, type]]]]] - ] = None - - @staticmethod - def _summarize(d: dict[ty.TypeVar, type]) -> dict[ty.TypeVar, type]: - d = d.copy() - while True: - for k, v in d.items(): - if isinstance(v, ty.TypeVar): - d[k] = d[v] - break - else: - return d - - del d[v] - - @classmethod - def _specialization(cls) -> dict[ty.TypeVar, type]: - if cls._specialized is None: - return dict() - - out: dict[ty.TypeVar, type] = {} - specialized = cls._specialized[cls] - - if specialized is None: - return {} - - for parent, content in specialized: - for tvar, typ in content.items(): - out[tvar] = typ - origin = getattr(parent, "__origin__", None) - if origin is not None and origin in cls._specialized: - out = {**origin._specialization(), **out} - - return out - - @classmethod - def specialization(cls) -> dict[ty.TypeVar, type]: - return GenericInfo._summarize(cls._specialization()) - - def __init_subclass__(cls) -> None: - if cls._specialized is None: - cls._specialized = {GenericInfo: None} - - tv: list[ty.TypeVar] = [] - entries: list[tuple[type, dict[ty.TypeVar, type]]] = [] - - for par in getattr(cls, "__parameters__", ()): - if isinstance(par, ty.TypeVar): - tv.append(par) - - for b in getattr(cls, "__orig_bases__", ()): - for k in cls._specialized.keys(): - if getattr(b, "__origin__", None) is k: - entries.append((b, {k: v for k, v in zip(tv, b.__args__)})) - break - - cls._specialized[cls] = entries - - return super().__init_subclass__() - - -################ -# Exceptions -################ - - -@dataclass(frozen=True) -class Statement: - """Base class for parsed elements within a source file.""" - - is_position_set: bool = dataclasses.field(init=False, default=False, repr=False) - - start_line: int = dataclasses.field(init=False, default=0) - start_col: int = dataclasses.field(init=False, default=0) - - end_line: int = dataclasses.field(init=False, default=0) - end_col: int = dataclasses.field(init=False, default=0) - - raw: Optional[str] = dataclasses.field(init=False, default=None) - - @classmethod - def from_statement(cls, statement: Statement) -> Self: - out = cls() - if statement.is_position_set: - out.set_position(*statement.get_position()) - if statement.raw is not None: - out.set_raw(statement.raw) - return out - - @classmethod - def from_statement_iterator_element( - cls, values: tuple[int, int, int, int, str] - ) -> Self: - out = cls() - out.set_position(*values[:-1]) - out.set_raw(values[-1]) - return out - - @property - def format_position(self) -> str: - if not self.is_position_set: - return "N/A" - return "%d,%d-%d,%d" % self.get_position() - - @property - def raw_strip(self) -> Optional[str]: - if self.raw is None: - return None - return self.raw.strip() - - def get_position(self) -> tuple[int, int, int, int]: - if self.is_position_set: - return self.start_line, self.start_col, self.end_line, self.end_col - return 0, 0, 0, 0 - - def set_position( - self: Self, start_line: int, start_col: int, end_line: int, end_col: int - ) -> Self: - object.__setattr__(self, "is_position_set", True) - object.__setattr__(self, "start_line", start_line) - object.__setattr__(self, "start_col", start_col) - object.__setattr__(self, "end_line", end_line) - object.__setattr__(self, "end_col", end_col) - return self - - def set_raw(self: Self, raw: str) -> Self: - object.__setattr__(self, "raw", raw) - return self - - def set_simple_position(self: Self, line: int, col: int, width: int) -> Self: - return self.set_position(line, col, line, col + width) - - -@dataclass(frozen=True) -class ParsingError(Statement, Exception): - """Base class for all parsing exceptions in this package.""" - - def __str__(self) -> str: - return Statement.__str__(self) - - -@dataclass(frozen=True) -class UnknownStatement(ParsingError): - """A string statement could not bee parsed.""" - - def __str__(self) -> str: - return f"Could not parse '{self.raw}' ({self.format_position})" - - -@dataclass(frozen=True) -class UnhandledParsingError(ParsingError): - """Base class for all parsing exceptions in this package.""" - - ex: Exception - - def __str__(self) -> str: - return f"Unhandled exception while parsing '{self.raw}' ({self.format_position}): {self.ex}" - - -@dataclass(frozen=True) -class UnexpectedEOS(ParsingError): - """End of file was found within an open block.""" - - -############################# -# Useful methods and classes -############################# - - -@dataclass(frozen=True) -class Hash: - algorithm_name: str - hexdigest: str - - def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, Hash) - and self.algorithm_name != "" - and self.algorithm_name == other.algorithm_name - and hmac.compare_digest(self.hexdigest, other.hexdigest) - ) - - @classmethod - def from_bytes( - cls, - algorithm: ty.Callable[ - [ - bytes, - ], - HasherProtocol, - ], - b: bytes, - ) -> Self: - hasher = algorithm(b) - return cls(hasher.name, hasher.hexdigest()) - - @classmethod - def from_file_pointer( - cls, - algorithm: ty.Callable[ - [ - bytes, - ], - HasherProtocol, - ], - fp: ty.BinaryIO, - ) -> Self: - return cls.from_bytes(algorithm, fp.read()) - - @classmethod - def nullhash(cls) -> Self: - return cls("", "") - - -def _yield_types( - obj: type, - valid_subclasses: tuple[type, ...] = (object,), - recurse_origin: tuple[Any, ...] = (tuple, list, Union), -) -> ty.Generator[type, None, None]: - """Recursively transverse type annotation if the - origin is any of the types in `recurse_origin` - and yield those type which are subclasses of `valid_subclasses`. - - """ - if ty.get_origin(obj) in recurse_origin: - for el in ty.get_args(obj): - yield from _yield_types(el, valid_subclasses, recurse_origin) - else: - if inspect.isclass(obj) and issubclass(obj, valid_subclasses): - yield obj - - -class classproperty: # noqa N801 - """Decorator for a class property - - In Python 3.9+ can be replaced by - - @classmethod - @property - def myprop(self): - return 42 - - """ - - def __init__(self, fget): # type: ignore - self.fget = fget - - def __get__(self, owner_self, owner_cls): # type: ignore - return self.fget(owner_cls) # type: ignore - - -class DelimiterInclude(enum.IntEnum): - """Specifies how to deal with delimiters while parsing.""" - - #: Split at delimiter, not including in any string - SPLIT = enum.auto() - - #: Split after, keeping the delimiter with previous string. - SPLIT_AFTER = enum.auto() - - #: Split before, keeping the delimiter with next string. - SPLIT_BEFORE = enum.auto() - - #: Do not split at delimiter. - DO_NOT_SPLIT = enum.auto() - - -class DelimiterAction(enum.IntEnum): - """Specifies how to deal with delimiters while parsing.""" - - #: Continue parsing normally. - CONTINUE = enum.auto() - - #: Capture everything til end of line as a whole. - CAPTURE_NEXT_TIL_EOL = enum.auto() - - #: Stop parsing line and move to next. - STOP_PARSING_LINE = enum.auto() - - #: Stop parsing content. - STOP_PARSING = enum.auto() - - -DO_NOT_SPLIT_EOL = { - "\r\n": (DelimiterInclude.DO_NOT_SPLIT, DelimiterAction.CONTINUE), - "\n": (DelimiterInclude.DO_NOT_SPLIT, DelimiterAction.CONTINUE), - "\r": (DelimiterInclude.DO_NOT_SPLIT, DelimiterAction.CONTINUE), -} - -SPLIT_EOL = { - "\r\n": (DelimiterInclude.SPLIT, DelimiterAction.CONTINUE), - "\n": (DelimiterInclude.SPLIT, DelimiterAction.CONTINUE), - "\r": (DelimiterInclude.SPLIT, DelimiterAction.CONTINUE), -} - -_EOLs_set = set(DO_NOT_SPLIT_EOL.keys()) - - -@functools.lru_cache -def _build_delimiter_pattern(delimiters: tuple[str, ...]) -> re.Pattern[str]: - """Compile a tuple of delimiters into a regex expression with a capture group - around the delimiter. - """ - return re.compile("|".join(f"({re.escape(el)})" for el in delimiters)) - - -############ -# Iterators -############ - -DelimiterDictT = dict[str, tuple[DelimiterInclude, DelimiterAction]] - - -class Spliter: - """Content iterator splitting according to given delimiters. - - The pattern can be changed dynamically sending a new pattern to the ty.Generator, - see DelimiterInclude and DelimiterAction for more information. - - The current scanning position can be changed at any time. - - Parameters - ---------- - content : str - delimiters : dict[str, tuple[DelimiterInclude, DelimiterAction]] - - Yields - ------ - start_line : int - line number of the start of the content (zero-based numbering). - start_col : int - column number of the start of the content (zero-based numbering). - end_line : int - line number of the end of the content (zero-based numbering). - end_col : int - column number of the end of the content (zero-based numbering). - part : str - part of the text between delimiters. - """ - - _pattern: Optional[re.Pattern[str]] - _delimiters: DelimiterDictT - - __stop_searching_in_line: bool = False - - __pending: str = "" - __first_line_col: Optional[tuple[int, int]] = None - - __lines: list[str] - __lineno: int = 0 - __colno: int = 0 - - def __init__(self, content: str, delimiters: DelimiterDictT): - self.set_delimiters(delimiters) - self.__lines = content.splitlines(keepends=True) - - def set_position(self, lineno: int, colno: int) -> None: - self.__lineno, self.__colno = lineno, colno - - def set_delimiters(self, delimiters: DelimiterDictT) -> None: - for k, v in delimiters.items(): - if v == (DelimiterInclude.DO_NOT_SPLIT, DelimiterAction.STOP_PARSING): - raise ValueError( - f"The delimiter action for {k} is not a valid combination ({v})" - ) - # Build a pattern but removing eols - _pat_dlm = tuple(set(delimiters.keys()) - _EOLs_set) - if _pat_dlm: - self._pattern = _build_delimiter_pattern(_pat_dlm) - else: - self._pattern = None - # We add the end of line as delimiters if not present. - self._delimiters = {**DO_NOT_SPLIT_EOL, **delimiters} - - def __iter__(self) -> Spliter: - return self - - def __next__(self) -> tuple[int, int, int, int, str]: - if self.__lineno >= len(self.__lines): - raise StopIteration - - while True: - if self.__stop_searching_in_line: - # There must be part of a line pending to parse - # due to stop - line = self.__lines[self.__lineno] - mo = None - self.__stop_searching_in_line = False - else: - # We get the current line and the find the first delimiter. - line = self.__lines[self.__lineno] - if self._pattern is None: - mo = None - else: - mo = self._pattern.search(line, self.__colno) - - if mo is None: - # No delimiter was found, - # which should happen at end of the content or end of line - for k in DO_NOT_SPLIT_EOL.keys(): - if line.endswith(k): - dlm = line[-len(k) :] - end_col, next_col = len(line) - len(k), 0 - break - else: - # No EOL found, this is end of content - dlm = None - end_col, next_col = len(line), 0 - - next_line = self.__lineno + 1 - - else: - next_line = self.__lineno - end_col, next_col = mo.span() - dlm = mo.group() - - part = line[self.__colno : end_col] - - if dlm is None: - include, action = DelimiterInclude.SPLIT, DelimiterAction.STOP_PARSING - else: - include, action = self._delimiters[dlm] - - if include == DelimiterInclude.SPLIT: - next_pending = "" - else: - # When dlm is None, DelimiterInclude.SPLIT - assert isinstance(dlm, str) - if include == DelimiterInclude.SPLIT_AFTER: - end_col += len(dlm) - part = part + dlm - next_pending = "" - elif include == DelimiterInclude.SPLIT_BEFORE: - next_pending = dlm - elif include == DelimiterInclude.DO_NOT_SPLIT: - self.__pending += line[self.__colno : end_col] + dlm - next_pending = "" - else: - raise ValueError(f"Unknown action {include}.") - - if action == DelimiterAction.STOP_PARSING: - # this will raise a StopIteration in the next call. - next_line = len(self.__lines) - elif action == DelimiterAction.STOP_PARSING_LINE: - next_line = self.__lineno + 1 - next_col = 0 - - start_line = self.__lineno - start_col = self.__colno - end_line = self.__lineno - - self.__lineno = next_line - self.__colno = next_col - - if action == DelimiterAction.CAPTURE_NEXT_TIL_EOL: - self.__stop_searching_in_line = True - - if include == DelimiterInclude.DO_NOT_SPLIT: - self.__first_line_col = start_line, start_col - else: - if self.__first_line_col is None: - out = ( - start_line, - start_col - len(self.__pending), - end_line, - end_col, - self.__pending + part, - ) - else: - out = ( - *self.__first_line_col, - end_line, - end_col, - self.__pending + part, - ) - self.__first_line_col = None - self.__pending = next_pending - return out - - -class StatementIterator: - """Content peekable iterator splitting according to given delimiters. - - The pattern can be changed dynamically sending a new pattern to the ty.Generator, - see DelimiterInclude and DelimiterAction for more information. - - Parameters - ---------- - content : str - delimiters : dict[str, tuple[DelimiterInclude, DelimiterAction]] - - Yields - ------ - Statement - """ - - _cache: ty.Deque[Statement] - - def __init__( - self, content: str, delimiters: DelimiterDictT, strip_spaces: bool = True - ): - self._cache = collections.deque() - self._spliter = Spliter(content, delimiters) - self._strip_spaces = strip_spaces - - def __iter__(self): - return self - - def set_delimiters(self, delimiters: DelimiterDictT) -> None: - self._spliter.set_delimiters(delimiters) - if self._cache: - value = self.peek() - # Elements are 1 based indexing, while splitter is 0 based. - self._spliter.set_position(value.start_line - 1, value.start_col) - self._cache.clear() - - def _get_next_strip(self) -> Statement: - part = "" - while not part: - start_line, start_col, end_line, end_col, part = next(self._spliter) - lo = len(part) - part = part.lstrip() - start_col += lo - len(part) - - lo = len(part) - part = part.rstrip() - end_col -= lo - len(part) - - return Statement.from_statement_iterator_element( - (start_line + 1, start_col, end_line + 1, end_col, part) # type: ignore - ) - - def _get_next(self) -> Statement: - if self._strip_spaces: - return self._get_next_strip() - - part = "" - while not part: - start_line, start_col, end_line, end_col, part = next(self._spliter) - - return Statement.from_statement_iterator_element( - (start_line + 1, start_col, end_line + 1, end_col, part) # type: ignore - ) - - def peek(self, default: Any = _SENTINEL) -> Statement: - """Return the item that will be next returned from ``next()``. - - Return ``default`` if there are no items left. If ``default`` is not - provided, raise ``StopIteration``. - - """ - if not self._cache: - try: - self._cache.append(self._get_next()) - except StopIteration: - if default is _SENTINEL: - raise - return default - return self._cache[0] - - def __next__(self) -> Statement: - if self._cache: - return self._cache.popleft() - return self._get_next() - - -########### -# Parsing -########### - -# Configuration type -T = ty.TypeVar("T") -CT = ty.TypeVar("CT") -PST = ty.TypeVar("PST", bound="ParsedStatement[Any]") -LineColStr: TypeAlias = tuple[int, int, str] - -ParsedResult: TypeAlias = Union[T, ParsingError] -NullableParsedResult: TypeAlias = Union[T, ParsingError, None] - - -class ConsumeProtocol(ty.Protocol): - @property - def is_position_set(self) -> bool: - ... - - @property - def start_line(self) -> int: - ... - - @property - def start_col(self) -> int: - ... - - @property - def end_line(self) -> int: - ... - - @property - def end_col(self) -> int: - ... - - @classmethod - def consume( - cls, statement_iterator: StatementIterator, config: Any - ) -> NullableParsedResult[Self]: - ... - - -@dataclass(frozen=True) -class ParsedStatement(ty.Generic[CT], Statement): - """A single parsed statement. - - In order to write your own, you need to subclass it as a - frozen dataclass and implement the parsing logic by overriding - `from_string` classmethod. - - Takes two arguments: the string to parse and an object given - by the parser which can be used to store configuration information. - - It should return an instance of this class if parsing - was successful or None otherwise - """ - - @classmethod - def from_string(cls, s: str) -> NullableParsedResult[Self]: - """Parse a string into a ParsedStatement. - - Return files and their meaning: - 1. None: the string cannot be parsed with this class. - 2. A subclass of ParsedStatement: the string was parsed successfully - 3. A subclass of ParsingError the string could be parsed with this class but there is - an error. - """ - raise NotImplementedError( - "ParsedStatement subclasses must implement " - "'from_string' or 'from_string_and_config'" - ) - - @classmethod - def from_string_and_config(cls, s: str, config: CT) -> NullableParsedResult[Self]: - """Parse a string into a ParsedStatement. - - Return files and their meaning: - 1. None: the string cannot be parsed with this class. - 2. A subclass of ParsedStatement: the string was parsed successfully - 3. A subclass of ParsingError the string could be parsed with this class but there is - an error. - """ - return cls.from_string(s) - - @classmethod - def from_statement_and_config( - cls, statement: Statement, config: CT - ) -> NullableParsedResult[Self]: - raw = statement.raw - if raw is None: - return None - - try: - out = cls.from_string_and_config(raw, config) - except Exception as ex: - out = UnhandledParsingError(ex) - - if out is None: - return None - - out.set_position(*statement.get_position()) - out.set_raw(raw) - return out - - @classmethod - def consume( - cls, statement_iterator: StatementIterator, config: CT - ) -> NullableParsedResult[Self]: - """Peek into the iterator and try to parse. - - Return files and their meaning: - 1. None: the string cannot be parsed with this class, the iterator is kept an the current place. - 2. a subclass of ParsedStatement: the string was parsed successfully, advance the iterator. - 3. a subclass of ParsingError: the string could be parsed with this class but there is - an error, advance the iterator. - """ - statement = statement_iterator.peek() - parsed_statement = cls.from_statement_and_config(statement, config) - if parsed_statement is None: - return None - next(statement_iterator) - return parsed_statement - - -OPST = ty.TypeVar("OPST", bound="ParsedStatement[Any]") -BPST = ty.TypeVar( - "BPST", bound="Union[ParsedStatement[Any], Block[Any, Any, Any, Any]]" -) -CPST = ty.TypeVar("CPST", bound="ParsedStatement[Any]") -RBT = ty.TypeVar("RBT", bound="RootBlock[Any, Any]") - - -@dataclass(frozen=True) -class Block(ty.Generic[OPST, BPST, CPST, CT], GenericInfo): - """A sequence of statements with an opening, body and closing.""" - - opening: ParsedResult[OPST] - body: tuple[ParsedResult[BPST], ...] - closing: Union[ParsedResult[CPST], EOS[CT]] - - delimiters: DelimiterDictT = dataclasses.field(default_factory=dict, init=False) - - def is_closed(self) -> bool: - return not isinstance(self.closing, EOS) - - @property - def is_position_set(self) -> bool: - return self.opening.is_position_set - - @property - def start_line(self) -> int: - return self.opening.start_line - - @property - def start_col(self) -> int: - return self.opening.start_col - - @property - def end_line(self) -> int: - return self.closing.end_line - - @property - def end_col(self) -> int: - return self.closing.end_col - - def get_position(self) -> tuple[int, int, int, int]: - return self.start_line, self.start_col, self.end_line, self.end_col - - @property - def format_position(self) -> str: - if not self.is_position_set: - return "N/A" - return "%d,%d-%d,%d" % self.get_position() - - def __iter__( - self, - ) -> ty.Generator[ - ParsedResult[Union[OPST, BPST, Union[CPST, EOS[CT]]]], None, None - ]: - yield self.opening - for el in self.body: - if isinstance(el, Block): - yield from el - else: - yield el - yield self.closing - - def iter_blocks( - self, - ) -> ty.Generator[ParsedResult[Union[OPST, BPST, CPST]], None, None]: - # raise RuntimeError("Is this used?") - yield self.opening - yield from self.body - yield self.closing - - ################################################### - # Convenience methods to iterate parsed statements - ################################################### - - _ElementT = ty.TypeVar("_ElementT", bound=Statement) - - def filter_by( - self, klass1: type[_ElementT], *klass: type[_ElementT] - ) -> ty.Generator[_ElementT, None, None]: - """Yield elements of a given class or classes.""" - yield from (el for el in self if isinstance(el, (klass1,) + klass)) # type: ignore[misc] - - @cached_property - def errors(self) -> tuple[ParsingError, ...]: - """Tuple of errors found.""" - return tuple(self.filter_by(ParsingError)) - - @property - def has_errors(self) -> bool: - """True if errors were found during parsing.""" - return bool(self.errors) - - #################### - # Statement classes - #################### - - @classmethod - def opening_classes(cls) -> ty.Generator[type[OPST], None, None]: - """Classes representing any of the parsed statement that can open this block.""" - try: - opening = cls.specialization()[OPST] # type: ignore[misc] - except KeyError: - opening: type = ty.get_type_hints(cls)["opening"] # type: ignore[no-redef] - yield from _yield_types(opening, ParsedStatement) # type: ignore - - @classmethod - def body_classes(cls) -> ty.Generator[type[BPST], None, None]: - """Classes representing any of the parsed statement that can be in the body.""" - try: - body = cls.specialization()[BPST] # type: ignore[misc] - except KeyError: - body: type = ty.get_type_hints(cls)["body"] # type: ignore[no-redef] - yield from _yield_types(body, (ParsedStatement, Block)) # type: ignore - - @classmethod - def closing_classes(cls) -> ty.Generator[type[CPST], None, None]: - """Classes representing any of the parsed statement that can close this block.""" - try: - closing = cls.specialization()[CPST] # type: ignore[misc] - except KeyError: - closing: type = ty.get_type_hints(cls)["closing"] # type: ignore[no-redef] - yield from _yield_types(closing, ParsedStatement) # type: ignore - - ########## - # ParsedResult - ########## - - @classmethod - def consume_opening( - cls, statement_iterator: StatementIterator, config: CT - ) -> NullableParsedResult[OPST]: - """Peek into the iterator and try to parse with any of the opening classes. - - See `ParsedStatement.consume` for more details. - """ - for c in cls.opening_classes(): - el = c.consume(statement_iterator, config) - if el is not None: - return el - return None - - @classmethod - def consume_body( - cls, statement_iterator: StatementIterator, config: CT - ) -> ParsedResult[BPST]: - """Peek into the iterator and try to parse with any of the body classes. - - If the statement cannot be parsed, a UnknownStatement is returned. - """ - for c in cls.body_classes(): - el = c.consume(statement_iterator, config) - if el is not None: - return el - unkel = next(statement_iterator) - return UnknownStatement.from_statement(unkel) - - @classmethod - def consume_closing( - cls, statement_iterator: StatementIterator, config: CT - ) -> NullableParsedResult[CPST]: - """Peek into the iterator and try to parse with any of the opening classes. - - See `ParsedStatement.consume` for more details. - """ - for c in cls.closing_classes(): - el = c.consume(statement_iterator, config) - if el is not None: - return el - return None - - @classmethod - def consume_body_closing( - cls, opening: OPST, statement_iterator: StatementIterator, config: CT - ) -> Self: - body: list[ParsedResult[BPST]] = [] - closing: ty.Union[CPST, ParsingError, None] = None - last_line = opening.end_line - while closing is None: - try: - closing = cls.consume_closing(statement_iterator, config) - if closing is not None: - continue - el = cls.consume_body(statement_iterator, config) - body.append(el) - last_line = el.end_line - except StopIteration: - unexpected_end = cls.on_stop_iteration(config) - unexpected_end.set_position(last_line + 1, 0, last_line + 1, 0) - return cls(opening, tuple(body), unexpected_end) - - return cls(opening, tuple(body), closing) - - @classmethod - def consume( - cls, statement_iterator: StatementIterator, config: CT - ) -> Union[Self, None]: - """Try consume the block. - - Possible outcomes: - 1. The opening was not matched, return None. - 2. A subclass of Block, where body and closing migh contain errors. - """ - opening = cls.consume_opening(statement_iterator, config) - if opening is None: - return None - - if isinstance(opening, ParsingError): - return None - - return cls.consume_body_closing(opening, statement_iterator, config) - - @classmethod - def on_stop_iteration(cls, config: CT) -> ParsedResult[EOS[CT]]: - return UnexpectedEOS() - - -@dataclass(frozen=True) -class BOS(ty.Generic[CT], ParsedStatement[CT]): - """Beginning of source.""" - - # Hasher algorithm name and hexdigest - content_hash: Hash - - @classmethod - def from_string_and_config(cls, s: str, config: CT) -> NullableParsedResult[Self]: - raise RuntimeError("BOS cannot be constructed from_string_and_config") - - @property - def location(self) -> SourceLocationT: - return "" - - -@dataclass(frozen=True) -class BOF(ty.Generic[CT], BOS[CT]): - """Beginning of file.""" - - path: pathlib.Path - - # Modification time of the file. - mtime: float - - @property - def location(self) -> SourceLocationT: - return self.path - - -@dataclass(frozen=True) -class BOR(ty.Generic[CT], BOS[CT]): - """Beginning of resource.""" - - package: str - resource_name: str - - @property - def location(self) -> SourceLocationT: - return self.package, self.resource_name - - -@dataclass(frozen=True) -class EOS(ty.Generic[CT], ParsedStatement[CT]): - """End of sequence.""" - - @classmethod - def from_string_and_config( - cls: type[PST], s: str, config: CT - ) -> NullableParsedResult[PST]: - return cls() - - -class RootBlock(ty.Generic[BPST, CT], Block[BOS[CT], BPST, EOS[CT], CT]): - """A sequence of statement flanked by the beginning and ending of stream.""" - - @classmethod - def consume_opening( - cls, statement_iterator: StatementIterator, config: CT - ) -> NullableParsedResult[BOS[CT]]: - raise RuntimeError( - "Implementation error, 'RootBlock.consume_opening' should never be called" - ) - - @classmethod - def consume(cls, statement_iterator: StatementIterator, config: CT) -> Self: - block = super().consume(statement_iterator, config) - if block is None: - raise RuntimeError( - "Implementation error, 'RootBlock.consume' should never return None" - ) - return block - - @classmethod - def consume_closing( - cls, statement_iterator: StatementIterator, config: CT - ) -> NullableParsedResult[EOS[CT]]: - return None - - @classmethod - def on_stop_iteration(cls, config: CT) -> ParsedResult[EOS[CT]]: - return EOS[CT]() - - -################# -# Source parsing -################# - -ResourceT: TypeAlias = tuple[str, str] # package name, resource name -StrictLocationT: TypeAlias = Union[pathlib.Path, ResourceT] -SourceLocationT: TypeAlias = Union[str, StrictLocationT] - - -@dataclass(frozen=True) -class ParsedSource(ty.Generic[RBT, CT]): - parsed_source: RBT - - # Parser configuration. - config: CT - - @property - def location(self) -> SourceLocationT: - if isinstance(self.parsed_source.opening, ParsingError): - raise self.parsed_source.opening - return self.parsed_source.opening.location - - @cached_property - def has_errors(self) -> bool: - return self.parsed_source.has_errors - - def errors(self) -> ty.Generator[ParsingError, None, None]: - yield from self.parsed_source.errors - - -@dataclass(frozen=True) -class CannotParseResourceAsFile(Exception): - """The requested python package resource cannot be located as a file - in the file system. - """ - - package: str - resource_name: str - - -class Parser(ty.Generic[RBT, CT], GenericInfo): - """Parser class.""" - - #: class to iterate through statements in a source unit. - _statement_iterator_class: type[StatementIterator] = StatementIterator - - #: Delimiters. - _delimiters: DelimiterDictT = SPLIT_EOL - - _strip_spaces: bool = True - - #: source file text encoding. - _encoding: str = "utf-8" - - #: configuration passed to from_string functions. - _config: CT - - #: try to open resources as files. - _prefer_resource_as_file: bool - - #: parser algorithm to us. Must be a callable member of hashlib - _hasher: ty.Callable[ - [ - bytes, - ], - HasherProtocol, - ] = hashlib.blake2b - - def __init__(self, config: CT, prefer_resource_as_file: bool = True): - self._config = config - self._prefer_resource_as_file = prefer_resource_as_file - - @classmethod - def root_boot_class(cls) -> type[RBT]: - """Class representing the root block class.""" - try: - return cls.specialization()[RBT] # type: ignore[misc] - except KeyError: - return ty.get_type_hints(cls)["root_boot_class"] # type: ignore[no-redef] - - def parse(self, source_location: SourceLocationT) -> ParsedSource[RBT, CT]: - """Parse a file into a ParsedSourceFile or ParsedResource. - - Parameters - ---------- - source_location: - if str or pathlib.Path is interpreted as a file. - if (str, str) is interpreted as (package, resource) using the resource python api. - """ - if isinstance(source_location, tuple) and len(source_location) == 2: - if self._prefer_resource_as_file: - try: - return self.parse_resource_from_file(*source_location) - except CannotParseResourceAsFile: - pass - return self.parse_resource(*source_location) - - if isinstance(source_location, str): - return self.parse_file(pathlib.Path(source_location)) - - if isinstance(source_location, pathlib.Path): - return self.parse_file(source_location) - - raise TypeError( - f"Unknown type {type(source_location)}, " - "use str or pathlib.Path for files or " - "(package: str, resource_name: str) tuple " - "for a resource." - ) - - def parse_bytes( - self, b: bytes, bos: Optional[BOS[CT]] = None - ) -> ParsedSource[RBT, CT]: - if bos is None: - bos = BOS[CT](Hash.from_bytes(self._hasher, b)).set_simple_position(0, 0, 0) - - sic = self._statement_iterator_class( - b.decode(self._encoding), self._delimiters, self._strip_spaces - ) - - parsed = self.root_boot_class().consume_body_closing(bos, sic, self._config) - - return ParsedSource( - parsed, - self._config, - ) - - def parse_file(self, path: pathlib.Path) -> ParsedSource[RBT, CT]: - """Parse a file into a ParsedSourceFile. - - Parameters - ---------- - path - path of the file. - """ - with path.open(mode="rb") as fi: - content = fi.read() - - bos = BOF[CT]( - Hash.from_bytes(self._hasher, content), path, path.stat().st_mtime - ).set_simple_position(0, 0, 0) - return self.parse_bytes(content, bos) - - def parse_resource_from_file( - self, package: str, resource_name: str - ) -> ParsedSource[RBT, CT]: - """Parse a resource into a ParsedSourceFile, opening as a file. - - Parameters - ---------- - package - package name where the resource is located. - resource_name - name of the resource - """ - with resources.as_file(resources.files(package).joinpath(resource_name)) as p: - path = p.resolve() - - if path.exists(): - return self.parse_file(path) - - raise CannotParseResourceAsFile(package, resource_name) - - def parse_resource(self, package: str, resource_name: str) -> ParsedSource[RBT, CT]: - """Parse a resource into a ParsedResource. - - Parameters - ---------- - package - package name where the resource is located. - resource_name - name of the resource - """ - with resources.files(package).joinpath(resource_name).open("rb") as fi: - content = fi.read() - - bos = BOR[CT]( - Hash.from_bytes(self._hasher, content), package, resource_name - ).set_simple_position(0, 0, 0) - - return self.parse_bytes(content, bos) - - -########## -# Project -########## - - -class IncludeStatement(ty.Generic[CT], ParsedStatement[CT]): - """ "Include statements allow to merge files.""" - - @property - def target(self) -> str: - raise NotImplementedError( - "IncludeStatement subclasses must implement target property." - ) - - -class ParsedProject( - ty.Generic[RBT, CT], - dict[ - Optional[tuple[StrictLocationT, str]], - ParsedSource[RBT, CT], - ], -): - """Collection of files, independent or connected via IncludeStatement. - - Keys are either an absolute pathname or a tuple package name, resource name. - - None is the name of the root. - - """ - - @cached_property - def has_errors(self) -> bool: - return any(el.has_errors for el in self.values()) - - def errors(self) -> ty.Generator[ParsingError, None, None]: - for el in self.values(): - yield from el.errors() - - def _iter_statements( - self, - items: ty.Iterable[tuple[Any, Any]], - seen: set[Any], - include_only_once: bool, - ) -> ty.Generator[ParsedStatement[CT], None, None]: - """Iter all definitions in the order they appear, - going into the included files. - """ - for source_location, parsed in items: - seen.add(source_location) - for parsed_statement in parsed.parsed_source: - if isinstance(parsed_statement, IncludeStatement): - location = parsed.location, parsed_statement.target - if location in seen and include_only_once: - raise ValueError(f"{location} was already included.") - yield from self._iter_statements( - ((location, self[location]),), seen, include_only_once - ) - else: - yield parsed_statement - - def iter_statements( - self, include_only_once: bool = True - ) -> ty.Generator[ParsedStatement[CT], None, None]: - """Iter all definitions in the order they appear, - going into the included files. - - Parameters - ---------- - include_only_once - if true, each file cannot be included more than once. - """ - yield from self._iter_statements([(None, self[None])], set(), include_only_once) - - def _iter_blocks( - self, - items: ty.Iterable[tuple[Any, Any]], - seen: set[Any], - include_only_once: bool, - ) -> ty.Generator[ParsedStatement[CT], None, None]: - """Iter all definitions in the order they appear, - going into the included files. - """ - for source_location, parsed in items: - seen.add(source_location) - for parsed_statement in parsed.parsed_source.iter_blocks(): - if isinstance(parsed_statement, IncludeStatement): - location = parsed.location, parsed_statement.target - if location in seen and include_only_once: - raise ValueError(f"{location} was already included.") - yield from self._iter_blocks( - ((location, self[location]),), seen, include_only_once - ) - else: - yield parsed_statement - - def iter_blocks( - self, include_only_once: bool = True - ) -> ty.Generator[ParsedStatement[CT], None, None]: - """Iter all definitions in the order they appear, - going into the included files. - - Parameters - ---------- - include_only_once - if true, each file cannot be included more than once. - """ - yield from self._iter_blocks([(None, self[None])], set(), include_only_once) - - -def default_locator(source_location: StrictLocationT, target: str) -> StrictLocationT: - """Return a new location from current_location and target.""" - - if isinstance(source_location, pathlib.Path): - current_location = pathlib.Path(source_location).resolve() - - if current_location.is_file(): - current_path = current_location.parent - else: - current_path = current_location - - target_path = pathlib.Path(target) - if target_path.is_absolute(): - raise ValueError( - f"Cannot refer to absolute paths in import statements ({source_location}, {target})." - ) - - tmp = (current_path / target_path).resolve() - if not tmp.is_relative_to(current_path): - raise ValueError( - f"Cannot refer to locations above the current location ({source_location}, {target})" - ) - - return tmp.absolute() - - elif isinstance(source_location, tuple) and len(source_location) == 2: - return source_location[0], target - - raise TypeError( - f"Cannot handle type {type(source_location)}, " - "use str or pathlib.Path for files or " - "(package: str, resource_name: str) tuple " - "for a resource." - ) - - -@no_type_check -def _build_root_block_class_parsed_statement( - spec: type[ParsedStatement[CT]], config: type[CT] -) -> type[RootBlock[ParsedStatement[CT], CT]]: - """Build root block class from a single ParsedStatement.""" - - @dataclass(frozen=True) - class CustomRootBlockA(RootBlock[spec, config]): # type: ignore - pass - - return CustomRootBlockA - - -@no_type_check -def _build_root_block_class_block( - spec: type[Block[OPST, BPST, CPST, CT]], - config: type[CT], -) -> type[RootBlock[Block[OPST, BPST, CPST, CT], CT]]: - """Build root block class from a single ParsedStatement.""" - - @dataclass(frozen=True) - class CustomRootBlockA(RootBlock[spec, config]): # type: ignore - pass - - return CustomRootBlockA - - -@no_type_check -def _build_root_block_class_parsed_statement_it( - spec: tuple[type[Union[ParsedStatement[CT], Block[OPST, BPST, CPST, CT]]]], - config: type[CT], -) -> type[RootBlock[ParsedStatement[CT], CT]]: - """Build root block class from iterable ParsedStatement.""" - - @dataclass(frozen=True) - class CustomRootBlockA(RootBlock[Union[spec], config]): # type: ignore - pass - - return CustomRootBlockA - - -@no_type_check -def _build_parser_class_root_block( - spec: type[RootBlock[BPST, CT]], - *, - strip_spaces: bool = True, - delimiters: Optional[DelimiterDictT] = None, -) -> type[Parser[RootBlock[BPST, CT], CT]]: - class CustomParser(Parser[spec, spec.specialization()[CT]]): # type: ignore - _delimiters: DelimiterDictT = delimiters or SPLIT_EOL - _strip_spaces: bool = strip_spaces - - return CustomParser - - -@no_type_check -def build_parser_class( - spec: Union[ - type[ - Union[ - Parser[RBT, CT], - RootBlock[BPST, CT], - Block[OPST, BPST, CPST, CT], - ParsedStatement[CT], - ] - ], - ty.Iterable[type[ParsedStatement[CT]]], - ], - config: CT = None, - strip_spaces: bool = True, - delimiters: Optional[DelimiterDictT] = None, -) -> type[ - Union[ - Parser[RBT, CT], - Parser[RootBlock[BPST, CT], CT], - Parser[RootBlock[Block[OPST, BPST, CPST, CT], CT], CT], - ] -]: - """Build a custom parser class. - - Parameters - ---------- - spec - RootBlock derived class. - strip_spaces : bool - if True, spaces will be stripped for each statement before calling - ``from_string_and_config``. - delimiters : dict - Specify how the source file is split into statements (See below). - - Delimiters dictionary - --------------------- - The delimiters are specified with the keys of the delimiters dict. - The dict files can be used to further customize the iterator. Each - consist of a tuple of two elements: - 1. A value of the DelimiterMode to indicate what to do with the - delimiter string: skip it, attach keep it with previous or next string - 2. A boolean indicating if parsing should stop after fiSBT - encountering this delimiter. - """ - - if isinstance(spec, type): - if issubclass(spec, Parser): - CustomParser = spec - - elif issubclass(spec, RootBlock): - CustomParser = _build_parser_class_root_block( - spec, strip_spaces=strip_spaces, delimiters=delimiters - ) - - elif issubclass(spec, Block): - CustomRootBlock = _build_root_block_class_block(spec, config.__class__) - CustomParser = _build_parser_class_root_block( - CustomRootBlock, strip_spaces=strip_spaces, delimiters=delimiters - ) - - elif issubclass(spec, ParsedStatement): - CustomRootBlock = _build_root_block_class_parsed_statement( - spec, config.__class__ - ) - CustomParser = _build_parser_class_root_block( - CustomRootBlock, strip_spaces=strip_spaces, delimiters=delimiters - ) - - else: - raise TypeError( - "`spec` must be of type Parser, Block, RootBlock or tuple of type Block or ParsedStatement, " - f"not {type(spec)}" - ) - - elif isinstance(spec, (tuple, list)): - CustomRootBlock = _build_root_block_class_parsed_statement_it( - spec, config.__class__ - ) - CustomParser = _build_parser_class_root_block( - CustomRootBlock, strip_spaces=strip_spaces, delimiters=delimiters - ) - - else: - raise - - return CustomParser - - -@no_type_check -def parse( - entry_point: SourceLocationT, - spec: Union[ - type[ - Union[ - Parser[RBT, CT], - RootBlock[BPST, CT], - Block[OPST, BPST, CPST, CT], - ParsedStatement[CT], - ] - ], - ty.Iterable[type[ParsedStatement[CT]]], - ], - config: CT = None, - *, - strip_spaces: bool = True, - delimiters: Optional[DelimiterDictT] = None, - locator: ty.Callable[[SourceLocationT, str], StrictLocationT] = default_locator, - prefer_resource_as_file: bool = True, - **extra_parser_kwargs: Any, -) -> Union[ParsedProject[RBT, CT], ParsedProject[RootBlock[BPST, CT], CT]]: - """Parse sources into a ParsedProject dictionary. - - Parameters - ---------- - entry_point - file or resource, given as (package_name, resource_name). - spec - specification of the content to parse. Can be one of the following things: - - Parser class. - - Block or ParsedStatement derived class. - - ty.Iterable of Block or ParsedStatement derived class. - - RootBlock derived class. - config - a configuration object that will be passed to `from_string_and_config` - classmethod. - strip_spaces : bool - if True, spaces will be stripped for each statement before calling - ``from_string_and_config``. - delimiters : dict - Specify how the source file is split into statements (See below). - locator : Callable - function that takes the current location and a target of an IncludeStatement - and returns a new location. - prefer_resource_as_file : bool - if True, resources will try to be located in the filesystem if - available. - extra_parser_kwargs - extra keyword arguments to be given to the parser. - - Delimiters dictionary - --------------------- - The delimiters are specified with the keys of the delimiters dict. - The dict files can be used to further customize the iterator. Each - consist of a tuple of two elements: - 1. A value of the DelimiterMode to indicate what to do with the - delimiter string: skip it, attach keep it with previous or next string - 2. A boolean indicating if parsing should stop after fiSBT - encountering this delimiter. - """ - - CustomParser = build_parser_class(spec, config, strip_spaces, delimiters) - parser = CustomParser( - config, prefer_resource_as_file=prefer_resource_as_file, **extra_parser_kwargs - ) - - pp = ParsedProject() - - pending: list[tuple[SourceLocationT, str]] = [] - if isinstance(entry_point, (str, pathlib.Path)): - entry_point = pathlib.Path(entry_point) - if not entry_point.is_absolute(): - entry_point = pathlib.Path.cwd() / entry_point - - elif not (isinstance(entry_point, tuple) and len(entry_point) == 2): - raise TypeError( - f"Cannot handle type {type(entry_point)}, " - "use str or pathlib.Path for files or " - "(package: str, resource_name: str) tuple " - "for a resource." - ) - - pp[None] = parsed = parser.parse(entry_point) - pending.extend( - (parsed.location, el.target) - for el in parsed.parsed_source.filter_by(IncludeStatement) - ) - - while pending: - source_location, target = pending.pop(0) - pp[(source_location, target)] = parsed = parser.parse( - locator(source_location, target) - ) - pending.extend( - (parsed.location, el.target) - for el in parsed.parsed_source.filter_by(IncludeStatement) - ) - - return pp - - -@no_type_check -def parse_bytes( - content: bytes, - spec: Union[ - type[ - Union[ - Parser[RBT, CT], - RootBlock[BPST, CT], - Block[OPST, BPST, CPST, CT], - ParsedStatement[CT], - ] - ], - ty.Iterable[type[ParsedStatement[CT]]], - ], - config: Optional[CT] = None, - *, - strip_spaces: bool, - delimiters: Optional[DelimiterDictT], - **extra_parser_kwargs: Any, -) -> ParsedProject[ - Union[RBT, RootBlock[BPST, CT], RootBlock[ParsedStatement[CT], CT]], CT -]: - """Parse sources into a ParsedProject dictionary. - - Parameters - ---------- - content - bytes. - spec - specification of the content to parse. Can be one of the following things: - - Parser class. - - Block or ParsedStatement derived class. - - ty.Iterable of Block or ParsedStatement derived class. - - RootBlock derived class. - config - a configuration object that will be passed to `from_string_and_config` - classmethod. - strip_spaces : bool - if True, spaces will be stripped for each statement before calling - ``from_string_and_config``. - delimiters : dict - Specify how the source file is split into statements (See below). - """ - - CustomParser = build_parser_class(spec, config, strip_spaces, delimiters) - - parser = CustomParser(config, prefer_resource_as_file=False, **extra_parser_kwargs) - - pp = ParsedProject() - - pp[None] = parsed = parser.parse_bytes(content) - - if any(parsed.parsed_source.filter_by(IncludeStatement)): - raise ValueError("parse_bytes does not support using an IncludeStatement") - - return pp diff --git a/pint/compat.py b/pint/compat.py index 6bbdf35af..32ad04afb 100644 --- a/pint/compat.py +++ b/pint/compat.py @@ -10,31 +10,22 @@ from __future__ import annotations -import sys import math +import sys +from collections.abc import Callable, Iterable, Mapping from decimal import Decimal from importlib import import_module from numbers import Number -from collections.abc import Mapping -from typing import Any, NoReturn, Callable, Optional, Union -from collections.abc import Iterable - -try: - from uncertainties import UFloat, ufloat - from uncertainties import unumpy as unp - - HAS_UNCERTAINTIES = True -except ImportError: - UFloat = ufloat = unp = None - HAS_UNCERTAINTIES = False - +from typing import ( + Any, + NoReturn, +) if sys.version_info >= (3, 10): from typing import TypeAlias # noqa else: from typing_extensions import TypeAlias # noqa - if sys.version_info >= (3, 11): from typing import Self # noqa else: @@ -60,7 +51,7 @@ def missing_dependency( - package: str, display_name: Optional[str] = None + package: str, display_name: str | None = None ) -> Callable[..., NoReturn]: """Return a helper function that raises an exception when used. @@ -82,6 +73,17 @@ class BehaviorChangeWarning(UserWarning): pass +try: + from uncertainties import UFloat, ufloat + from uncertainties import unumpy as unp + + HAS_UNCERTAINTIES = True +except ImportError: + UFloat = ufloat = unp = None + + HAS_UNCERTAINTIES = False + + try: import numpy as np from numpy import datetime64 as np_datetime64 @@ -176,6 +178,9 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False): except ImportError: HAS_BABEL = False + babel_parse = missing_dependency("Babel") # noqa: F811 # type:ignore + babel_units = babel_parse + try: import mip @@ -190,19 +195,6 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False): except ImportError: HAS_MIP = False -# Defines Logarithm and Exponential for Logarithmic Converter -if HAS_NUMPY: - from numpy import exp # noqa: F401 - from numpy import log # noqa: F401 -else: - from math import exp # noqa: F401 - from math import log # noqa: F401 - -if not HAS_BABEL: - babel_parse = missing_dependency("Babel") # noqa: F811 - babel_units = babel_parse - -if not HAS_MIP: mip_missing = missing_dependency("mip") mip_model = mip_missing mip_Model = mip_missing @@ -211,6 +203,19 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False): mip_xsum = mip_missing mip_OptimizationStatus = mip_missing +# Defines Logarithm and Exponential for Logarithmic Converter +if HAS_NUMPY: + from numpy import ( + exp, # noqa: F401 + log, # noqa: F401 + ) +else: + from math import ( + exp, # noqa: F401 + log, # noqa: F401 + ) + + # Define location of pint.Quantity in NEP-13 type cast hierarchy by defining upcast # types using guarded imports @@ -236,7 +241,7 @@ def _to_magnitude(value, force_ndarray=False, force_ndarray_like=False): ) #: Map type name to the actual type (for upcast types). -upcast_type_map: Mapping[str, Optional[type]] = {k: None for k in upcast_type_names} +upcast_type_map: Mapping[str, type | None] = {k: None for k in upcast_type_names} def fully_qualified_name(t: type) -> str: @@ -297,7 +302,7 @@ def is_duck_array(obj: type) -> bool: return is_duck_array_type(type(obj)) -def eq(lhs: Any, rhs: Any, check_all: bool) -> Union[bool, Iterable[bool]]: +def eq(lhs: Any, rhs: Any, check_all: bool) -> bool | Iterable[bool]: """Comparison of scalars and arrays. Parameters @@ -320,7 +325,7 @@ def eq(lhs: Any, rhs: Any, check_all: bool) -> Union[bool, Iterable[bool]]: return out -def isnan(obj: Any, check_all: bool) -> Union[bool, Iterable[bool]]: +def isnan(obj: Any, check_all: bool) -> bool | Iterable[bool]: """Test for NaN or NaT. Parameters @@ -362,7 +367,7 @@ def isnan(obj: Any, check_all: bool) -> Union[bool, Iterable[bool]]: return False -def zero_or_nan(obj: Any, check_all: bool) -> Union[bool, Iterable[bool]]: +def zero_or_nan(obj: Any, check_all: bool) -> bool | Iterable[bool]: """Test if obj is zero, NaN, or NaT. Parameters diff --git a/pint/converters.py b/pint/converters.py index 249cbbf89..fbe3b5fb0 100644 --- a/pint/converters.py +++ b/pint/converters.py @@ -12,12 +12,10 @@ from dataclasses import dataclass from dataclasses import fields as dc_fields - -from typing import Any, Optional, ClassVar +from typing import Any, ClassVar from ._typing import Magnitude - -from .compat import HAS_NUMPY, exp, log, Self # noqa: F401 +from .compat import HAS_NUMPY, Self, exp, log # noqa: F401 @dataclass(frozen=True) @@ -51,7 +49,7 @@ def get_field_names(cls, new_cls: type) -> frozenset[str]: return frozenset(p.name for p in dc_fields(new_cls)) @classmethod - def preprocess_kwargs(cls, **kwargs: Any) -> Optional[dict[str, Any]]: + def preprocess_kwargs(cls, **kwargs: Any) -> dict[str, Any] | None: return None @classmethod diff --git a/pint/default_en.txt b/pint/default_en.txt index 5fc7f8265..45f241f18 100644 --- a/pint/default_en.txt +++ b/pint/default_en.txt @@ -494,12 +494,17 @@ buckingham = debye * angstrom bohr_magneton = e * hbar / (2 * m_e) = µ_B = mu_B nuclear_magneton = e * hbar / (2 * m_p) = µ_N = mu_N +# Refractive index +[refractive_index] = [] +refractive_index_unit = [] = RIU + # Logaritmic Unit Definition # Unit = scale; logbase; logfactor # x_dB = [logfactor] * log( x_lin / [scale] ) / log( [logbase] ) # Logaritmic Units of dimensionless quantity: [ https://en.wikipedia.org/wiki/Level_(logarithmic_quantity) ] +decibelwatt = watt; logbase: 10; logfactor: 10 = dBW decibelmilliwatt = 1e-3 watt; logbase: 10; logfactor: 10 = dBm decibelmicrowatt = 1e-6 watt; logbase: 10; logfactor: 10 = dBu diff --git a/pint/definitions.py b/pint/definitions.py index ce89e94d4..8a6cc496f 100644 --- a/pint/definitions.py +++ b/pint/definitions.py @@ -10,8 +10,9 @@ from __future__ import annotations +import flexparser as fp + from . import errors -from ._vendor import flexparser as fp from .delegates import ParserConfig, txt_defparser diff --git a/pint/delegates/__init__.py b/pint/delegates/__init__.py index e663a10c5..dc4699cf9 100644 --- a/pint/delegates/__init__.py +++ b/pint/delegates/__init__.py @@ -7,6 +7,7 @@ :copyright: 2022 by Pint Authors, see AUTHORS for more details. :license: BSD, see LICENSE for more details. """ +from __future__ import annotations from . import txt_defparser from .base_defparser import ParserConfig, build_disk_cache_class diff --git a/pint/delegates/base_defparser.py b/pint/delegates/base_defparser.py index 9e784ac64..44170f842 100644 --- a/pint/delegates/base_defparser.py +++ b/pint/delegates/base_defparser.py @@ -14,15 +14,16 @@ import itertools import numbers import pathlib -from dataclasses import dataclass, field +from dataclasses import dataclass +from typing import Any + +import flexcache as fc +import flexparser as fp from pint import errors from pint.facets.plain.definitions import NotNumeric from pint.util import ParserHelper, UnitsContainer -from .._vendor import flexcache as fc -from .._vendor import flexparser as fp - @dataclass(frozen=True) class ParserConfig: @@ -72,7 +73,7 @@ class PintParsedStatement(fp.ParsedStatement[ParserConfig]): @functools.lru_cache -def build_disk_cache_class(non_int_type: type): +def build_disk_cache_class(chosen_non_int_type: type): """Build disk cache class, taking into account the non_int_type.""" @dataclass(frozen=True) @@ -80,14 +81,18 @@ class PintHeader(fc.InvalidateByExist, fc.NameByFields, fc.BasicPythonHeader): from .. import __version__ pint_version: str = __version__ - non_int_type: str = field(default_factory=lambda: non_int_type.__qualname__) + non_int_type: str = chosen_non_int_type.__qualname__ + @dataclass(frozen=True) class PathHeader(fc.NameByFileContent, PintHeader): pass + @dataclass(frozen=True) class ParsedProjecHeader(fc.NameByHashIter, PintHeader): @classmethod - def from_parsed_project(cls, pp: fp.ParsedProject, reader_id): + def from_parsed_project( + cls, pp: fp.ParsedProject[Any, ParserConfig], reader_id: str + ): tmp = ( f"{stmt.content_hash.algorithm_name}:{stmt.content_hash.hexdigest}" for stmt in pp.iter_statements() diff --git a/pint/delegates/formatter/__init__.py b/pint/delegates/formatter/__init__.py index 31d36b0f6..5dab6a0f0 100644 --- a/pint/delegates/formatter/__init__.py +++ b/pint/delegates/formatter/__init__.py @@ -10,7 +10,7 @@ :copyright: 2022 by Pint Authors, see AUTHORS for more details. :license: BSD, see LICENSE for more details. """ - +from __future__ import annotations from .full import FullFormatter diff --git a/pint/delegates/formatter/_compound_unit_helpers.py b/pint/delegates/formatter/_compound_unit_helpers.py new file mode 100644 index 000000000..06a8ac2d3 --- /dev/null +++ b/pint/delegates/formatter/_compound_unit_helpers.py @@ -0,0 +1,328 @@ +""" + pint.delegates.formatter._compound_unit_helpers + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + Convenient functions to help organize compount units. + + :copyright: 2022 by Pint Authors, see AUTHORS for more details. + :license: BSD, see LICENSE for more details. +""" + + +from __future__ import annotations + +import functools +import locale +from collections.abc import Callable, Iterable +from functools import partial +from itertools import filterfalse, tee +from typing import ( + TYPE_CHECKING, + Any, + Literal, + TypedDict, + TypeVar, +) + +from ...compat import TypeAlias, babel_parse +from ...util import UnitsContainer + +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") +W = TypeVar("W") + +if TYPE_CHECKING: + from ...compat import Locale, Number + from ...facets.plain import PlainUnit + from ...registry import UnitRegistry + + +class SortKwds(TypedDict): + registry: UnitRegistry + + +SortFunc: TypeAlias = Callable[ + [Iterable[tuple[str, Any, str]], Any], Iterable[tuple[str, Any, str]] +] + + +class BabelKwds(TypedDict): + """Babel related keywords used in formatters.""" + + use_plural: bool + length: Literal["short", "long", "narrow"] | None + locale: Locale | str | None + + +def partition( + predicate: Callable[[T], bool], iterable: Iterable[T] +) -> tuple[filterfalse[T], filter[T]]: + """Partition entries into false entries and true entries. + + If *predicate* is slow, consider wrapping it with functools.lru_cache(). + """ + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + t1, t2 = tee(iterable) + return filterfalse(predicate, t1), filter(predicate, t2) + + +def localize_per( + length: Literal["short", "long", "narrow"] = "long", + locale: Locale | str | None = locale.LC_NUMERIC, + default: str | None = None, +) -> str: + """Localized singular and plural form of a unit. + + THIS IS TAKEN FROM BABEL format_unit. But + - No magnitude is returned in the string. + - If the unit is not found, the default is given. + - If the default is None, then the same value is given. + """ + locale = babel_parse(locale) + + patterns = locale._data["compound_unit_patterns"].get("per", None) + if patterns is None: + return default or "{}/{}" + + patterns = patterns.get(length, None) + if patterns is None: + return default or "{}/{}" + + # babel 2.8 + if isinstance(patterns, str): + return patterns + + # babe; 2.15 + return patterns.get("compound", default or "{}/{}") + + +@functools.lru_cache +def localize_unit_name( + measurement_unit: str, + use_plural: bool, + length: Literal["short", "long", "narrow"] = "long", + locale: Locale | str | None = locale.LC_NUMERIC, + default: str | None = None, +) -> str: + """Localized singular and plural form of a unit. + + THIS IS TAKEN FROM BABEL format_unit. But + - No magnitude is returned in the string. + - If the unit is not found, the default is given. + - If the default is None, then the same value is given. + """ + locale = babel_parse(locale) + from babel.units import _find_unit_pattern, get_unit_name + + q_unit = _find_unit_pattern(measurement_unit, locale=locale) + if not q_unit: + return measurement_unit + + unit_patterns = locale._data["unit_patterns"][q_unit].get(length, {}) + + if use_plural: + grammatical_number = "other" + else: + grammatical_number = "one" + + if grammatical_number in unit_patterns: + return unit_patterns[grammatical_number].format("").replace("\xa0", "").strip() + + if default is not None: + return default + + # Fall back to a somewhat bad representation. + # nb: This is marked as no-cover, as the current CLDR seemingly has no way for this to happen. + fallback_name = get_unit_name( + measurement_unit, length=length, locale=locale + ) # pragma: no cover + return f"{fallback_name or measurement_unit}" # pragma: no cover + + +def extract2(element: tuple[str, T, str]) -> tuple[str, T]: + """Extract display name and exponent from a tuple containing display name, exponent and unit name.""" + + return element[:2] + + +def to_name_exponent_name(element: tuple[str, T]) -> tuple[str, T, str]: + """Convert unit name and exponent to unit name as display name, exponent and unit name.""" + + # TODO: write a generic typing + + return element + (element[0],) + + +def to_symbol_exponent_name( + el: tuple[str, T], registry: UnitRegistry +) -> tuple[str, T, str]: + """Convert unit name and exponent to unit symbol as display name, exponent and unit name.""" + return registry._get_symbol(el[0]), el[1], el[0] + + +def localize_display_exponent_name( + element: tuple[str, T, str], + use_plural: bool, + length: Literal["short", "long", "narrow"] = "long", + locale: Locale | str | None = locale.LC_NUMERIC, + default: str | None = None, +) -> tuple[str, T, str]: + """Localize display name in a triplet display name, exponent and unit name.""" + + return ( + localize_unit_name( + element[2], use_plural, length, locale, default or element[0] + ), + element[1], + element[2], + ) + + +##################### +# Sorting functions +##################### + + +def sort_by_unit_name( + items: Iterable[tuple[str, Number, str]], _registry: UnitRegistry | None +) -> Iterable[tuple[str, Number, str]]: + return sorted(items, key=lambda el: el[2]) + + +def sort_by_display_name( + items: Iterable[tuple[str, Number, str]], _registry: UnitRegistry | None +) -> Iterable[tuple[str, Number, str]]: + return sorted(items) + + +def sort_by_dimensionality( + items: Iterable[tuple[str, Number, str]], registry: UnitRegistry | None +) -> Iterable[tuple[str, Number, str]]: + """Sort a list of units by dimensional order (from `registry.formatter.dim_order`). + + Parameters + ---------- + items : tuple + a list of tuples containing (unit names, exponent values). + registry : UnitRegistry | None + the registry to use for looking up the dimensions of each unit. + + Returns + ------- + list + the list of units sorted by most significant dimension first. + + Raises + ------ + KeyError + If unit cannot be found in the registry. + """ + + if registry is None: + return items + + dim_order = registry.formatter.dim_order + + def sort_key(item: tuple[str, Number, str]): + _display_name, _unit_exponent, unit_name = item + cname = registry.get_name(unit_name) + cname_dims = registry.get_dimensionality(cname) or {"[]": None} + for cname_dim in cname_dims: + if cname_dim in dim_order: + return dim_order.index(cname_dim), cname + + raise KeyError(f"Unit {unit_name} (aka {cname}) has no recognized dimensions") + + return sorted(items, key=sort_key) + + +def prepare_compount_unit( + unit: PlainUnit | UnitsContainer | Iterable[tuple[str, T]], + spec: str = "", + sort_func: SortFunc | None = None, + use_plural: bool = True, + length: Literal["short", "long", "narrow"] | None = None, + locale: Locale | str | None = None, + as_ratio: bool = True, + registry: UnitRegistry | None = None, +) -> tuple[Iterable[tuple[str, T]], Iterable[tuple[str, T]]]: + """Format compound unit into unit container given + an spec and locale. + + Returns + ------- + iterable of display name, exponent, canonical name + """ + + if isinstance(unit, UnitsContainer): + out = unit.items() + elif hasattr(unit, "_units"): + out = unit._units.items() + else: + out = unit + + # out: unit_name, unit_exponent + + if len(out) == 0: + if "~" in spec: + return ([], []) + else: + return ([("dimensionless", 1)], []) + + if "~" in spec: + if registry is None: + raise ValueError( + f"Can't short format a {type(unit)} without a registry." + " This is usually triggered when formatting a instance" + " of the internal `UnitsContainer`." + ) + _to_symbol_exponent_name = partial(to_symbol_exponent_name, registry=registry) + out = map(_to_symbol_exponent_name, out) + else: + out = map(to_name_exponent_name, out) + + # We keep unit_name because the sort or localizing functions might needed. + # out: display_unit_name, unit_exponent, unit_name + + if as_ratio: + numerator, denominator = partition(lambda el: el[1] < 0, out) + else: + numerator, denominator = out, () + + # numerator: display_unit_name, unit_name, unit_exponent + # denominator: display_unit_name, unit_name, unit_exponent + + if locale is None: + if sort_func is not None: + numerator = sort_func(numerator, registry) + denominator = sort_func(denominator, registry) + + return map(extract2, numerator), map(extract2, denominator) + + if length is None: + length = "short" if "~" in spec else "long" + + mapper = partial( + localize_display_exponent_name, use_plural=False, length=length, locale=locale + ) + + numerator = map(mapper, numerator) + denominator = map(mapper, denominator) + + if sort_func is not None: + numerator = sort_func(numerator, registry) + denominator = sort_func(denominator, registry) + + if use_plural: + if not isinstance(numerator, list): + numerator = list(numerator) + numerator[-1] = localize_display_exponent_name( + numerator[-1], + use_plural, + length=length, + locale=locale, + default=numerator[-1][0], + ) + + return map(extract2, numerator), map(extract2, denominator) diff --git a/pint/delegates/formatter/_format_helpers.py b/pint/delegates/formatter/_format_helpers.py index 2ed4ba985..8a2f37a59 100644 --- a/pint/delegates/formatter/_format_helpers.py +++ b/pint/delegates/formatter/_format_helpers.py @@ -11,28 +11,19 @@ from __future__ import annotations +import re +from collections.abc import Callable, Generator, Iterable +from contextlib import contextmanager from functools import partial +from locale import LC_NUMERIC, getlocale, setlocale from typing import ( + TYPE_CHECKING, Any, - Generator, - Iterable, TypeVar, - Callable, - TYPE_CHECKING, - Literal, - TypedDict, ) -from locale import getlocale, setlocale, LC_NUMERIC -from contextlib import contextmanager -from warnings import warn - -import locale - -from pint.delegates.formatter._spec_helpers import FORMATTER, _join - -from ...compat import babel_parse, ndarray -from ...util import UnitsContainer +from ...compat import ndarray +from ._spec_helpers import FORMATTER try: from numpy import integer as np_integer @@ -40,21 +31,15 @@ np_integer = None if TYPE_CHECKING: - from ...registry import UnitRegistry - from ...facets.plain import PlainUnit from ...compat import Locale, Number T = TypeVar("T") U = TypeVar("U") V = TypeVar("V") +W = TypeVar("W") - -class BabelKwds(TypedDict): - """Babel related keywords used in formatters.""" - - use_plural: bool - length: Literal["short", "long", "narrow"] | None - locale: Locale | str | None +_PRETTY_EXPONENTS = "⁰¹²³⁴⁵⁶⁷⁸⁹" +_JOIN_REG_EXP = re.compile(r"{\d*}") def format_number(value: Any, spec: str = "") -> str: @@ -113,163 +98,64 @@ def override_locale( setlocale(LC_NUMERIC, prev_locale_string) -def format_unit_no_magnitude( - measurement_unit: str, - use_plural: bool = True, - length: Literal["short", "long", "narrow"] = "long", - locale: Locale | str | None = locale.LC_NUMERIC, -) -> str | None: - """Format a value of a given unit. - - THIS IS TAKEN FROM BABEL format_unit. But - - No magnitude is returned in the string. - - If the unit is not found, the same is given. - - use_plural instead of value - - Values are formatted according to the locale's usual pluralization rules - and number formats. - - >>> format_unit(12, 'length-meter', locale='ro_RO') - u'metri' - >>> format_unit(15.5, 'length-mile', locale='fi_FI') - u'mailia' - >>> format_unit(1200, 'pressure-millimeter-ofhg', locale='nb') - u'millimeter kvikks\\xf8lv' - >>> format_unit(270, 'ton', locale='en') - u'tons' - >>> format_unit(1234.5, 'kilogram', locale='ar_EG', numbering_system='default') - u'كيلوغرام' - - - The locale's usual pluralization rules are respected. - - >>> format_unit(1, 'length-meter', locale='ro_RO') - u'metru' - >>> format_unit(0, 'length-mile', locale='cy') - u'mi' - >>> format_unit(1, 'length-mile', locale='cy') - u'filltir' - >>> format_unit(3, 'length-mile', locale='cy') - u'milltir' - - >>> format_unit(15, 'length-horse', locale='fi') - Traceback (most recent call last): - ... - UnknownUnitError: length-horse is not a known unit in fi - - .. versionadded:: 2.2.0 - - :param value: the value to format. If this is a string, no number formatting will be attempted. - :param measurement_unit: the code of a measurement unit. - Known units can be found in the CLDR Unit Validity XML file: - https://unicode.org/repos/cldr/tags/latest/common/validity/unit.xml - :param length: "short", "long" or "narrow" - :param format: An optional format, as accepted by `format_decimal`. - :param locale: the `Locale` object or locale identifier - :param numbering_system: The numbering system used for formatting number symbols. Defaults to "latn". - The special value "default" will use the default numbering system of the locale. - :raise `UnsupportedNumberingSystemError`: If the numbering system is not supported by the locale. - """ - locale = babel_parse(locale) - from babel.units import _find_unit_pattern, get_unit_name +def pretty_fmt_exponent(num: Number) -> str: + """Format an number into a pretty printed exponent.""" + # unicode dot operator (U+22C5) looks like a superscript decimal + ret = f"{num:n}".replace("-", "⁻").replace(".", "\u22C5") + for n in range(10): + ret = ret.replace(str(n), _PRETTY_EXPONENTS[n]) + return ret - q_unit = _find_unit_pattern(measurement_unit, locale=locale) - if not q_unit: - return measurement_unit - unit_patterns = locale._data["unit_patterns"][q_unit].get(length, {}) +def join_u(fmt: str, iterable: Iterable[Any]) -> str: + """Join an iterable with the format specified in fmt. - if use_plural: - plural_form = "other" - else: - plural_form = "one" - - if plural_form in unit_patterns: - return unit_patterns[plural_form].format("").replace("\xa0", "").strip() - - # Fall back to a somewhat bad representation. - # nb: This is marked as no-cover, as the current CLDR seemingly has no way for this to happen. - fallback_name = get_unit_name( - measurement_unit, length=length, locale=locale - ) # pragma: no cover - return f"{fallback_name or measurement_unit}" # pragma: no cover - - -def map_keys( - func: Callable[ - [ - T, - ], - U, - ], - items: Iterable[tuple[T, V]], -) -> Iterable[tuple[U, V]]: - """Map dict keys given an items view.""" - return map(lambda el: (func(el[0]), el[1]), items) - - -def short_form( - units: Iterable[tuple[str, T]], - registry: UnitRegistry, -) -> Iterable[tuple[str, T]]: - """Replace each unit by its short form.""" - return map_keys(registry._get_symbol, units) - - -def localized_form( - units: Iterable[tuple[str, T]], - use_plural: bool, - length: Literal["short", "long", "narrow"], - locale: Locale | str, -) -> Iterable[tuple[str, T]]: - """Replace each unit by its localized version.""" - mapper = partial( - format_unit_no_magnitude, - use_plural=use_plural, - length=length, - locale=babel_parse(locale), - ) - - return map_keys(mapper, units) - - -def format_compound_unit( - unit: PlainUnit | UnitsContainer, - spec: str = "", - use_plural: bool = False, - length: Literal["short", "long", "narrow"] | None = None, - locale: Locale | str | None = None, -) -> Iterable[tuple[str, Number]]: - """Format compound unit into unit container given - an spec and locale. + The format can be specified in two ways: + - PEP3101 format with two replacement fields (eg. '{} * {}') + - The concatenating string (eg. ' * ') """ + if not iterable: + return "" + if not _JOIN_REG_EXP.search(fmt): + return fmt.join(iterable) + miter = iter(iterable) + first = next(miter) + for val in miter: + ret = fmt.format(first, val) + first = ret + return first - # TODO: provisional? Should we allow unbounded units? - # Should we allow UnitsContainer? - registry = getattr(unit, "_REGISTRY", None) - if isinstance(unit, UnitsContainer): - out = unit.items() - else: - out = unit._units.items() +def join_mu(joint_fstring: str, mstr: str, ustr: str) -> str: + """Join magnitude and units. + + This avoids that `3 and `1 / m` becomes `3 1 / m` + """ + if ustr == "": + return mstr + if ustr.startswith("1 / "): + return joint_fstring.format(mstr, ustr[2:]) + return joint_fstring.format(mstr, ustr) - if "~" in spec: - if registry is None: - raise ValueError( - f"Can't short format a {type(unit)} without a registry." - " This is usually triggered when formatting a instance" - " of the internal `UnitsContainer`." - ) - out = short_form(out, registry) - if locale is not None: - out = localized_form(out, use_plural, length or "long", locale) +def join_unc(joint_fstring: str, lpar: str, rpar: str, mstr: str, ustr: str) -> str: + """Join uncertainty magnitude and units. - return out + Uncertainty magnitudes might require extra parenthesis when joined to units. + - YES: 3 +/- 1 + - NO : 3(1) + - NO : (3 +/ 1)e-9 + + This avoids that `(3 + 1)` and `meter` becomes ((3 +/- 1) meter) + """ + if mstr.startswith(lpar) or mstr.endswith(rpar): + return joint_fstring.format(mstr, ustr) + return joint_fstring.format(lpar + mstr + rpar, ustr) def formatter( - items: Iterable[tuple[str, Number]], + numerator: Iterable[tuple[str, Number]], + denominator: Iterable[tuple[str, Number]], as_ratio: bool = True, single_denominator: bool = False, product_fmt: str = " * ", @@ -277,14 +163,6 @@ def formatter( power_fmt: str = "{} ** {}", parentheses_fmt: str = "({0})", exp_call: FORMATTER = "{:n}".format, - sort: bool | None = None, - sort_func: Callable[ - [ - Iterable[tuple[str, Number]], - ], - Iterable[tuple[str, Number]], - ] - | None = sorted, ) -> str: """Format a list of (name, exponent) pairs. @@ -307,8 +185,6 @@ def formatter( the format used for parenthesis. (Default value = "({0})") exp_call : callable (Default value = lambda x: f"{x:n}") - sort : bool, optional - True to sort the formatted units alphabetically (Default value = True) Returns ------- @@ -317,61 +193,43 @@ def formatter( """ - if sort is False: - warn( - "The boolean `sort` argument is deprecated. " - "Use `sort_fun` to specify the sorting function (default=sorted) " - "or None to keep units in the original order." - ) - sort_func = None - elif sort is True: - warn( - "The boolean `sort` argument is deprecated. " - "Use `sort_fun` to specify the sorting function (default=sorted) " - "or None to keep units in the original order." - ) - sort_func = sorted - - if sort_func is None: - items = tuple(items) - else: - items = sort_func(items) - - if not items: - return "" - if as_ratio: fun = lambda x: exp_call(abs(x)) else: fun = exp_call - pos_terms, neg_terms = [], [] - - for key, value in items: + pos_terms: list[str] = [] + for key, value in numerator: if value == 1: pos_terms.append(key) - elif value > 0: + else: pos_terms.append(power_fmt.format(key, fun(value))) - elif value == -1 and as_ratio: + + neg_terms: list[str] = [] + for key, value in denominator: + if value == -1 and as_ratio: neg_terms.append(key) else: neg_terms.append(power_fmt.format(key, fun(value))) + if not pos_terms and not neg_terms: + return "" + if not as_ratio: # Show as Product: positive * negative terms ** -1 - return _join(product_fmt, pos_terms + neg_terms) + return join_u(product_fmt, pos_terms + neg_terms) # Show as Ratio: positive terms / negative terms - pos_ret = _join(product_fmt, pos_terms) or "1" + pos_ret = join_u(product_fmt, pos_terms) or "1" if not neg_terms: return pos_ret if single_denominator: - neg_ret = _join(product_fmt, neg_terms) + neg_ret = join_u(product_fmt, neg_terms) if len(neg_terms) > 1: neg_ret = parentheses_fmt.format(neg_ret) else: - neg_ret = _join(division_fmt, neg_terms) + neg_ret = join_u(division_fmt, neg_terms) - return _join(division_fmt, [pos_ret, neg_ret]) + return join_u(division_fmt, [pos_ret, neg_ret]) diff --git a/pint/delegates/formatter/_spec_helpers.py b/pint/delegates/formatter/_spec_helpers.py index 27f6c5726..344859b38 100644 --- a/pint/delegates/formatter/_spec_helpers.py +++ b/pint/delegates/formatter/_spec_helpers.py @@ -10,10 +10,11 @@ from __future__ import annotations -from typing import Iterable, Callable, Any -import warnings -from ...compat import Number +import functools import re +import warnings +from collections.abc import Callable +from typing import Any FORMATTER = Callable[ [ @@ -26,8 +27,6 @@ # http://docs.python.org/2/library/string.html#format-specification-mini-language # We also add uS for uncertainties. _BASIC_TYPES = frozenset("bcdeEfFgGnosxX%uS") -_PRETTY_EXPONENTS = "⁰¹²³⁴⁵⁶⁷⁸⁹" -_JOIN_REG_EXP = re.compile(r"{\d*}") REGISTERED_FORMATTERS: dict[str, Any] = {} @@ -58,34 +57,6 @@ def parse_spec(spec: str) -> str: return result -def _join(fmt: str, iterable: Iterable[Any]) -> str: - """Join an iterable with the format specified in fmt. - - The format can be specified in two ways: - - PEP3101 format with two replacement fields (eg. '{} * {}') - - The concatenating string (eg. ' * ') - """ - if not iterable: - return "" - if not _JOIN_REG_EXP.search(fmt): - return fmt.join(iterable) - miter = iter(iterable) - first = next(miter) - for val in miter: - ret = fmt.format(first, val) - first = ret - return first - - -def pretty_fmt_exponent(num: Number) -> str: - """Format an number into a pretty printed exponent.""" - # unicode dot operator (U+22C5) looks like a superscript decimal - ret = f"{num:n}".replace("-", "⁻").replace(".", "\u22C5") - for n in range(10): - ret = ret.replace(str(n), _PRETTY_EXPONENTS[n]) - return ret - - def extract_custom_flags(spec: str) -> str: """Return custom flags present in a format specification @@ -116,6 +87,7 @@ def remove_custom_flags(spec: str) -> str: return spec +@functools.lru_cache def split_format( spec: str, default: str, separate_format_defaults: bool = True ) -> tuple[str, str]: @@ -157,28 +129,3 @@ def split_format( uspec = uspec or default_uspec return mspec, uspec - - -def join_mu(joint_fstring: str, mstr: str, ustr: str) -> str: - """Join magnitude and units. - - This avoids that `3 and `1 / m` becomes `3 1 / m` - """ - if ustr.startswith("1 / "): - return joint_fstring.format(mstr, ustr[2:]) - return joint_fstring.format(mstr, ustr) - - -def join_unc(joint_fstring: str, lpar: str, rpar: str, mstr: str, ustr: str) -> str: - """Join uncertainty magnitude and units. - - Uncertainty magnitudes might require extra parenthesis when joined to units. - - YES: 3 +/- 1 - - NO : 3(1) - - NO : (3 +/ 1)e-9 - - This avoids that `(3 + 1)` and `meter` becomes ((3 +/- 1) meter) - """ - if mstr.startswith(lpar) or mstr.endswith(rpar): - return joint_fstring.format(mstr, ustr) - return joint_fstring.format(lpar + mstr + rpar, ustr) diff --git a/pint/delegates/formatter/_to_register.py b/pint/delegates/formatter/_to_register.py index b2c2a3f38..697973716 100644 --- a/pint/delegates/formatter/_to_register.py +++ b/pint/delegates/formatter/_to_register.py @@ -8,16 +8,19 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable -from ...compat import ndarray, np, Unpack -from ._spec_helpers import split_format, join_mu, REGISTERED_FORMATTERS +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Iterable from ..._typing import Magnitude - -from ._format_helpers import format_compound_unit, BabelKwds, override_locale +from ...compat import Unpack, ndarray, np +from ...util import UnitsContainer +from ._compound_unit_helpers import BabelKwds, prepare_compount_unit +from ._format_helpers import join_mu, override_locale +from ._spec_helpers import REGISTERED_FORMATTERS, split_format +from .plain import BaseFormatter if TYPE_CHECKING: - from ...facets.plain import PlainQuantity, PlainUnit, MagnitudeT + from ...facets.plain import MagnitudeT, PlainQuantity, PlainUnit from ...registry import UnitRegistry @@ -57,7 +60,9 @@ def wrapper(func: Callable[[PlainUnit, UnitRegistry], str]): if name in REGISTERED_FORMATTERS: raise ValueError(f"format {name!r} already exists") # or warn instead - class NewFormatter: + class NewFormatter(BaseFormatter): + spec = name + def format_magnitude( self, magnitude: Magnitude, @@ -78,13 +83,25 @@ def format_magnitude( return mstr def format_unit( - self, unit: PlainUnit, uspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + unit: PlainUnit | Iterable[tuple[str, Any]], + uspec: str = "", + **babel_kwds: Unpack[BabelKwds], ) -> str: - units = unit._REGISTRY.UnitsContainer( - format_compound_unit(unit, uspec, **babel_kwds) + numerator, _denominator = prepare_compount_unit( + unit, + uspec, + **babel_kwds, + as_ratio=False, + registry=self._registry, ) - return func(units, registry=unit._REGISTRY, **babel_kwds) + if self._registry is None: + units = UnitsContainer(numerator) + else: + units = self._registry.UnitsContainer(numerator) + + return func(units, registry=self._registry) def format_quantity( self, @@ -92,19 +109,22 @@ def format_quantity( qspec: str = "", **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = quantity._REGISTRY + registry = self._registry - mspec, uspec = split_format( - qspec, - registry.formatter.default_format, - registry.separate_format_defaults, - ) + if registry is None: + mspec, uspec = split_format(qspec, "", True) + else: + mspec, uspec = split_format( + qspec, + registry.formatter.default_format, + registry.separate_format_defaults, + ) joint_fstring = "{} {}" return join_mu( joint_fstring, self.format_magnitude(quantity.magnitude, mspec, **babel_kwds), - self.format_unit(quantity.units, uspec, **babel_kwds), + self.format_unit(quantity.unit_items(), uspec, **babel_kwds), ) REGISTERED_FORMATTERS[name] = NewFormatter() diff --git a/pint/delegates/formatter/full.py b/pint/delegates/formatter/full.py index fae26d524..d5de43326 100644 --- a/pint/delegates/formatter/full.py +++ b/pint/delegates/formatter/full.py @@ -11,25 +11,36 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal, Optional, Any import locale -from ...compat import babel_parse, Unpack -from ...util import iterable +from typing import TYPE_CHECKING, Any, Iterable, Literal from ..._typing import Magnitude +from ...compat import Unpack, babel_parse +from ...util import iterable +from ._compound_unit_helpers import BabelKwds, SortFunc, sort_by_unit_name +from ._to_register import REGISTERED_FORMATTERS from .html import HTMLFormatter from .latex import LatexFormatter, SIunitxFormatter -from .plain import RawFormatter, CompactFormatter, PrettyFormatter, DefaultFormatter -from ._format_helpers import BabelKwds -from ._to_register import REGISTERED_FORMATTERS +from .plain import ( + BaseFormatter, + CompactFormatter, + DefaultFormatter, + PrettyFormatter, + RawFormatter, +) if TYPE_CHECKING: - from ...facets.plain import PlainQuantity, PlainUnit, MagnitudeT - from ...facets.measurement import Measurement from ...compat import Locale + from ...facets.measurement import Measurement + from ...facets.plain import ( + MagnitudeT, + PlainQuantity, + PlainUnit, + ) + from ...registry import UnitRegistry -class FullFormatter: +class FullFormatter(BaseFormatter): """A formatter that dispatch to other formatters. Has a default format, locale and babel_length @@ -39,10 +50,35 @@ class FullFormatter: default_format: str = "" - locale: Optional[Locale] = None - babel_length: Literal["short", "long", "narrow"] = "long" + # TODO: This can be over-riden by the registry definitions file + dim_order: tuple[str, ...] = ( + "[substance]", + "[mass]", + "[current]", + "[luminosity]", + "[length]", + "[]", + "[time]", + "[temperature]", + ) + + default_sort_func: SortFunc | None = staticmethod(sort_by_unit_name) - def set_locale(self, loc: Optional[str]) -> None: + locale: Locale | None = None + + def __init__(self, registry: UnitRegistry | None = None): + super().__init__(registry) + + self._formatters = {} + self._formatters["raw"] = RawFormatter(registry) + self._formatters["D"] = DefaultFormatter(registry) + self._formatters["H"] = HTMLFormatter(registry) + self._formatters["P"] = PrettyFormatter(registry) + self._formatters["Lx"] = SIunitxFormatter(registry) + self._formatters["L"] = LatexFormatter(registry) + self._formatters["C"] = CompactFormatter(registry) + + def set_locale(self, loc: str | None) -> None: """Change the locale used by default by `format_babel`. Parameters @@ -59,16 +95,6 @@ def set_locale(self, loc: Optional[str]) -> None: self.locale = loc - def __init__(self) -> None: - self._formatters = {} - self._formatters["raw"] = RawFormatter() - self._formatters["D"] = DefaultFormatter() - self._formatters["H"] = HTMLFormatter() - self._formatters["P"] = PrettyFormatter() - self._formatters["Lx"] = SIunitxFormatter() - self._formatters["L"] = LatexFormatter() - self._formatters["C"] = CompactFormatter() - def get_formatter(self, spec: str): if spec == "": return self._formatters["D"] @@ -76,12 +102,20 @@ def get_formatter(self, spec: str): if k in spec: return v - try: - return REGISTERED_FORMATTERS[spec] - except KeyError: - pass + for k, v in REGISTERED_FORMATTERS.items(): + if k in spec: + orphan_fmt = REGISTERED_FORMATTERS[k] + break + else: + return self._formatters["D"] - return self._formatters["D"] + try: + fmt = orphan_fmt.__class__(self._registry) + spec = getattr(fmt, "spec", spec) + self._formatters[spec] = fmt + return fmt + except Exception: + return orphan_fmt def format_magnitude( self, magnitude: Magnitude, mspec: str = "", **babel_kwds: Unpack[BabelKwds] @@ -92,10 +126,17 @@ def format_magnitude( ) def format_unit( - self, unit: PlainUnit, uspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + unit: PlainUnit | Iterable[tuple[str, Any]], + uspec: str = "", + sort_func: SortFunc | None = None, + **babel_kwds: Unpack[BabelKwds], ) -> str: uspec = uspec or self.default_format - return self.get_formatter(uspec).format_unit(unit, uspec, **babel_kwds) + sort_func = sort_func or self.default_sort_func + return self.get_formatter(uspec).format_unit( + unit, uspec, sort_func=sort_func, **babel_kwds + ) def format_quantity( self, @@ -113,16 +154,25 @@ def format_quantity( del quantity - use_plural = obj.magnitude > 1 - if iterable(use_plural): - use_plural = True + locale = babel_kwds.get("locale", self.locale) + + if locale: + if "use_plural" in babel_kwds: + use_plural = babel_kwds["use_plural"] + else: + use_plural = obj.magnitude > 1 + if iterable(use_plural): + use_plural = True + else: + use_plural = False return self.get_formatter(spec).format_quantity( obj, spec, - use_plural=babel_kwds.get("use_plural", use_plural), - length=babel_kwds.get("length", self.babel_length), - locale=babel_kwds.get("locale", self.locale), + sort_func=self.default_sort_func, + use_plural=use_plural, + length=babel_kwds.get("length", None), + locale=locale, ) def format_measurement( @@ -148,8 +198,9 @@ def format_measurement( return self.get_formatter(meas_spec).format_measurement( obj, meas_spec, + sort_func=self.default_sort_func, use_plural=babel_kwds.get("use_plural", use_plural), - length=babel_kwds.get("length", self.babel_length), + length=babel_kwds.get("length", None), locale=babel_kwds.get("locale", self.locale), ) @@ -159,10 +210,10 @@ def format_measurement( def format_unit_babel( self, - unit: PlainUnit, + unit: PlainUnit | Iterable[tuple[str, Any]], spec: str = "", - length: Optional[Literal["short", "long", "narrow"]] = "long", - locale: Optional[Locale] = None, + length: Literal["short", "long", "narrow"] | None = None, + locale: Locale | None = None, ) -> str: if self.locale is None and locale is None: raise ValueError( @@ -172,8 +223,9 @@ def format_unit_babel( return self.format_unit( unit, spec or self.default_format, + sort_func=self.default_sort_func, use_plural=False, - length=length or self.babel_length, + length=length, locale=locale or self.locale, ) @@ -181,8 +233,8 @@ def format_quantity_babel( self, quantity: PlainQuantity[MagnitudeT], spec: str = "", - length: Literal["short", "long", "narrow"] = "long", - locale: Optional[Locale] = None, + length: Literal["short", "long", "narrow"] | None = None, + locale: Locale | None = None, ) -> str: if self.locale is None and locale is None: raise ValueError( @@ -192,11 +244,13 @@ def format_quantity_babel( use_plural = quantity.magnitude > 1 if iterable(use_plural): use_plural = True + return self.format_quantity( quantity, spec or self.default_format, + sort_func=self.default_sort_func, use_plural=use_plural, - length=length or self.babel_length, + length=length, locale=locale or self.locale, ) diff --git a/pint/delegates/formatter/html.py b/pint/delegates/formatter/html.py index 3dc14330c..b8e3f517f 100644 --- a/pint/delegates/formatter/html.py +++ b/pint/delegates/formatter/html.py @@ -11,28 +11,38 @@ from __future__ import annotations -from typing import TYPE_CHECKING import re +from typing import TYPE_CHECKING, Any, Iterable + +from ..._typing import Magnitude +from ...compat import Unpack, ndarray, np from ...util import iterable -from ...compat import ndarray, np, Unpack -from ._spec_helpers import ( - split_format, +from ._compound_unit_helpers import ( + BabelKwds, + SortFunc, + localize_per, + prepare_compount_unit, +) +from ._format_helpers import ( + formatter, join_mu, join_unc, + override_locale, +) +from ._spec_helpers import ( remove_custom_flags, + split_format, ) - -from ..._typing import Magnitude -from ._format_helpers import BabelKwds, format_compound_unit, formatter, override_locale +from .plain import BaseFormatter if TYPE_CHECKING: - from ...facets.plain import PlainQuantity, PlainUnit, MagnitudeT from ...facets.measurement import Measurement + from ...facets.plain import MagnitudeT, PlainQuantity, PlainUnit _EXP_PATTERN = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)") -class HTMLFormatter: +class HTMLFormatter(BaseFormatter): """HTML localizable text formatter.""" def format_magnitude( @@ -75,16 +85,33 @@ def format_magnitude( return mstr def format_unit( - self, unit: PlainUnit, uspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + unit: PlainUnit | Iterable[tuple[str, Any]], + uspec: str = "", + sort_func: SortFunc | None = None, + **babel_kwds: Unpack[BabelKwds], ) -> str: - units = format_compound_unit(unit, uspec, **babel_kwds) + numerator, denominator = prepare_compount_unit( + unit, + uspec, + sort_func=sort_func, + **babel_kwds, + registry=self._registry, + ) + + if babel_kwds.get("locale", None): + length = babel_kwds.get("length") or ("short" if "~" in uspec else "long") + division_fmt = localize_per(length, babel_kwds.get("locale"), "{}/{}") + else: + division_fmt = "{}/{}" return formatter( - units, + numerator, + denominator, as_ratio=True, single_denominator=True, product_fmt=r" ", - division_fmt=r"{}/{}", + division_fmt=division_fmt, power_fmt=r"{}{}", parentheses_fmt=r"({})", ) @@ -93,9 +120,10 @@ def format_quantity( self, quantity: PlainQuantity[MagnitudeT], qspec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = quantity._REGISTRY + registry = self._registry mspec, uspec = split_format( qspec, registry.formatter.default_format, registry.separate_format_defaults @@ -116,13 +144,14 @@ def format_quantity( return join_mu( joint_fstring, self.format_magnitude(quantity.magnitude, mspec, **babel_kwds), - self.format_unit(quantity.units, uspec, **babel_kwds), + self.format_unit(quantity.unit_items(), uspec, sort_func, **babel_kwds), ) def format_uncertainty( self, uncertainty, unc_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: unc_str = format(uncertainty, unc_spec).replace("+/-", " ± ") @@ -135,9 +164,10 @@ def format_measurement( self, measurement: Measurement, meas_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = measurement._REGISTRY + registry = self._registry mspec, uspec = split_format( meas_spec, @@ -154,5 +184,5 @@ def format_measurement( "(", ")", self.format_uncertainty(measurement.magnitude, unc_spec, **babel_kwds), - self.format_unit(measurement.units, uspec, **babel_kwds), + self.format_unit(measurement.units, uspec, sort_func, **babel_kwds), ) diff --git a/pint/delegates/formatter/latex.py b/pint/delegates/formatter/latex.py index aacf8cdf5..468a65fa4 100644 --- a/pint/delegates/formatter/latex.py +++ b/pint/delegates/formatter/latex.py @@ -12,23 +12,37 @@ from __future__ import annotations -import functools - -from typing import TYPE_CHECKING, Any, Iterable, Union +import functools import re -from ._spec_helpers import split_format, FORMATTER +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any from ..._typing import Magnitude -from ...compat import ndarray, Unpack, Number -from ._format_helpers import BabelKwds, formatter, override_locale, format_compound_unit -from ._spec_helpers import join_mu, join_unc, remove_custom_flags +from ...compat import Number, Unpack, ndarray +from ._compound_unit_helpers import ( + BabelKwds, + SortFunc, + prepare_compount_unit, +) +from ._format_helpers import ( + FORMATTER, + formatter, + join_mu, + join_unc, + override_locale, +) +from ._spec_helpers import ( + remove_custom_flags, + split_format, +) +from .plain import BaseFormatter if TYPE_CHECKING: - from ...facets.plain import PlainQuantity, PlainUnit, MagnitudeT from ...facets.measurement import Measurement - from ...util import ItMatrix + from ...facets.plain import MagnitudeT, PlainQuantity, PlainUnit from ...registry import UnitRegistry + from ...util import ItMatrix def vector_to_latex( @@ -110,8 +124,8 @@ def siunitx_format_unit( ) -> str: """Returns LaTeX code for the unit that can be put into an siunitx command.""" - def _tothe(power: Union[int, float]) -> str: - if isinstance(power, int) or (isinstance(power, float) and power.is_integer()): + def _tothe(power) -> str: + if power == int(power): if power == 1: return "" elif power == 2: @@ -153,7 +167,7 @@ def _tothe(power: Union[int, float]) -> str: _EXP_PATTERN = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)") -class LatexFormatter: +class LatexFormatter(BaseFormatter): """Latex localizable text formatter.""" def format_magnitude( @@ -170,13 +184,35 @@ def format_magnitude( return mstr def format_unit( - self, unit: PlainUnit, uspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + unit: PlainUnit | Iterable[tuple[str, Any]], + uspec: str = "", + sort_func: SortFunc | None = None, + **babel_kwds: Unpack[BabelKwds], ) -> str: - units = format_compound_unit(unit, uspec, **babel_kwds) + numerator, denominator = prepare_compount_unit( + unit, + uspec, + sort_func=sort_func, + **babel_kwds, + registry=self._registry, + ) + + numerator = ((rf"\mathrm{{{latex_escape(u)}}}", p) for u, p in numerator) + denominator = ((rf"\mathrm{{{latex_escape(u)}}}", p) for u, p in denominator) + + # Localized latex + # if babel_kwds.get("locale", None): + # length = babel_kwds.get("length") or ("short" if "~" in uspec else "long") + # division_fmt = localize_per(length, babel_kwds.get("locale"), "{}/{}") + # else: + # division_fmt = "{}/{}" + + # division_fmt = r"\frac" + division_fmt.format("[{}]", "[{}]") - preprocessed = {rf"\mathrm{{{latex_escape(u)}}}": p for u, p in units} formatted = formatter( - preprocessed.items(), + numerator, + denominator, as_ratio=True, single_denominator=True, product_fmt=r" \cdot ", @@ -184,15 +220,17 @@ def format_unit( power_fmt="{}^[{}]", parentheses_fmt=r"\left({}\right)", ) + return formatted.replace("[", "{").replace("]", "}") def format_quantity( self, quantity: PlainQuantity[MagnitudeT], qspec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = quantity._REGISTRY + registry = self._registry mspec, uspec = split_format( qspec, registry.formatter.default_format, registry.separate_format_defaults @@ -203,13 +241,14 @@ def format_quantity( return join_mu( joint_fstring, self.format_magnitude(quantity.magnitude, mspec, **babel_kwds), - self.format_unit(quantity.units, uspec, **babel_kwds), + self.format_unit(quantity.unit_items(), uspec, sort_func, **babel_kwds), ) def format_uncertainty( self, uncertainty, unc_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: # uncertainties handles everythin related to latex. @@ -224,9 +263,10 @@ def format_measurement( self, measurement: Measurement, meas_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = measurement._REGISTRY + registry = self._registry mspec, uspec = split_format( meas_spec, @@ -240,25 +280,29 @@ def format_measurement( if "L" not in unc_spec: unc_spec += "L" - joint_fstring = "{}\ {}" + joint_fstring = r"{}\ {}" return join_unc( joint_fstring, r"\left(", r"\right)", self.format_uncertainty(measurement.magnitude, unc_spec, **babel_kwds), - self.format_unit(measurement.units, uspec, **babel_kwds), + self.format_unit(measurement.units, uspec, sort_func, **babel_kwds), ) -class SIunitxFormatter: +class SIunitxFormatter(BaseFormatter): """Latex localizable text formatter with siunitx format. See: https://ctan.org/pkg/siunitx """ def format_magnitude( - self, magnitude: Magnitude, mspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + magnitude: Magnitude, + mspec: str = "", + sort_func: SortFunc | None = None, + **babel_kwds: Unpack[BabelKwds], ) -> str: with override_locale(mspec, babel_kwds.get("locale", None)) as format_number: if isinstance(magnitude, ndarray): @@ -272,9 +316,13 @@ def format_magnitude( return mstr def format_unit( - self, unit: PlainUnit, uspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + unit: PlainUnit | Iterable[tuple[str, Any]], + uspec: str = "", + sort_func: SortFunc | None = None, + **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = unit._REGISTRY + registry = self._registry if registry is None: raise ValueError( "Can't format as siunitx without a registry." @@ -289,7 +337,12 @@ def format_unit( # should unit names be shortened? # units = format_compound_unit(unit, uspec, **babel_kwds) - formatted = siunitx_format_unit(unit._units.items(), registry) + try: + units = unit._units.items() + except Exception: + units = unit + + formatted = siunitx_format_unit(units, registry) if "~" in uspec: formatted = formatted.replace(r"\percent", r"\%") @@ -302,9 +355,10 @@ def format_quantity( self, quantity: PlainQuantity[MagnitudeT], qspec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = quantity._REGISTRY + registry = self._registry mspec, uspec = split_format( qspec, registry.formatter.default_format, registry.separate_format_defaults @@ -313,13 +367,16 @@ def format_quantity( joint_fstring = "{}{}" mstr = self.format_magnitude(quantity.magnitude, mspec, **babel_kwds) - ustr = self.format_unit(quantity.units, uspec, **babel_kwds)[len(r"\si[]") :] + ustr = self.format_unit(quantity.unit_items(), uspec, sort_func, **babel_kwds)[ + len(r"\si[]") : + ] return r"\SI[]" + join_mu(joint_fstring, "{%s}" % mstr, ustr) def format_uncertainty( self, uncertainty, unc_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: # SIunitx requires space between "+-" (or "\pm") and the nominal value @@ -337,9 +394,10 @@ def format_measurement( self, measurement: Measurement, meas_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = measurement._REGISTRY + registry = self._registry mspec, uspec = split_format( meas_spec, @@ -357,5 +415,7 @@ def format_measurement( r"", "{%s}" % self.format_uncertainty(measurement.magnitude, unc_spec, **babel_kwds), - self.format_unit(measurement.units, uspec, **babel_kwds)[len(r"\si[]") :], + self.format_unit(measurement.units, uspec, sort_func, **babel_kwds)[ + len(r"\si[]") : + ], ) diff --git a/pint/delegates/formatter/plain.py b/pint/delegates/formatter/plain.py index 4b9616631..d40ec1ae0 100644 --- a/pint/delegates/formatter/plain.py +++ b/pint/delegates/formatter/plain.py @@ -14,30 +14,45 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import itertools import re -from ...compat import ndarray, np, Unpack -from ._spec_helpers import ( - pretty_fmt_exponent, - split_format, +from typing import TYPE_CHECKING, Any, Iterable + +from ..._typing import Magnitude +from ...compat import Unpack, ndarray, np +from ._compound_unit_helpers import ( + BabelKwds, + SortFunc, + localize_per, + prepare_compount_unit, +) +from ._format_helpers import ( + formatter, join_mu, join_unc, + override_locale, + pretty_fmt_exponent, +) +from ._spec_helpers import ( remove_custom_flags, + split_format, ) -from ..._typing import Magnitude - -from ._format_helpers import format_compound_unit, BabelKwds, formatter, override_locale - if TYPE_CHECKING: - from ...facets.plain import PlainQuantity, PlainUnit, MagnitudeT from ...facets.measurement import Measurement + from ...facets.plain import MagnitudeT, PlainQuantity, PlainUnit + from ...registry import UnitRegistry _EXP_PATTERN = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)") -class DefaultFormatter: +class BaseFormatter: + def __init__(self, registry: UnitRegistry | None = None): + self._registry = registry + + +class DefaultFormatter(BaseFormatter): """Simple, localizable plain text formatter. A formatter is a class with methods to format into string each of the objects @@ -62,19 +77,37 @@ def format_magnitude( return mstr def format_unit( - self, unit: PlainUnit, uspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + unit: PlainUnit | Iterable[tuple[str, Any]], + uspec: str = "", + sort_func: SortFunc | None = None, + **babel_kwds: Unpack[BabelKwds], ) -> str: - units = format_compound_unit(unit, uspec, **babel_kwds) """Format a unit (can be compound) into string given a string formatting specification and locale related arguments. """ + numerator, denominator = prepare_compount_unit( + unit, + uspec, + sort_func=sort_func, + **babel_kwds, + registry=self._registry, + ) + + if babel_kwds.get("locale", None): + length = babel_kwds.get("length") or ("short" if "~" in uspec else "long") + division_fmt = localize_per(length, babel_kwds.get("locale"), "{} / {}") + else: + division_fmt = "{} / {}" + return formatter( - units, + numerator, + denominator, as_ratio=True, single_denominator=False, - product_fmt=" * ", - division_fmt=" / ", + product_fmt="{} * {}", + division_fmt=division_fmt, power_fmt="{} ** {}", parentheses_fmt=r"({})", ) @@ -83,13 +116,14 @@ def format_quantity( self, quantity: PlainQuantity[MagnitudeT], qspec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: """Format a quantity (magnitude and unit) into string given a string formatting specification and locale related arguments. """ - registry = quantity._REGISTRY + registry = self._registry mspec, uspec = split_format( qspec, registry.formatter.default_format, registry.separate_format_defaults @@ -99,13 +133,14 @@ def format_quantity( return join_mu( joint_fstring, self.format_magnitude(quantity.magnitude, mspec, **babel_kwds), - self.format_unit(quantity.units, uspec, **babel_kwds), + self.format_unit(quantity.unit_items(), uspec, sort_func, **babel_kwds), ) def format_uncertainty( self, uncertainty, unc_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: """Format an uncertainty magnitude (nominal value and stdev) into string @@ -118,13 +153,14 @@ def format_measurement( self, measurement: Measurement, meas_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: """Format an measurement (uncertainty and units) into string given a string formatting specification and locale related arguments. """ - registry = measurement._REGISTRY + registry = self._registry mspec, uspec = split_format( meas_spec, @@ -141,11 +177,11 @@ def format_measurement( "(", ")", self.format_uncertainty(measurement.magnitude, unc_spec, **babel_kwds), - self.format_unit(measurement.units, uspec, **babel_kwds), + self.format_unit(measurement.units, uspec, sort_func, **babel_kwds), ) -class CompactFormatter: +class CompactFormatter(BaseFormatter): """Simple, localizable plain text formatter without extra spaces.""" def format_magnitude( @@ -163,16 +199,30 @@ def format_magnitude( return mstr def format_unit( - self, unit: PlainUnit, uspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + unit: PlainUnit | Iterable[tuple[str, Any]], + uspec: str = "", + sort_func: SortFunc | None = None, + **babel_kwds: Unpack[BabelKwds], ) -> str: - units = format_compound_unit(unit, uspec, **babel_kwds) + numerator, denominator = prepare_compount_unit( + unit, + uspec, + sort_func=sort_func, + **babel_kwds, + registry=self._registry, + ) + + # Division format in compact formatter is not localized. + division_fmt = "{}/{}" return formatter( - units, + numerator, + denominator, as_ratio=True, single_denominator=False, product_fmt="*", # TODO: Should this just be ''? - division_fmt="/", + division_fmt=division_fmt, power_fmt="{}**{}", parentheses_fmt=r"({})", ) @@ -181,9 +231,10 @@ def format_quantity( self, quantity: PlainQuantity[MagnitudeT], qspec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = quantity._REGISTRY + registry = self._registry mspec, uspec = split_format( qspec, registry.formatter.default_format, registry.separate_format_defaults @@ -194,13 +245,14 @@ def format_quantity( return join_mu( joint_fstring, self.format_magnitude(quantity.magnitude, mspec, **babel_kwds), - self.format_unit(quantity.units, uspec, **babel_kwds), + self.format_unit(quantity.unit_items(), uspec, sort_func, **babel_kwds), ) def format_uncertainty( self, uncertainty, unc_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: return format(uncertainty, unc_spec).replace("+/-", "+/-") @@ -209,9 +261,10 @@ def format_measurement( self, measurement: Measurement, meas_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = measurement._REGISTRY + registry = self._registry mspec, uspec = split_format( meas_spec, @@ -228,11 +281,11 @@ def format_measurement( "(", ")", self.format_uncertainty(measurement.magnitude, unc_spec, **babel_kwds), - self.format_unit(measurement.units, uspec, **babel_kwds), + self.format_unit(measurement.units, uspec, sort_func, **babel_kwds), ) -class PrettyFormatter: +class PrettyFormatter(BaseFormatter): """Pretty printed localizable plain text formatter without extra spaces.""" def format_magnitude( @@ -256,16 +309,33 @@ def format_magnitude( return mstr def format_unit( - self, unit: PlainUnit, uspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + unit: PlainUnit | Iterable[tuple[str, Any]], + uspec: str = "", + sort_func: SortFunc | None = None, + **babel_kwds: Unpack[BabelKwds], ) -> str: - units = format_compound_unit(unit, uspec, **babel_kwds) + numerator, denominator = prepare_compount_unit( + unit, + uspec, + sort_func=sort_func, + **babel_kwds, + registry=self._registry, + ) + + if babel_kwds.get("locale", None): + length = babel_kwds.get("length") or ("short" if "~" in uspec else "long") + division_fmt = localize_per(length, babel_kwds.get("locale"), "{}/{}") + else: + division_fmt = "{}/{}" return formatter( - units, + numerator, + denominator, as_ratio=True, single_denominator=False, product_fmt="·", - division_fmt="/", + division_fmt=division_fmt, power_fmt="{}{}", parentheses_fmt="({})", exp_call=pretty_fmt_exponent, @@ -275,9 +345,10 @@ def format_quantity( self, quantity: PlainQuantity[MagnitudeT], qspec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = quantity._REGISTRY + registry = self._registry mspec, uspec = split_format( qspec, registry.formatter.default_format, registry.separate_format_defaults @@ -288,13 +359,14 @@ def format_quantity( return join_mu( joint_fstring, self.format_magnitude(quantity.magnitude, mspec, **babel_kwds), - self.format_unit(quantity.units, uspec, **babel_kwds), + self.format_unit(quantity.unit_items(), uspec, sort_func, **babel_kwds), ) def format_uncertainty( self, uncertainty, unc_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: return format(uncertainty, unc_spec).replace("±", " ± ") @@ -303,9 +375,10 @@ def format_measurement( self, measurement: Measurement, meas_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = measurement._REGISTRY + registry = self._registry mspec, uspec = split_format( meas_spec, @@ -321,11 +394,11 @@ def format_measurement( "(", ")", self.format_uncertainty(measurement.magnitude, unc_spec, **babel_kwds), - self.format_unit(measurement.units, uspec, **babel_kwds), + self.format_unit(measurement.units, uspec, sort_func, **babel_kwds), ) -class RawFormatter: +class RawFormatter(BaseFormatter): """Very simple non-localizable plain text formatter. Ignores all pint custom string formatting specification. @@ -337,19 +410,33 @@ def format_magnitude( return str(magnitude) def format_unit( - self, unit: PlainUnit, uspec: str = "", **babel_kwds: Unpack[BabelKwds] + self, + unit: PlainUnit | Iterable[tuple[str, Any]], + uspec: str = "", + sort_func: SortFunc | None = None, + **babel_kwds: Unpack[BabelKwds], ) -> str: - units = format_compound_unit(unit, uspec, **babel_kwds) + numerator, denominator = prepare_compount_unit( + unit, + uspec, + sort_func=sort_func, + **babel_kwds, + registry=self._registry, + ) - return " * ".join(k if v == 1 else f"{k} ** {v}" for k, v in units) + return " * ".join( + k if v == 1 else f"{k} ** {v}" + for k, v in itertools.chain(numerator, denominator) + ) def format_quantity( self, quantity: PlainQuantity[MagnitudeT], qspec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = quantity._REGISTRY + registry = self._registry mspec, uspec = split_format( qspec, registry.formatter.default_format, registry.separate_format_defaults @@ -359,13 +446,14 @@ def format_quantity( return join_mu( joint_fstring, self.format_magnitude(quantity.magnitude, mspec, **babel_kwds), - self.format_unit(quantity.units, uspec, **babel_kwds), + self.format_unit(quantity.unit_items(), uspec, sort_func, **babel_kwds), ) def format_uncertainty( self, uncertainty, unc_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: return format(uncertainty, unc_spec) @@ -374,9 +462,10 @@ def format_measurement( self, measurement: Measurement, meas_spec: str = "", + sort_func: SortFunc | None = None, **babel_kwds: Unpack[BabelKwds], ) -> str: - registry = measurement._REGISTRY + registry = self._registry mspec, uspec = split_format( meas_spec, @@ -393,5 +482,5 @@ def format_measurement( "(", ")", self.format_uncertainty(measurement.magnitude, unc_spec, **babel_kwds), - self.format_unit(measurement.units, uspec, **babel_kwds), + self.format_unit(measurement.units, uspec, sort_func, **babel_kwds), ) diff --git a/pint/delegates/txt_defparser/__init__.py b/pint/delegates/txt_defparser/__init__.py index 49e4a0bf5..ba0dbbf65 100644 --- a/pint/delegates/txt_defparser/__init__.py +++ b/pint/delegates/txt_defparser/__init__.py @@ -7,7 +7,7 @@ :copyright: 2022 by Pint Authors, see AUTHORS for more details. :license: BSD, see LICENSE for more details. """ - +from __future__ import annotations from .defparser import DefParser diff --git a/pint/delegates/txt_defparser/block.py b/pint/delegates/txt_defparser/block.py index e8d8aa43f..6e8d18968 100644 --- a/pint/delegates/txt_defparser/block.py +++ b/pint/delegates/txt_defparser/block.py @@ -16,11 +16,11 @@ from __future__ import annotations from dataclasses import dataclass - from typing import Generic, TypeVar -from ..base_defparser import PintParsedStatement, ParserConfig -from ..._vendor import flexparser as fp +import flexparser as fp + +from ..base_defparser import ParserConfig, PintParsedStatement @dataclass(frozen=True) @@ -28,7 +28,7 @@ class EndDirectiveBlock(PintParsedStatement): """An EndDirectiveBlock is simply an "@end" statement.""" @classmethod - def from_string(cls, s: str) -> fp.FromString[EndDirectiveBlock]: + def from_string(cls, s: str) -> fp.NullableParsedResult[EndDirectiveBlock]: if s == "@end": return cls() return None @@ -50,7 +50,5 @@ class DirectiveBlock( Subclass this class for convenience. """ - closing: EndDirectiveBlock - def derive_definition(self) -> DefT: ... diff --git a/pint/delegates/txt_defparser/common.py b/pint/delegates/txt_defparser/common.py index a1195b3bf..ebdabc062 100644 --- a/pint/delegates/txt_defparser/common.py +++ b/pint/delegates/txt_defparser/common.py @@ -14,8 +14,10 @@ from dataclasses import dataclass, field +import flexparser as fp + from ... import errors -from ..._vendor import flexparser as fp +from ..base_defparser import ParserConfig @dataclass(frozen=True) @@ -43,7 +45,7 @@ def set_location(self, value: str) -> None: @dataclass(frozen=True) -class ImportDefinition(fp.IncludeStatement): +class ImportDefinition(fp.IncludeStatement[ParserConfig]): value: str @property @@ -51,7 +53,7 @@ def target(self) -> str: return self.value @classmethod - def from_string(cls, s: str) -> fp.FromString[ImportDefinition]: + def from_string(cls, s: str) -> fp.NullableParsedResult[ImportDefinition]: if s.startswith("@import"): return ImportDefinition(s[len("@import") :].strip()) return None diff --git a/pint/delegates/txt_defparser/context.py b/pint/delegates/txt_defparser/context.py index 5ede7b44b..029b60445 100644 --- a/pint/delegates/txt_defparser/context.py +++ b/pint/delegates/txt_defparser/context.py @@ -19,10 +19,11 @@ import numbers import re import typing as ty -from typing import Optional, Union from dataclasses import dataclass +from typing import Union + +import flexparser as fp -from ..._vendor import flexparser as fp from ...facets.context import definitions from ..base_defparser import ParserConfig, PintParsedStatement from . import block, common, plain @@ -33,7 +34,7 @@ def _from_string_and_context_sep( cls: type[T], s: str, config: ParserConfig, separator: str -) -> Optional[T]: +) -> T | None: if separator not in s: return None if ":" not in s: @@ -58,7 +59,7 @@ class ForwardRelation(PintParsedStatement, definitions.ForwardRelation): @classmethod def from_string_and_config( cls, s: str, config: ParserConfig - ) -> fp.FromString[ForwardRelation]: + ) -> fp.NullableParsedResult[ForwardRelation]: return _from_string_and_context_sep(cls, s, config, "->") @@ -74,7 +75,7 @@ class BidirectionalRelation(PintParsedStatement, definitions.BidirectionalRelati @classmethod def from_string_and_config( cls, s: str, config: ParserConfig - ) -> fp.FromString[BidirectionalRelation]: + ) -> fp.NullableParsedResult[BidirectionalRelation]: return _from_string_and_context_sep(cls, s, config, "<->") @@ -96,7 +97,7 @@ class BeginContext(PintParsedStatement): @classmethod def from_string_and_config( cls, s: str, config: ParserConfig - ) -> fp.FromString[BeginContext]: + ) -> fp.NullableParsedResult[BeginContext]: try: r = cls._header_re.search(s) if r is None: @@ -169,16 +170,6 @@ class ContextDefinition( @end """ - opening: fp.Single[BeginContext] - body: fp.Multi[ - ty.Union[ - plain.CommentDefinition, - BidirectionalRelation, - ForwardRelation, - plain.UnitDefinition, - ] - ] - def derive_definition(self) -> definitions.ContextDefinition: return definitions.ContextDefinition( self.name, self.aliases, self.defaults, self.relations, self.redefinitions @@ -200,7 +191,7 @@ def defaults(self) -> dict[str, numbers.Number]: return self.opening.defaults @property - def relations(self) -> tuple[Union[BidirectionalRelation, ForwardRelation], ...]: + def relations(self) -> tuple[BidirectionalRelation | ForwardRelation, ...]: return tuple( r for r in self.body diff --git a/pint/delegates/txt_defparser/defaults.py b/pint/delegates/txt_defparser/defaults.py index b29be18f2..669daddb4 100644 --- a/pint/delegates/txt_defparser/defaults.py +++ b/pint/delegates/txt_defparser/defaults.py @@ -16,10 +16,11 @@ import typing as ty from dataclasses import dataclass, fields -from ..._vendor import flexparser as fp +import flexparser as fp + from ...facets.plain import definitions -from . import block, plain from ..base_defparser import PintParsedStatement +from . import block, plain @dataclass(frozen=True) @@ -30,7 +31,7 @@ class BeginDefaults(PintParsedStatement): """ @classmethod - def from_string(cls, s: str) -> fp.FromString[BeginDefaults]: + def from_string(cls, s: str) -> fp.NullableParsedResult[BeginDefaults]: if s.strip() == "@defaults": return cls() return None @@ -56,14 +57,6 @@ class DefaultsDefinition( See Equality and Comment for more parsing related information. """ - opening: fp.Single[BeginDefaults] - body: fp.Multi[ - ty.Union[ - plain.CommentDefinition, - plain.Equality, - ] - ] - @property def _valid_fields(self) -> tuple[str, ...]: return tuple(f.name for f in fields(definitions.DefaultsDefinition)) diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py index e89863d00..8c57ac306 100644 --- a/pint/delegates/txt_defparser/defparser.py +++ b/pint/delegates/txt_defparser/defparser.py @@ -2,10 +2,10 @@ import pathlib import typing as ty -from typing import Optional, Union -from ..._vendor import flexcache as fc -from ..._vendor import flexparser as fp +import flexcache as fc +import flexparser as fp + from ..base_defparser import ParserConfig from . import block, common, context, defaults, group, plain, system @@ -28,28 +28,6 @@ class PintRootBlock( ParserConfig, ] ): - body: fp.Multi[ - ty.Union[ - plain.CommentDefinition, - common.ImportDefinition, - context.ContextDefinition, - defaults.DefaultsDefinition, - system.SystemDefinition, - group.GroupDefinition, - plain.AliasDefinition, - plain.DerivedDimensionDefinition, - plain.DimensionDefinition, - plain.PrefixDefinition, - plain.UnitDefinition, - ] - ] - - -class PintSource(fp.ParsedSource[PintRootBlock, ParserConfig]): - """Source code in Pint.""" - - -class HashTuple(tuple): pass @@ -66,16 +44,18 @@ class _PintParser(fp.Parser[PintRootBlock, ParserConfig]): _root_block_class = PintRootBlock _strip_spaces = True - _diskcache: fc.DiskCache + _diskcache: fc.DiskCache | None - def __init__(self, config: ParserConfig, *args, **kwargs): + def __init__(self, config: ParserConfig, *args: ty.Any, **kwargs: ty.Any): self._diskcache = kwargs.pop("diskcache", None) super().__init__(config, *args, **kwargs) - def parse_file(self, path: pathlib.Path) -> PintSource: + def parse_file( + self, path: pathlib.Path + ) -> fp.ParsedSource[PintRootBlock, ParserConfig]: if self._diskcache is None: return super().parse_file(path) - content, basename = self._diskcache.load(path, super().parse_file) + content, _basename = self._diskcache.load(path, super().parse_file) return content @@ -88,26 +68,33 @@ class DefParser: plain.CommentDefinition, ) - def __init__(self, default_config, diskcache): + def __init__(self, default_config: ParserConfig, diskcache: fc.DiskCache): self._default_config = default_config self._diskcache = diskcache - def iter_parsed_project(self, parsed_project: fp.ParsedProject): + def iter_parsed_project( + self, parsed_project: fp.ParsedProject[PintRootBlock, ParserConfig] + ) -> ty.Generator[fp.ParsedStatement[ParserConfig], None, None]: last_location = None for stmt in parsed_project.iter_blocks(): - if isinstance(stmt, fp.BOF): - last_location = str(stmt.path) - elif isinstance(stmt, fp.BOR): - last_location = ( - f"[package: {stmt.package}, resource: {stmt.resource_name}]" - ) + if isinstance(stmt, fp.BOS): + if isinstance(stmt, fp.BOF): + last_location = str(stmt.path) + continue + elif isinstance(stmt, fp.BOR): + last_location = ( + f"[package: {stmt.package}, resource: {stmt.resource_name}]" + ) + continue + else: + last_location = "orphan string" + continue if isinstance(stmt, self.skip_classes): continue + assert isinstance(last_location, str) if isinstance(stmt, common.DefinitionSyntaxError): - # TODO: check why this assert fails - # assert isinstance(last_location, str) stmt.set_location(last_location) raise stmt elif isinstance(stmt, block.DirectiveBlock): @@ -132,8 +119,8 @@ def iter_parsed_project(self, parsed_project: fp.ParsedProject): yield stmt def parse_file( - self, filename: Union[pathlib.Path, str], cfg: Optional[ParserConfig] = None - ): + self, filename: pathlib.Path | str, cfg: ParserConfig | None = None + ) -> fp.ParsedProject[PintRootBlock, ParserConfig]: return fp.parse( filename, _PintParser, @@ -143,7 +130,9 @@ def parse_file( delimiters=_PintParser._delimiters, ) - def parse_string(self, content: str, cfg: Optional[ParserConfig] = None): + def parse_string( + self, content: str, cfg: ParserConfig | None = None + ) -> fp.ParsedProject[PintRootBlock, ParserConfig]: return fp.parse_bytes( content.encode("utf-8"), _PintParser, diff --git a/pint/delegates/txt_defparser/group.py b/pint/delegates/txt_defparser/group.py index 851e68572..120438a83 100644 --- a/pint/delegates/txt_defparser/group.py +++ b/pint/delegates/txt_defparser/group.py @@ -20,10 +20,11 @@ import typing as ty from dataclasses import dataclass -from ..._vendor import flexparser as fp +import flexparser as fp + from ...facets.group import definitions -from . import block, common, plain from ..base_defparser import PintParsedStatement +from . import block, common, plain @dataclass(frozen=True) @@ -40,7 +41,7 @@ class BeginGroup(PintParsedStatement): using_group_names: ty.Tuple[str, ...] @classmethod - def from_string(cls, s: str) -> fp.FromString[BeginGroup]: + def from_string(cls, s: str) -> fp.NullableParsedResult[BeginGroup]: if not s.startswith("@group"): return None @@ -90,14 +91,6 @@ class GroupDefinition( """ - opening: fp.Single[BeginGroup] - body: fp.Multi[ - ty.Union[ - plain.CommentDefinition, - plain.UnitDefinition, - ] - ] - def derive_definition(self) -> definitions.GroupDefinition: return definitions.GroupDefinition( self.name, self.using_group_names, self.definitions diff --git a/pint/delegates/txt_defparser/plain.py b/pint/delegates/txt_defparser/plain.py index 9c7bd42ef..ac4230bcb 100644 --- a/pint/delegates/txt_defparser/plain.py +++ b/pint/delegates/txt_defparser/plain.py @@ -25,7 +25,8 @@ from dataclasses import dataclass -from ..._vendor import flexparser as fp +import flexparser as fp + from ...converters import Converter from ...facets.plain import definitions from ...util import UnitsContainer @@ -41,7 +42,7 @@ class Equality(PintParsedStatement, definitions.Equality): """ @classmethod - def from_string(cls, s: str) -> fp.FromString[Equality]: + def from_string(cls, s: str) -> fp.NullableParsedResult[Equality]: if "=" not in s: return None parts = [p.strip() for p in s.split("=")] @@ -63,7 +64,7 @@ class CommentDefinition(PintParsedStatement, definitions.CommentDefinition): """ @classmethod - def from_string(cls, s: str) -> fp.FromString[CommentDefinition]: + def from_string(cls, s: str) -> fp.NullableParsedResult[CommentDefinition]: if not s.startswith("#"): return None return cls(s[1:].strip()) @@ -83,7 +84,7 @@ class PrefixDefinition(PintParsedStatement, definitions.PrefixDefinition): @classmethod def from_string_and_config( cls, s: str, config: ParserConfig - ) -> fp.FromString[PrefixDefinition]: + ) -> fp.NullableParsedResult[PrefixDefinition]: if "=" not in s: return None @@ -140,7 +141,7 @@ class UnitDefinition(PintParsedStatement, definitions.UnitDefinition): @classmethod def from_string_and_config( cls, s: str, config: ParserConfig - ) -> fp.FromString[UnitDefinition]: + ) -> fp.NullableParsedResult[UnitDefinition]: if "=" not in s: return None @@ -205,17 +206,12 @@ class DimensionDefinition(PintParsedStatement, definitions.DimensionDefinition): """ @classmethod - def from_string(cls, s: str) -> fp.FromString[DimensionDefinition]: + def from_string(cls, s: str) -> fp.NullableParsedResult[DimensionDefinition]: s = s.strip() if not (s.startswith("[") and "=" not in s): return None - try: - s = definitions.check_dim(s) - except common.DefinitionSyntaxError as ex: - return ex - return cls(s) @@ -235,7 +231,7 @@ class DerivedDimensionDefinition( @classmethod def from_string_and_config( cls, s: str, config: ParserConfig - ) -> fp.FromString[DerivedDimensionDefinition]: + ) -> fp.NullableParsedResult[DerivedDimensionDefinition]: if not (s.startswith("[") and "=" in s): return None @@ -272,7 +268,7 @@ class AliasDefinition(PintParsedStatement, definitions.AliasDefinition): """ @classmethod - def from_string(cls, s: str) -> fp.FromString[AliasDefinition]: + def from_string(cls, s: str) -> fp.NullableParsedResult[AliasDefinition]: if not s.startswith("@alias "): return None name, *aliases = s[len("@alias ") :].split("=") diff --git a/pint/delegates/txt_defparser/system.py b/pint/delegates/txt_defparser/system.py index 7a65a36ae..8c45b0b0b 100644 --- a/pint/delegates/txt_defparser/system.py +++ b/pint/delegates/txt_defparser/system.py @@ -12,7 +12,8 @@ import typing as ty from dataclasses import dataclass -from ..._vendor import flexparser as fp +import flexparser as fp + from ...facets.system import definitions from ..base_defparser import PintParsedStatement from . import block, common, plain @@ -21,7 +22,7 @@ @dataclass(frozen=True) class BaseUnitRule(PintParsedStatement, definitions.BaseUnitRule): @classmethod - def from_string(cls, s: str) -> fp.FromString[BaseUnitRule]: + def from_string(cls, s: str) -> fp.NullableParsedResult[BaseUnitRule]: if ":" not in s: return cls(s.strip()) parts = [p.strip() for p in s.split(":")] @@ -46,7 +47,7 @@ class BeginSystem(PintParsedStatement): using_group_names: ty.Tuple[str, ...] @classmethod - def from_string(cls, s: str) -> fp.FromString[BeginSystem]: + def from_string(cls, s: str) -> fp.NullableParsedResult[BeginSystem]: if not s.startswith("@system"): return None @@ -96,9 +97,6 @@ class SystemDefinition( If the new_unit_name and the old_unit_name, the later and the colon can be omitted. """ - opening: fp.Single[BeginSystem] - body: fp.Multi[ty.Union[plain.CommentDefinition, BaseUnitRule]] - def derive_definition(self) -> definitions.SystemDefinition: return definitions.SystemDefinition( self.name, self.using_group_names, self.rules diff --git a/pint/errors.py b/pint/errors.py index 8041c1817..59d3b4569 100644 --- a/pint/errors.py +++ b/pint/errors.py @@ -10,7 +10,6 @@ from __future__ import annotations -from typing import Union import typing as ty from dataclasses import dataclass, fields @@ -135,7 +134,7 @@ def __reduce__(self): class UndefinedUnitError(AttributeError, PintError): """Raised when the units are not defined in the unit registry.""" - unit_names: Union[str, tuple[str, ...]] + unit_names: str | tuple[str, ...] def __str__(self): if isinstance(self.unit_names, str): @@ -246,3 +245,11 @@ def __reduce__(self): class UnexpectedScaleInContainer(Exception): def __reduce__(self): return self.__class__, tuple(getattr(self, f.name) for f in fields(self)) + + +@dataclass(frozen=False) +class UndefinedBehavior(UserWarning, PintError): + msg: str + + def __reduce__(self): + return self.__class__, tuple(getattr(self, f.name) for f in fields(self)) diff --git a/pint/facets/__init__.py b/pint/facets/__init__.py index 2a2bb4cd3..12729289c 100644 --- a/pint/facets/__init__.py +++ b/pint/facets/__init__.py @@ -71,15 +71,15 @@ class that belongs to a registry that has NumpyRegistry as one of its bases. from .context import ContextRegistry, GenericContextRegistry from .dask import DaskRegistry, GenericDaskRegistry -from .group import GroupRegistry, GenericGroupRegistry -from .measurement import MeasurementRegistry, GenericMeasurementRegistry +from .group import GenericGroupRegistry, GroupRegistry +from .measurement import GenericMeasurementRegistry, MeasurementRegistry from .nonmultiplicative import ( - NonMultiplicativeRegistry, GenericNonMultiplicativeRegistry, + NonMultiplicativeRegistry, ) -from .numpy import NumpyRegistry, GenericNumpyRegistry -from .plain import PlainRegistry, GenericPlainRegistry, QuantityT, UnitT, MagnitudeT -from .system import SystemRegistry, GenericSystemRegistry +from .numpy import GenericNumpyRegistry, NumpyRegistry +from .plain import GenericPlainRegistry, MagnitudeT, PlainRegistry, QuantityT, UnitT +from .system import GenericSystemRegistry, SystemRegistry __all__ = [ "ContextRegistry", diff --git a/pint/facets/context/definitions.py b/pint/facets/context/definitions.py index f63a6fcc3..76f84d63d 100644 --- a/pint/facets/context/definitions.py +++ b/pint/facets/context/definitions.py @@ -11,9 +11,9 @@ import itertools import numbers import re +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable -from collections.abc import Iterable +from typing import TYPE_CHECKING from ... import errors from ..plain import UnitDefinition diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py index c0e2f0c67..edd1dfb2a 100644 --- a/pint/facets/context/objects.py +++ b/pint/facets/context/objects.py @@ -10,13 +10,13 @@ import weakref from collections import ChainMap, defaultdict -from typing import Any, Callable, Protocol, Generic, Optional, TYPE_CHECKING -from collections.abc import Iterable +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any, Generic, Protocol -from ...facets.plain import UnitDefinition, PlainQuantity, PlainUnit, MagnitudeT +from ..._typing import Magnitude +from ...facets.plain import MagnitudeT, PlainQuantity, PlainUnit, UnitDefinition from ...util import UnitsContainer, to_units_container from .definitions import ContextDefinition -from ..._typing import Magnitude if TYPE_CHECKING: from ...registry import UnitRegistry @@ -96,11 +96,11 @@ class Context: def __init__( self, - name: Optional[str] = None, + name: str | None = None, aliases: tuple[str, ...] = tuple(), - defaults: Optional[dict[str, Any]] = None, + defaults: dict[str, Any] | None = None, ) -> None: - self.name: Optional[str] = name + self.name: str | None = name self.aliases: tuple[str, ...] = aliases #: Maps (src, dst) -> transformation function @@ -155,7 +155,7 @@ def from_context(cls, context: Context, **defaults: Any) -> Context: def from_lines( cls, lines: Iterable[str], - to_base_func: Optional[ToBaseFunc] = None, + to_base_func: ToBaseFunc | None = None, non_int_type: type = float, ) -> Context: context_definition = ContextDefinition.from_lines(lines, non_int_type) @@ -167,7 +167,7 @@ def from_lines( @classmethod def from_definition( - cls, cd: ContextDefinition, to_base_func: Optional[ToBaseFunc] = None + cls, cd: ContextDefinition, to_base_func: ToBaseFunc | None = None ) -> Context: ctx = cls(cd.name, cd.aliases, cd.defaults) @@ -246,7 +246,7 @@ def _redefine(self, definition: UnitDefinition): def hashable( self, ) -> tuple[ - Optional[str], + str | None, tuple[str, ...], frozenset[tuple[SrcDst, int]], frozenset[tuple[str, Any]], @@ -278,7 +278,7 @@ def __init__(self): super().__init__() self.contexts: list[Context] = [] self.maps.clear() # Remove default empty map - self._graph: Optional[dict[SrcDst, set[UnitsContainer]]] = None + self._graph: dict[SrcDst, set[UnitsContainer]] | None = None def insert_contexts(self, *contexts: Context): """Insert one or more contexts in reversed order the chained map. @@ -292,7 +292,7 @@ def insert_contexts(self, *contexts: Context): self.maps = [ctx.relation_to_context for ctx in reversed(contexts)] + self.maps self._graph = None - def remove_contexts(self, n: Optional[int] = None): + def remove_contexts(self, n: int | None = None): """Remove the last n inserted contexts from the chain. Parameters diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py index 3bfb3fd25..8f9f71ca5 100644 --- a/pint/facets/context/registry.py +++ b/pint/facets/context/registry.py @@ -10,16 +10,17 @@ import functools from collections import ChainMap +from collections.abc import Callable, Generator from contextlib import contextmanager -from typing import Any, Callable, Generator, Generic, Optional, Union +from typing import Any, Generic -from ...compat import TypeAlias from ..._typing import F, Magnitude +from ...compat import TypeAlias from ...errors import UndefinedUnitError -from ...util import find_connected_nodes, find_shortest_path, logger, UnitsContainer -from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT -from .definitions import ContextDefinition +from ...util import UnitsContainer, find_connected_nodes, find_shortest_path, logger +from ..plain import GenericPlainRegistry, QuantityT, UnitDefinition, UnitT from . import objects +from .definitions import ContextDefinition # TODO: Put back annotation when possible # registry_cache: "RegistryCache" @@ -75,7 +76,7 @@ def _register_definition_adders(self) -> None: super()._register_definition_adders() self._register_adder(ContextDefinition, self.add_context) - def add_context(self, context: Union[objects.Context, ContextDefinition]) -> None: + def add_context(self, context: objects.Context | ContextDefinition) -> None: """Add a context object to the registry. The context will be accessible by its name and aliases. @@ -198,7 +199,7 @@ def _redefine(self, definition: UnitDefinition) -> None: self.define(definition) def enable_contexts( - self, *names_or_contexts: Union[str, objects.Context], **kwargs: Any + self, *names_or_contexts: str | objects.Context, **kwargs: Any ) -> None: """Enable contexts provided by name or by object. @@ -245,7 +246,7 @@ def enable_contexts( self._active_ctx.insert_contexts(*contexts) self._switch_context_cache_and_units() - def disable_contexts(self, n: Optional[int] = None) -> None: + def disable_contexts(self, n: int | None = None) -> None: """Disable the last n enabled contexts. Parameters @@ -404,7 +405,7 @@ def _convert( return super()._convert(value, src, dst, inplace) def _get_compatible_units( - self, input_units: UnitsContainer, group_or_system: Optional[str] = None + self, input_units: UnitsContainer, group_or_system: str | None = None ): src_dim = self._get_dimensionality(input_units) diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py index 8d62f55d7..c3133bc31 100644 --- a/pint/facets/dask/__init__.py +++ b/pint/facets/dask/__init__.py @@ -11,17 +11,17 @@ from __future__ import annotations -from typing import Generic, Any import functools +from typing import Any, Generic -from ...compat import compute, dask_array, persist, visualize, TypeAlias +from ...compat import TypeAlias, compute, dask_array, persist, visualize from ..plain import ( GenericPlainRegistry, + MagnitudeT, PlainQuantity, + PlainUnit, QuantityT, UnitT, - PlainUnit, - MagnitudeT, ) diff --git a/pint/facets/group/__init__.py b/pint/facets/group/__init__.py index b25ea85cf..db488deac 100644 --- a/pint/facets/group/__init__.py +++ b/pint/facets/group/__init__.py @@ -12,7 +12,7 @@ from .definitions import GroupDefinition from .objects import Group, GroupQuantity, GroupUnit -from .registry import GroupRegistry, GenericGroupRegistry +from .registry import GenericGroupRegistry, GroupRegistry __all__ = [ "GroupDefinition", diff --git a/pint/facets/group/definitions.py b/pint/facets/group/definitions.py index 0a22b5072..bec7d8ac0 100644 --- a/pint/facets/group/definitions.py +++ b/pint/facets/group/definitions.py @@ -10,10 +10,9 @@ from collections.abc import Iterable from dataclasses import dataclass -from typing import Optional -from ...compat import Self from ... import errors +from ...compat import Self from .. import plain @@ -31,7 +30,7 @@ class GroupDefinition(errors.WithDefErr): @classmethod def from_lines( cls: type[Self], lines: Iterable[str], non_int_type: type - ) -> Optional[Self]: + ) -> Self | None: # TODO: this is to keep it backwards compatible from ...delegates import ParserConfig, txt_defparser diff --git a/pint/facets/group/objects.py b/pint/facets/group/objects.py index dbd7ecf3c..751dd3765 100644 --- a/pint/facets/group/objects.py +++ b/pint/facets/group/objects.py @@ -8,12 +8,12 @@ from __future__ import annotations -from typing import Callable, Any, TYPE_CHECKING, Generic, Optional +from collections.abc import Callable, Generator, Iterable +from typing import TYPE_CHECKING, Any, Generic -from collections.abc import Generator, Iterable from ...util import SharedRegistryObject, getattr_maybe_raise +from ..plain import MagnitudeT, PlainQuantity, PlainUnit from .definitions import GroupDefinition -from ..plain import PlainQuantity, PlainUnit, MagnitudeT if TYPE_CHECKING: from ..plain import UnitDefinition @@ -81,7 +81,7 @@ def __init__(self, name: str): #: A cache of the included units. #: None indicates that the cache has been invalidated. - self._computed_members: Optional[frozenset[str]] = None + self._computed_members: frozenset[str] | None = None @property def members(self) -> frozenset[str]: @@ -197,7 +197,7 @@ def from_lines( def from_definition( cls, group_definition: GroupDefinition, - add_unit_func: Optional[AddUnitFunc] = None, + add_unit_func: AddUnitFunc | None = None, ) -> Group: grp = cls(group_definition.name) diff --git a/pint/facets/group/registry.py b/pint/facets/group/registry.py index da068c5e9..33f78c645 100644 --- a/pint/facets/group/registry.py +++ b/pint/facets/group/registry.py @@ -8,10 +8,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, Any, Optional +from typing import TYPE_CHECKING, Any, Generic -from ...compat import TypeAlias from ... import errors +from ...compat import TypeAlias if TYPE_CHECKING: from ..._typing import Unit, UnitsContainer @@ -19,12 +19,12 @@ from ...util import create_class_with_registry, to_units_container from ..plain import ( GenericPlainRegistry, - UnitDefinition, QuantityT, + UnitDefinition, UnitT, ) -from .definitions import GroupDefinition from . import objects +from .definitions import GroupDefinition class GenericGroupRegistry( @@ -121,7 +121,7 @@ def get_group(self, name: str, create_if_needed: bool = True) -> objects.Group: return self.Group(name) def get_compatible_units( - self, input_units: UnitsContainer, group: Optional[str] = None + self, input_units: UnitsContainer, group: str | None = None ) -> frozenset[Unit]: """ """ if group is None: @@ -134,7 +134,7 @@ def get_compatible_units( return frozenset(self.Unit(eq) for eq in equiv) def _get_compatible_units( - self, input_units: UnitsContainer, group: Optional[str] = None + self, input_units: UnitsContainer, group: str | None = None ) -> frozenset[str]: ret = super()._get_compatible_units(input_units) diff --git a/pint/facets/measurement/__init__.py b/pint/facets/measurement/__init__.py index d36a5c31a..0b241ea1d 100644 --- a/pint/facets/measurement/__init__.py +++ b/pint/facets/measurement/__init__.py @@ -11,7 +11,7 @@ from __future__ import annotations from .objects import Measurement, MeasurementQuantity -from .registry import MeasurementRegistry, GenericMeasurementRegistry +from .registry import GenericMeasurementRegistry, MeasurementRegistry __all__ = [ "Measurement", diff --git a/pint/facets/measurement/objects.py b/pint/facets/measurement/objects.py index f052152e5..4240a91d2 100644 --- a/pint/facets/measurement/objects.py +++ b/pint/facets/measurement/objects.py @@ -13,7 +13,7 @@ from typing import Generic from ...compat import ufloat -from ..plain import PlainQuantity, PlainUnit, MagnitudeT +from ..plain import MagnitudeT, PlainQuantity, PlainUnit MISSING = object() diff --git a/pint/facets/measurement/registry.py b/pint/facets/measurement/registry.py index 4a3e87804..905de7ab7 100644 --- a/pint/facets/measurement/registry.py +++ b/pint/facets/measurement/registry.py @@ -9,9 +9,9 @@ from __future__ import annotations -from typing import Generic, Any +from typing import Any, Generic -from ...compat import ufloat, TypeAlias +from ...compat import TypeAlias, ufloat from ...util import create_class_with_registry from ..plain import GenericPlainRegistry, QuantityT, UnitT from . import objects diff --git a/pint/facets/nonmultiplicative/__init__.py b/pint/facets/nonmultiplicative/__init__.py index eb3292b3c..a338dc34a 100644 --- a/pint/facets/nonmultiplicative/__init__.py +++ b/pint/facets/nonmultiplicative/__init__.py @@ -15,6 +15,6 @@ # This import register LogarithmicConverter and OffsetConverter to be usable # (via subclassing) from .definitions import LogarithmicConverter, OffsetConverter # noqa: F401 -from .registry import NonMultiplicativeRegistry, GenericNonMultiplicativeRegistry +from .registry import GenericNonMultiplicativeRegistry, NonMultiplicativeRegistry __all__ = ["NonMultiplicativeRegistry", "GenericNonMultiplicativeRegistry"] diff --git a/pint/facets/nonmultiplicative/objects.py b/pint/facets/nonmultiplicative/objects.py index 8ebe8f8ea..114a256af 100644 --- a/pint/facets/nonmultiplicative/objects.py +++ b/pint/facets/nonmultiplicative/objects.py @@ -8,9 +8,9 @@ from __future__ import annotations -from typing import Generic, Optional +from typing import Generic -from ..plain import PlainQuantity, PlainUnit, MagnitudeT +from ..plain import MagnitudeT, PlainQuantity, PlainUnit class NonMultiplicativeQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): @@ -42,7 +42,7 @@ def _has_compatible_delta(self, unit: str) -> bool: self._get_unit_definition(d).reference == offset_unit_dim for d in deltas ) - def _ok_for_muldiv(self, no_offset_units: Optional[int] = None) -> bool: + def _ok_for_muldiv(self, no_offset_units: int | None = None) -> bool: """Checks if PlainQuantity object can be multiplied or divided""" is_ok = True diff --git a/pint/facets/nonmultiplicative/registry.py b/pint/facets/nonmultiplicative/registry.py index 67250ea48..7f58d060c 100644 --- a/pint/facets/nonmultiplicative/registry.py +++ b/pint/facets/nonmultiplicative/registry.py @@ -8,15 +8,14 @@ from __future__ import annotations -from typing import Any, TypeVar, Generic, Optional +from typing import Any, Generic, TypeVar from ...compat import TypeAlias from ...errors import DimensionalityError, UndefinedUnitError from ...util import UnitsContainer, logger -from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT -from .definitions import OffsetConverter, ScaleConverter +from ..plain import GenericPlainRegistry, QuantityT, UnitDefinition, UnitT from . import objects - +from .definitions import OffsetConverter, ScaleConverter T = TypeVar("T") @@ -60,8 +59,8 @@ def __init__( def parse_units_as_container( self, input_string: str, - as_delta: Optional[bool] = None, - case_sensitive: Optional[bool] = None, + as_delta: bool | None = None, + case_sensitive: bool | None = None, ) -> UnitsContainer: """ """ if as_delta is None: @@ -120,7 +119,7 @@ def _is_multiplicative(self, unit_name: str) -> bool: Raises ------ UndefinedUnitError - If the unit is not in the registyr. + If the unit is not in the registry. """ if unit_name in self._units: return self._units[unit_name].is_multiplicative @@ -136,7 +135,7 @@ def _is_multiplicative(self, unit_name: str) -> bool: except KeyError: raise UndefinedUnitError(unit_name) - def _validate_and_extract(self, units: UnitsContainer) -> Optional[str]: + def _validate_and_extract(self, units: UnitsContainer) -> str | None: """Used to check if a given units is suitable for a simple conversion. @@ -193,7 +192,7 @@ def _add_ref_of_log_or_offset_unit( self, offset_unit: str, all_units: UnitsContainer ) -> UnitsContainer: slct_unit = self._units[offset_unit] - if slct_unit.is_logarithmic or (not slct_unit.is_multiplicative): + if slct_unit.is_logarithmic: # Extract reference unit slct_ref = slct_unit.reference @@ -205,6 +204,11 @@ def _add_ref_of_log_or_offset_unit( (u, e) = [(u, e) for u, e in slct_ref.items()].pop() # Add it back to the unit list return all_units.add(u, e) + + if not slct_unit.is_multiplicative: # is offset unit + # Extract reference unit + return slct_unit.reference + # Otherwise, return the units unmodified return all_units @@ -250,6 +254,7 @@ def _convert( src, dst, extra_msg=f" - In destination units, {ex}" ) + # convert if no offset units are present if not (src_offset_unit or dst_offset_unit): return super()._convert(value, src, dst, inplace) @@ -263,6 +268,8 @@ def _convert( # clean src from offset units by converting to reference if src_offset_unit: + if any(u.startswith("delta_") for u in dst): + raise DimensionalityError(src, dst) value = self._units[src_offset_unit].converter.to_reference(value, inplace) src = src.remove([src_offset_unit]) # Add reference unit for multiplicative section @@ -270,6 +277,8 @@ def _convert( # clean dst units from offset units if dst_offset_unit: + if any(u.startswith("delta_") for u in src): + raise DimensionalityError(src, dst) dst = dst.remove([dst_offset_unit]) # Add reference unit for multiplicative section dst = self._add_ref_of_log_or_offset_unit(dst_offset_unit, dst) diff --git a/pint/facets/numpy/__init__.py b/pint/facets/numpy/__init__.py index 2e38dc1dc..477c09579 100644 --- a/pint/facets/numpy/__init__.py +++ b/pint/facets/numpy/__init__.py @@ -10,6 +10,6 @@ from __future__ import annotations -from .registry import NumpyRegistry, GenericNumpyRegistry +from .registry import GenericNumpyRegistry, NumpyRegistry __all__ = ["NumpyRegistry", "GenericNumpyRegistry"] diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py index 57dc5123d..b79700f9f 100644 --- a/pint/facets/numpy/numpy_func.py +++ b/pint/facets/numpy/numpy_func.py @@ -52,6 +52,10 @@ def _is_sequence_with_quantity_elements(obj): ------- True if obj is a sequence and at least one element is a Quantity; False otherwise """ + if np is not None and isinstance(obj, np.ndarray) and not obj.dtype.hasobject: + # If obj is a numpy array, avoid looping on all elements + # if dtype does not have objects + return False return ( iterable(obj) and sized(obj) @@ -284,6 +288,17 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None): @implements(func_str, func_type) def implementation(*args, **kwargs): + if func_str in ["multiply", "true_divide", "divide", "floor_divide"] and any( + [ + not _is_quantity(arg) and _is_sequence_with_quantity_elements(arg) + for arg in args + ] + ): + # the sequence may contain different units, so fall back to element-wise + return np.array( + [func(*func_args) for func_args in zip(*args)], dtype=object + ) + first_input_units = _get_first_input_units(args, kwargs) if input_units == "all_consistent": # Match all input args/kwargs to same units @@ -413,6 +428,7 @@ def implementation(*args, **kwargs): "take", "trace", "transpose", + "roll", "ceil", "floor", "hypot", @@ -740,8 +756,11 @@ def _base_unit_if_needed(a): raise OffsetUnitCalculusError(a.units) +# NP2 Can remove trapz wrapping when we only support numpy>=2 @implements("trapz", "function") +@implements("trapezoid", "function") def _trapz(y, x=None, dx=1.0, **kwargs): + trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz y = _base_unit_if_needed(y) units = y.units if x is not None: @@ -749,17 +768,26 @@ def _trapz(y, x=None, dx=1.0, **kwargs): x = _base_unit_if_needed(x) units *= x.units x = x._magnitude - ret = np.trapz(y._magnitude, x, **kwargs) + ret = trapezoid(y._magnitude, x, **kwargs) else: if hasattr(dx, "units"): dx = _base_unit_if_needed(dx) units *= dx.units dx = dx._magnitude - ret = np.trapz(y._magnitude, dx=dx, **kwargs) + ret = trapezoid(y._magnitude, dx=dx, **kwargs) return y.units._REGISTRY.Quantity(ret, units) +@implements("correlate", "function") +def _correlate(a, v, mode="valid", **kwargs): + a = _base_unit_if_needed(a) + v = _base_unit_if_needed(v) + units = a.units * v.units + ret = np.correlate(a._magnitude, v._magnitude, mode=mode, **kwargs) + return a.units._REGISTRY.Quantity(ret, units) + + def implement_mul_func(func): # If NumPy is not available, do not attempt implement that which does not exist if np is None: @@ -850,6 +878,7 @@ def implementation(*args, **kwargs): ("median", "a", True), ("nanmedian", "a", True), ("transpose", "a", True), + ("roll", "a", True), ("copy", "a", True), ("average", "a", True), ("nanmean", "a", True), @@ -965,7 +994,7 @@ def implementation(a, *args, **kwargs): return a._REGISTRY.Quantity(func(a_stripped, *args, **kwargs)) -for func_str in ("cumprod", "cumproduct", "nancumprod"): +for func_str in ("cumprod", "nancumprod"): implement_single_dimensionless_argument_func(func_str) # Handle single-argument consistent unit functions diff --git a/pint/facets/numpy/quantity.py b/pint/facets/numpy/quantity.py index deaf675da..75dccec54 100644 --- a/pint/facets/numpy/quantity.py +++ b/pint/facets/numpy/quantity.py @@ -13,11 +13,10 @@ import warnings from typing import Any, Generic -from ..plain import PlainQuantity, MagnitudeT - from ..._typing import Shape -from ...compat import _to_magnitude, np, HAS_NUMPY +from ...compat import HAS_NUMPY, _to_magnitude, np from ...errors import DimensionalityError, PintTypeError, UnitStrippedWarning +from ..plain import MagnitudeT, PlainQuantity from .numpy_func import ( HANDLED_UFUNCS, copy_units_output_ufuncs, @@ -31,7 +30,7 @@ try: import uncertainties.unumpy as unp - from uncertainties import ufloat, UFloat + from uncertainties import UFloat, ufloat HAS_UNCERTAINTIES = True except ImportError: diff --git a/pint/facets/numpy/registry.py b/pint/facets/numpy/registry.py index e93de44f0..e1128f383 100644 --- a/pint/facets/numpy/registry.py +++ b/pint/facets/numpy/registry.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import Generic, Any +from typing import Any, Generic from ...compat import TypeAlias from ..plain import GenericPlainRegistry, QuantityT, UnitT diff --git a/pint/facets/plain/__init__.py b/pint/facets/plain/__init__.py index 90bf2e35a..f84dd68f3 100644 --- a/pint/facets/plain/__init__.py +++ b/pint/facets/plain/__init__.py @@ -19,8 +19,8 @@ UnitDefinition, ) from .objects import PlainQuantity, PlainUnit -from .registry import PlainRegistry, GenericPlainRegistry, QuantityT, UnitT from .quantity import MagnitudeT +from .registry import GenericPlainRegistry, PlainRegistry, QuantityT, UnitT __all__ = [ "GenericPlainRegistry", diff --git a/pint/facets/plain/definitions.py b/pint/facets/plain/definitions.py index 44bf29858..a43ce0dbc 100644 --- a/pint/facets/plain/definitions.py +++ b/pint/facets/plain/definitions.py @@ -13,10 +13,10 @@ import typing as ty from dataclasses import dataclass from functools import cached_property -from typing import Any, Optional +from typing import Any -from ..._typing import Magnitude from ... import errors +from ..._typing import Magnitude from ...converters import Converter from ...util import UnitsContainer @@ -81,7 +81,7 @@ class PrefixDefinition(NamedDefinition, errors.WithDefErr): #: scaling value for this prefix value: numbers.Number #: canonical symbol - defined_symbol: Optional[str] = "" + defined_symbol: str | None = "" #: additional names for the same prefix aliases: ty.Tuple[str, ...] = () @@ -118,7 +118,7 @@ class UnitDefinition(NamedDefinition, errors.WithDefErr): """Definition of a unit.""" #: canonical symbol - defined_symbol: Optional[str] + defined_symbol: str | None #: additional names for the same unit aliases: tuple[str, ...] #: A functiont that converts a value in these units into the reference units @@ -126,9 +126,9 @@ class UnitDefinition(NamedDefinition, errors.WithDefErr): # Briefly, in several places converter attributes like as_multiplicative were # accesed. So having a generic function is a no go. # I guess this was never used as errors where not raised. - converter: Optional[Converter] + converter: Converter | None #: Reference units. - reference: Optional[UnitsContainer] + reference: UnitsContainer | None def __post_init__(self): if not errors.is_valid_unit_name(self.name): diff --git a/pint/facets/plain/qto.py b/pint/facets/plain/qto.py index 726523763..22176491d 100644 --- a/pint/facets/plain/qto.py +++ b/pint/facets/plain/qto.py @@ -1,21 +1,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional - import bisect import math import numbers import warnings +from typing import TYPE_CHECKING -from ...util import infer_base_unit from ...compat import ( mip_INF, mip_INTEGER, - mip_model, mip_Model, + mip_model, mip_OptimizationStatus, mip_xsum, ) +from ...errors import UndefinedBehavior +from ...util import infer_base_unit if TYPE_CHECKING: from ..._typing import UnitLike @@ -82,7 +82,7 @@ def to_reduced_units( def to_compact( - quantity: PlainQuantity, unit: Optional[UnitsContainer] = None + quantity: PlainQuantity, unit: UnitsContainer | None = None ) -> PlainQuantity: """ "Return PlainQuantity rescaled to compact, human-readable units. @@ -103,9 +103,11 @@ def to_compact( if not isinstance(quantity.magnitude, numbers.Number) and not hasattr( quantity.magnitude, "nominal_value" ): - msg = "to_compact applied to non numerical types " "has an undefined behavior." - w = RuntimeWarning(msg) - warnings.warn(w, stacklevel=2) + warnings.warn( + "to_compact applied to non numerical types has an undefined behavior.", + UndefinedBehavior, + stacklevel=2, + ) return quantity if ( @@ -170,7 +172,7 @@ def to_compact( def to_preferred( - quantity: PlainQuantity, preferred_units: Optional[list[UnitLike]] = None + quantity: PlainQuantity, preferred_units: list[UnitLike] | None = None ) -> PlainQuantity: """Return Quantity converted to a unit composed of the preferred units. @@ -182,7 +184,7 @@ def to_preferred( >>> (1*ureg.acre).to_preferred([ureg.meters]) >>> (1*(ureg.force_pound*ureg.m)).to_preferred([ureg.W]) - + """ units = _get_preferred(quantity, preferred_units) @@ -190,7 +192,7 @@ def to_preferred( def ito_preferred( - quantity: PlainQuantity, preferred_units: Optional[list[UnitLike]] = None + quantity: PlainQuantity, preferred_units: list[UnitLike] | None = None ) -> PlainQuantity: """Return Quantity converted to a unit composed of the preferred units. @@ -202,7 +204,7 @@ def ito_preferred( >>> (1*ureg.acre).to_preferred([ureg.meters]) >>> (1*(ureg.force_pound*ureg.m)).to_preferred([ureg.W]) - + """ units = _get_preferred(quantity, preferred_units) @@ -210,7 +212,7 @@ def ito_preferred( def _get_preferred( - quantity: PlainQuantity, preferred_units: Optional[list[UnitLike]] = None + quantity: PlainQuantity, preferred_units: list[UnitLike] | None = None ) -> PlainQuantity: if preferred_units is None: preferred_units = quantity._REGISTRY.default_preferred_units diff --git a/pint/facets/plain/quantity.py b/pint/facets/plain/quantity.py index 2a4dcf19d..a18919273 100644 --- a/pint/facets/plain/quantity.py +++ b/pint/facets/plain/quantity.py @@ -8,34 +8,31 @@ from __future__ import annotations - import copy import datetime import locale import numbers import operator +from collections.abc import Callable, Iterator, Sequence from typing import ( TYPE_CHECKING, Any, - Callable, - overload, Generic, + Iterable, TypeVar, - Optional, - Union, + overload, ) -from collections.abc import Iterator, Sequence -from ..._typing import UnitLike, QuantityOrUnitLike, Magnitude, Scalar +from ..._typing import Magnitude, QuantityOrUnitLike, Scalar, UnitLike from ...compat import ( HAS_NUMPY, _to_magnitude, + deprecated, eq, is_duck_array_type, is_upcast_type, np, zero_or_nan, - deprecated, ) from ...errors import DimensionalityError, OffsetUnitCalculusError, PintTypeError from ...util import ( @@ -45,8 +42,8 @@ logger, to_units_container, ) -from .definitions import UnitDefinition from . import qto +from .definitions import UnitDefinition if TYPE_CHECKING: from ..context import Context @@ -58,7 +55,7 @@ try: import uncertainties.unumpy as unp - from uncertainties import ufloat, UFloat + from uncertainties import UFloat, ufloat HAS_UNCERTAINTIES = True except ImportError: @@ -143,7 +140,7 @@ class PlainQuantity(Generic[MagnitudeT], PrettyIPython, SharedRegistryObject): def ndim(self) -> int: if isinstance(self.magnitude, numbers.Number): return 0 - if str(self.magnitude) == "": + if str(type(self.magnitude)) == "NAType": return 0 return self.magnitude.ndim @@ -168,25 +165,23 @@ def __reduce__(self) -> tuple[type, Magnitude, UnitsContainer]: @overload def __new__( - cls, value: MagnitudeT, units: Optional[UnitLike] = None + cls, value: MagnitudeT, units: UnitLike | None = None ) -> PlainQuantity[MagnitudeT]: ... @overload - def __new__( - cls, value: str, units: Optional[UnitLike] = None - ) -> PlainQuantity[Any]: + def __new__(cls, value: str, units: UnitLike | None = None) -> PlainQuantity[Any]: ... @overload def __new__( # type: ignore[misc] - cls, value: Sequence[ScalarT], units: Optional[UnitLike] = None + cls, value: Sequence[ScalarT], units: UnitLike | None = None ) -> PlainQuantity[Any]: ... @overload def __new__( - cls, value: PlainQuantity[Any], units: Optional[UnitLike] = None + cls, value: PlainQuantity[Any], units: UnitLike | None = None ) -> PlainQuantity[Any]: ... @@ -334,6 +329,10 @@ def unitless(self) -> bool: """ """ return not bool(self.to_root_units()._units) + def unit_items(self) -> Iterable[tuple[str, Scalar]]: + """A view of the unit items.""" + return self._units.unit_items() + @property def dimensionless(self) -> bool: """ """ @@ -341,7 +340,7 @@ def dimensionless(self) -> bool: return not bool(tmp.dimensionality) - _dimensionality: Optional[UnitsContainerT] = None + _dimensionality: UnitsContainerT | None = None @property def dimensionality(self) -> UnitsContainerT: @@ -436,7 +435,7 @@ def compatible_units(self, *contexts): return self._REGISTRY.get_compatible_units(self._units) def is_compatible_with( - self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any + self, other: Any, *contexts: str | Context, **ctx_kwargs: Any ) -> bool: """check if the other object is compatible @@ -493,7 +492,7 @@ def _convert_magnitude(self, other, *contexts, **ctx_kwargs): ) def ito( - self, other: Optional[QuantityOrUnitLike] = None, *contexts, **ctx_kwargs + self, other: QuantityOrUnitLike | None = None, *contexts, **ctx_kwargs ) -> None: """Inplace rescale to different units. @@ -515,7 +514,7 @@ def ito( return None def to( - self, other: Optional[QuantityOrUnitLike] = None, *contexts, **ctx_kwargs + self, other: QuantityOrUnitLike | None = None, *contexts, **ctx_kwargs ) -> PlainQuantity: """Return PlainQuantity rescaled to different units. @@ -1289,7 +1288,7 @@ def __rpow__(self, other) -> PlainQuantity[MagnitudeT]: def __abs__(self) -> PlainQuantity[MagnitudeT]: return self.__class__(abs(self._magnitude), self._units) - def __round__(self, ndigits: Optional[int] = 0) -> PlainQuantity[MagnitudeT]: + def __round__(self, ndigits: int | None = 0) -> PlainQuantity[MagnitudeT]: return self.__class__(round(self._magnitude, ndigits=ndigits), self._units) def __pos__(self) -> PlainQuantity[MagnitudeT]: diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py index 2e5128fd8..09fd220ee 100644 --- a/pint/facets/plain/registry.py +++ b/pint/facets/plain/registry.py @@ -30,45 +30,40 @@ import pathlib import re from collections import defaultdict +from collections.abc import Callable, Generator, Iterable, Iterator from decimal import Decimal from fractions import Fraction from token import NAME, NUMBER from tokenize import TokenInfo - from typing import ( TYPE_CHECKING, Any, - Callable, + Generic, TypeVar, Union, - Generic, - Generator, - Optional, ) -from collections.abc import Iterable, Iterator if TYPE_CHECKING: - from ..context import Context from ...compat import Locale + from ..context import Context # from ..._typing import Quantity, Unit +import appdirs + +from ... import pint_eval from ..._typing import ( - QuantityOrUnitLike, - UnitLike, + Handler, QuantityArgument, + QuantityOrUnitLike, Scalar, - Handler, + UnitLike, ) - -from ... import pint_eval -from ..._vendor import appdirs -from ...compat import TypeAlias, Self, deprecated +from ...compat import Self, TypeAlias, deprecated from ...errors import DimensionalityError, RedefinitionError, UndefinedUnitError from ...pint_eval import build_eval_tree -from ...util import ParserHelper -from ...util import UnitsContainer as UnitsContainer from ...util import ( + ParserHelper, _is_dim, create_class_with_registry, getattr_maybe_raise, @@ -77,15 +72,16 @@ string_preprocessor, to_units_container, ) +from ...util import UnitsContainer as UnitsContainer from .definitions import ( AliasDefinition, CommentDefinition, DefaultsDefinition, DerivedDimensionDefinition, DimensionDefinition, + NamedDefinition, PrefixDefinition, UnitDefinition, - NamedDefinition, ) from .objects import PlainQuantity, PlainUnit @@ -95,7 +91,7 @@ @functools.lru_cache -def pattern_to_regex(pattern: Union[str, re.Pattern[str]]) -> re.Pattern[str]: +def pattern_to_regex(pattern: str | re.Pattern[str]) -> re.Pattern[str]: # TODO: This has been changed during typing improvements. # if hasattr(pattern, "finditer"): if not isinstance(pattern, str): @@ -223,12 +219,12 @@ def __init__( on_redefinition: str = "warn", auto_reduce_dimensions: bool = False, autoconvert_to_preferred: bool = False, - preprocessors: Optional[list[PreprocessorType]] = None, - fmt_locale: Optional[str] = None, + preprocessors: list[PreprocessorType] | None = None, + fmt_locale: str | None = None, non_int_type: NON_INT_TYPE = float, case_sensitive: bool = True, - cache_folder: Optional[Union[str, pathlib.Path]] = None, - separate_format_defaults: Optional[bool] = None, + cache_folder: str | pathlib.Path | None = None, + separate_format_defaults: bool | None = None, mpl_formatter: str = "{:P}", ): #: Map a definition class to a adder methods. @@ -251,7 +247,7 @@ def __init__( delegates.ParserConfig(non_int_type), diskcache=self._diskcache ) - self.formatter = delegates.Formatter() + self.formatter = delegates.Formatter(self) self._filename = filename self.force_ndarray = force_ndarray self.force_ndarray_like = force_ndarray_like @@ -289,7 +285,7 @@ def __init__( #: Map dimension name (string) to its definition (DimensionDefinition). self._dimensions: dict[ - str, Union[DimensionDefinition, DerivedDimensionDefinition] + str, DimensionDefinition | DerivedDimensionDefinition ] = {} #: Map unit name (string) to its definition (UnitDefinition). @@ -419,7 +415,7 @@ def fmt_locale(self, loc: str | None): "This function will be removed in future versions of pint.\n" "Use ureg.formatter.set_locale" ) - def set_fmt_locale(self, loc: Optional[str]) -> None: + def set_fmt_locale(self, loc: str | None) -> None: """Change the locale used by default by `format_babel`. Parameters @@ -448,7 +444,7 @@ def default_format(self, value: str) -> None: self.formatter.default_format = value @property - def cache_folder(self) -> Optional[pathlib.Path]: + def cache_folder(self) -> pathlib.Path | None: if self._diskcache: return self._diskcache.cache_folder return None @@ -457,7 +453,7 @@ def cache_folder(self) -> Optional[pathlib.Path]: def non_int_type(self): return self._non_int_type - def define(self, definition: Union[str, type]) -> None: + def define(self, definition: str | type) -> None: """Add unit to the registry. Parameters @@ -499,7 +495,7 @@ def _helper_adder( self, definition: NamedDefinition, target_dict: dict[str, Any], - casei_target_dict: Optional[dict[str, Any]], + casei_target_dict: dict[str, Any] | None, ) -> None: """Helper function to store a definition in the internal dictionaries. It stores the definition under its name, symbol and aliases. @@ -525,7 +521,7 @@ def _helper_single_adder( key: str, value: NamedDefinition, target_dict: dict[str, Any], - casei_target_dict: Optional[dict[str, Any]], + casei_target_dict: dict[str, Any] | None, ) -> None: """Helper function to store a definition in the internal dictionaries. @@ -575,7 +571,7 @@ def _add_unit(self, definition: UnitDefinition) -> None: self._helper_adder(definition, self._units, self._units_casei) def load_definitions( - self, file: Union[Iterable[str], str, pathlib.Path], is_resource: bool = False + self, file: Iterable[str] | str | pathlib.Path, is_resource: bool = False ): """Add units and prefixes defined in a definition text file. @@ -646,9 +642,7 @@ def _build_cache(self, loaded_files=None) -> None: logger.warning(f"Could not resolve {unit_name}: {exc!r}") return self._cache - def get_name( - self, name_or_alias: str, case_sensitive: Optional[bool] = None - ) -> str: + def get_name(self, name_or_alias: str, case_sensitive: bool | None = None) -> str: """Return the canonical name of a unit.""" if name_or_alias == "dimensionless": @@ -666,8 +660,7 @@ def get_name( prefix, unit_name, _ = candidates[0] if len(candidates) > 1: logger.warning( - "Parsing {} yield multiple results. " - "Options are: {!r}".format(name_or_alias, candidates) + f"Parsing {name_or_alias} yield multiple results. Options are: {candidates!r}" ) if prefix: @@ -685,9 +678,7 @@ def get_name( return unit_name - def get_symbol( - self, name_or_alias: str, case_sensitive: Optional[bool] = None - ) -> str: + def get_symbol(self, name_or_alias: str, case_sensitive: bool | None = None) -> str: """Return the preferred alias for a unit.""" candidates = self.parse_unit_name(name_or_alias, case_sensitive) if not candidates: @@ -696,8 +687,7 @@ def get_symbol( prefix, unit_name, _ = candidates[0] if len(candidates) > 1: logger.warning( - "Parsing {} yield multiple results. " - "Options are: {!r}".format(name_or_alias, candidates) + f"Parsing {name_or_alias} yield multiple results. Options are: {candidates!r}" ) return self._prefixes[prefix].symbol + self._units[unit_name].symbol @@ -716,9 +706,7 @@ def get_dimensionality(self, input_units: UnitLike) -> UnitsContainer: return self._get_dimensionality(input_units) - def _get_dimensionality( - self, input_units: Optional[UnitsContainer] - ) -> UnitsContainer: + def _get_dimensionality(self, input_units: UnitsContainer | None) -> UnitsContainer: """Convert a UnitsContainer to plain dimensions.""" if not input_units: return self.UnitsContainer() @@ -748,7 +736,12 @@ def _get_dimensionality_recurse( for key in ref: exp2 = exp * ref[key] if _is_dim(key): - reg = self._dimensions[key] + try: + reg = self._dimensions[key] + except KeyError: + raise ValueError( + f"{key} is not defined as dimension in the pint UnitRegistry" + ) if isinstance(reg, DerivedDimensionDefinition): self._get_dimensionality_recurse(reg.reference, exp2, accumulator) else: @@ -892,7 +885,7 @@ def _get_root_units( except KeyError: pass - accumulators: dict[Optional[str], int] = defaultdict(int) + accumulators: dict[str | None, int] = defaultdict(int) accumulators[None] = 1 self._get_root_units_recurse(input_units, 1, accumulators) @@ -911,7 +904,7 @@ def _get_root_units( def get_base_units( self, - input_units: Union[UnitsContainer, str], + input_units: UnitsContainer | str, check_nonmult: bool = True, system=None, ) -> tuple[Scalar, UnitT]: @@ -943,7 +936,7 @@ def get_base_units( # TODO: accumulators breaks typing list[int, dict[str, int]] # So we have changed the behavior here def _get_root_units_recurse( - self, ref: UnitsContainer, exp: Scalar, accumulators: dict[Optional[str], int] + self, ref: UnitsContainer, exp: Scalar, accumulators: dict[str | None, int] ) -> None: """ @@ -981,7 +974,7 @@ def _get_compatible_units( # TODO: remove context from here def is_compatible_with( - self, obj1: Any, obj2: Any, *contexts: Union[str, Context], **ctx_kwargs + self, obj1: Any, obj2: Any, *contexts: str | Context, **ctx_kwargs ) -> bool: """check if the other object is compatible @@ -1094,7 +1087,7 @@ def _convert( return value def parse_unit_name( - self, unit_name: str, case_sensitive: Optional[bool] = None + self, unit_name: str, case_sensitive: bool | None = None ) -> tuple[tuple[str, str, str], ...]: """Parse a unit to identify prefix, unit name and suffix by walking the list of prefix and suffix. @@ -1156,7 +1149,7 @@ def _yield_unit_triplets( @staticmethod def _dedup_candidates( - candidates: Iterable[tuple[str, str, str]] + candidates: Iterable[tuple[str, str, str]], ) -> tuple[tuple[str, str, str], ...]: """Helper of parse_unit_name. @@ -1178,8 +1171,8 @@ def _dedup_candidates( def parse_units( self, input_string: str, - as_delta: Optional[bool] = None, - case_sensitive: Optional[bool] = None, + as_delta: bool | None = None, + case_sensitive: bool | None = None, ) -> UnitT: """Parse a units expression and returns a UnitContainer with the canonical names. @@ -1209,8 +1202,8 @@ def parse_units( def parse_units_as_container( self, input_string: str, - as_delta: Optional[bool] = None, - case_sensitive: Optional[bool] = None, + as_delta: bool | None = None, + case_sensitive: bool | None = None, ) -> UnitsContainer: as_delta = ( as_delta if as_delta is not None else True @@ -1271,7 +1264,7 @@ def _parse_units_as_container( def _eval_token( self, token: TokenInfo, - case_sensitive: Optional[bool] = None, + case_sensitive: bool | None = None, **values: QuantityArgument, ): """Evaluate a single token using the following rules: @@ -1321,9 +1314,9 @@ def parse_pattern( self, input_string: str, pattern: str, - case_sensitive: Optional[bool] = None, + case_sensitive: bool | None = None, many: bool = False, - ) -> Optional[Union[list[str], str]]: + ) -> list[str] | str | None: """Parse a string with a given regex pattern and returns result. Parameters @@ -1372,7 +1365,7 @@ def parse_pattern( def parse_expression( self: Self, input_string: str, - case_sensitive: Optional[bool] = None, + case_sensitive: bool | None = None, **values: QuantityArgument, ) -> QuantityT: """Parse a mathematical expression including units and return a quantity object. diff --git a/pint/facets/plain/unit.py b/pint/facets/plain/unit.py index 4d3a5b12e..0ee05abbc 100644 --- a/pint/facets/plain/unit.py +++ b/pint/facets/plain/unit.py @@ -12,7 +12,7 @@ import locale import operator from numbers import Number -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from ..._typing import UnitLike from ...compat import NUMERIC_TYPES, deprecated @@ -43,8 +43,9 @@ def __init__(self, units: UnitLike) -> None: self._units = units._units else: raise TypeError( - "units must be of type str, Unit or " - "UnitsContainer; not {}.".format(type(units)) + "units must be of type str, Unit or " "UnitsContainer; not {}.".format( + type(units) + ) ) def __copy__(self) -> PlainUnit: @@ -103,7 +104,7 @@ def compatible_units(self, *contexts): return self._REGISTRY.get_compatible_units(self) def is_compatible_with( - self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any + self, other: Any, *contexts: str | Context, **ctx_kwargs: Any ) -> bool: """check if the other object is compatible diff --git a/pint/facets/system/__init__.py b/pint/facets/system/__init__.py index 24e68b761..b9cbc9593 100644 --- a/pint/facets/system/__init__.py +++ b/pint/facets/system/__init__.py @@ -12,6 +12,6 @@ from .definitions import SystemDefinition from .objects import System -from .registry import SystemRegistry, GenericSystemRegistry +from .registry import GenericSystemRegistry, SystemRegistry __all__ = ["SystemDefinition", "System", "SystemRegistry", "GenericSystemRegistry"] diff --git a/pint/facets/system/definitions.py b/pint/facets/system/definitions.py index 008abac78..f47a23fd8 100644 --- a/pint/facets/system/definitions.py +++ b/pint/facets/system/definitions.py @@ -10,10 +10,9 @@ from collections.abc import Iterable from dataclasses import dataclass -from typing import Optional -from ...compat import Self from ... import errors +from ...compat import Self @dataclass(frozen=True) @@ -25,7 +24,7 @@ class BaseUnitRule: new_unit_name: str #: name of the unit to be kicked out to make room for the new base uni #: If None, the current base unit with the same dimensionality will be used - old_unit_name: Optional[str] = None + old_unit_name: str | None = None # Instead of defining __post_init__ here, # it will be added to the container class @@ -47,7 +46,7 @@ class SystemDefinition(errors.WithDefErr): @classmethod def from_lines( cls: type[Self], lines: Iterable[str], non_int_type: type - ) -> Optional[Self]: + ) -> Self | None: # TODO: this is to keep it backwards compatible # TODO: check when is None returned. from ...delegates import ParserConfig, txt_defparser @@ -60,7 +59,7 @@ def from_lines( return definition @property - def unit_replacements(self) -> tuple[tuple[str, Optional[str]], ...]: + def unit_replacements(self) -> tuple[tuple[str, str | None], ...]: # TODO: check if None can be dropped. return tuple((el.new_unit_name, el.old_unit_name) for el in self.rules) diff --git a/pint/facets/system/objects.py b/pint/facets/system/objects.py index 912094de7..751a66abf 100644 --- a/pint/facets/system/objects.py +++ b/pint/facets/system/objects.py @@ -10,14 +10,11 @@ from __future__ import annotations import numbers - -from typing import Any, Optional -from collections.abc import Iterable - - -from typing import Callable, Generic +from collections.abc import Callable, Iterable from numbers import Number +from typing import Any, Generic +from ..._typing import UnitLike from ...babel_names import _babel_systems from ...compat import babel_parse from ...util import ( @@ -26,11 +23,9 @@ logger, to_units_container, ) -from .definitions import SystemDefinition from .. import group from ..plain import MagnitudeT - -from ..._typing import UnitLike +from .definitions import SystemDefinition GetRootUnits = Callable[[UnitLike, bool], tuple[Number, UnitLike]] @@ -73,7 +68,7 @@ def __init__(self, name: str): #: Names of the _used_groups in used by this system. self._used_groups: set[str] = set() - self._computed_members: Optional[frozenset[str]] = None + self._computed_members: frozenset[str] | None = None # Add this system to the system dictionary self._REGISTRY._systems[self.name] = self @@ -154,7 +149,7 @@ def from_lines( def from_definition( cls: type[System], system_definition: SystemDefinition, - get_root_func: Optional[GetRootUnits] = None, + get_root_func: GetRootUnits | None = None, ) -> System: if get_root_func is None: # TODO: kept for backwards compatibility diff --git a/pint/facets/system/registry.py b/pint/facets/system/registry.py index 04aaea7b0..e5235a4cb 100644 --- a/pint/facets/system/registry.py +++ b/pint/facets/system/registry.py @@ -9,12 +9,10 @@ from __future__ import annotations from numbers import Number -from typing import TYPE_CHECKING, Generic, Any, Union, Optional +from typing import TYPE_CHECKING, Any, Generic from ... import errors - from ...compat import TypeAlias - from ..plain import QuantityT, UnitT if TYPE_CHECKING: @@ -27,8 +25,8 @@ to_units_container, ) from ..group import GenericGroupRegistry -from .definitions import SystemDefinition from . import objects +from .definitions import SystemDefinition class GenericSystemRegistry( @@ -53,7 +51,7 @@ class GenericSystemRegistry( # to enjoy typing goodies System: type[objects.System] - def __init__(self, system: Optional[str] = None, **kwargs): + def __init__(self, system: str | None = None, **kwargs): super().__init__(**kwargs) #: Map system name to system. @@ -62,7 +60,7 @@ def __init__(self, system: Optional[str] = None, **kwargs): #: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer) self._base_units_cache: dict[UnitsContainerT, UnitsContainerT] = {} - self._default_system_name: Optional[str] = system + self._default_system_name: str | None = system def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" @@ -103,7 +101,7 @@ def sys(self): return objects.Lister(self._systems) @property - def default_system(self) -> Optional[str]: + def default_system(self) -> str | None: return self._default_system_name @default_system.setter @@ -143,9 +141,9 @@ def get_system(self, name: str, create_if_needed: bool = True) -> objects.System def get_base_units( self, - input_units: Union[UnitLike, Quantity], + input_units: UnitLike | Quantity, check_nonmult: bool = True, - system: Optional[Union[str, objects.System]] = None, + system: str | objects.System | None = None, ) -> tuple[Number, Unit]: """Convert unit or dict of units to the plain units. @@ -183,7 +181,7 @@ def _get_base_units( self, input_units: UnitsContainerT, check_nonmult: bool = True, - system: Optional[Union[str, objects.System]] = None, + system: str | objects.System | None = None, ): if system is None: system = self._default_system_name @@ -225,7 +223,7 @@ def _get_base_units( return base_factor, destination_units def get_compatible_units( - self, input_units: UnitsContainerT, group_or_system: Optional[str] = None + self, input_units: UnitsContainerT, group_or_system: str | None = None ) -> frozenset[Unit]: """ """ @@ -241,7 +239,7 @@ def get_compatible_units( return frozenset(self.Unit(eq) for eq in equiv) def _get_compatible_units( - self, input_units: UnitsContainerT, group_or_system: Optional[str] = None + self, input_units: UnitsContainerT, group_or_system: str | None = None ) -> frozenset[Unit]: if group_or_system and group_or_system in self._systems: members = self._systems[group_or_system].members diff --git a/pint/formatting.py b/pint/formatting.py index 94eb57cf6..9b880ae0e 100644 --- a/pint/formatting.py +++ b/pint/formatting.py @@ -10,31 +10,136 @@ from __future__ import annotations +from numbers import Number +from typing import Iterable -# Backwards compatiblity stuff -from .delegates.formatter.latex import ( - vector_to_latex, # noqa - matrix_to_latex, # noqa - ndarray_to_latex_parts, # noqa - ndarray_to_latex, # noqa - latex_escape, # noqa - siunitx_format_unit, # noqa - _EXP_PATTERN, # noqa -) # noqa +from .delegates.formatter._format_helpers import ( + _PRETTY_EXPONENTS, # noqa: F401 +) +from .delegates.formatter._format_helpers import ( + join_u as _join, # noqa: F401 +) +from .delegates.formatter._format_helpers import ( + pretty_fmt_exponent as _pretty_fmt_exponent, # noqa: F401 +) from .delegates.formatter._spec_helpers import ( - FORMATTER, # noqa - _BASIC_TYPES, # noqa - parse_spec as _parse_spec, # noqa - _JOIN_REG_EXP as __JOIN_REG_EXP, # noqa, - _join, # noqa - _PRETTY_EXPONENTS, # noqa - pretty_fmt_exponent as _pretty_fmt_exponent, # noqa - extract_custom_flags, # noqa - remove_custom_flags, # noqa - split_format, # noqa + _BASIC_TYPES, # noqa: F401 + FORMATTER, # noqa: F401 REGISTERED_FORMATTERS, -) # noqa -from .delegates.formatter._to_register import register_unit_format # noqa + extract_custom_flags, # noqa: F401 + remove_custom_flags, # noqa: F401 +) +from .delegates.formatter._spec_helpers import ( + parse_spec as _parse_spec, # noqa: F401 +) +from .delegates.formatter._spec_helpers import ( + split_format as split_format, # noqa: F401 +) + +# noqa +from .delegates.formatter._to_register import register_unit_format # noqa: F401 + +# Backwards compatiblity stuff +from .delegates.formatter.latex import ( + _EXP_PATTERN, # noqa: F401 + latex_escape, # noqa: F401 + matrix_to_latex, # noqa: F401 + ndarray_to_latex, # noqa: F401 + ndarray_to_latex_parts, # noqa: F401 + siunitx_format_unit, # noqa: F401 + vector_to_latex, # noqa: F401 +) + + +def formatter( + items: Iterable[tuple[str, Number]], + as_ratio: bool = True, + single_denominator: bool = False, + product_fmt: str = " * ", + division_fmt: str = " / ", + power_fmt: str = "{} ** {}", + parentheses_fmt: str = "({0})", + exp_call: FORMATTER = "{:n}".format, + sort: bool = True, +) -> str: + """Format a list of (name, exponent) pairs. + + Parameters + ---------- + items : list + a list of (name, exponent) pairs. + as_ratio : bool, optional + True to display as ratio, False as negative powers. (Default value = True) + single_denominator : bool, optional + all with terms with negative exponents are + collected together. (Default value = False) + product_fmt : str + the format used for multiplication. (Default value = " * ") + division_fmt : str + the format used for division. (Default value = " / ") + power_fmt : str + the format used for exponentiation. (Default value = "{} ** {}") + parentheses_fmt : str + the format used for parenthesis. (Default value = "({0})") + exp_call : callable + (Default value = lambda x: f"{x:n}") + sort : bool, optional + True to sort the formatted units alphabetically (Default value = True) + + Returns + ------- + str + the formula as a string. + + """ + + join_u = _join + + if sort is False: + items = tuple(items) + else: + items = sorted(items) + + if not items: + return "" + + if as_ratio: + fun = lambda x: exp_call(abs(x)) + else: + fun = exp_call + + pos_terms, neg_terms = [], [] + + for key, value in items: + if value == 1: + pos_terms.append(key) + elif value > 0: + pos_terms.append(power_fmt.format(key, fun(value))) + elif value == -1 and as_ratio: + neg_terms.append(key) + else: + neg_terms.append(power_fmt.format(key, fun(value))) + + if not as_ratio: + # Show as Product: positive * negative terms ** -1 + return _join(product_fmt, pos_terms + neg_terms) + + # Show as Ratio: positive terms / negative terms + pos_ret = _join(product_fmt, pos_terms) or "1" + + if not neg_terms: + return pos_ret + + if single_denominator: + neg_ret = join_u(product_fmt, neg_terms) + if len(neg_terms) > 1: + neg_ret = parentheses_fmt.format(neg_ret) + else: + neg_ret = join_u(division_fmt, neg_terms) + + # TODO: first or last pos_ret should be pluralized + + return _join(division_fmt, [pos_ret, neg_ret]) def format_unit(unit, spec: str, registry=None, **options): @@ -54,9 +159,9 @@ def format_unit(unit, spec: str, registry=None, **options): _formatter = REGISTERED_FORMATTERS.get(spec, None) else: try: - _formatter = registry._formatters[spec] + _formatter = registry.formatter._formatters[spec] except Exception: - _formatter = registry._formatters.get(spec, None) + _formatter = registry.formatter._formatters.get(spec, None) if _formatter is None: raise ValueError(f"Unknown conversion specified: {spec}") diff --git a/pint/pint_eval.py b/pint/pint_eval.py index 3f030505b..c2ddb29cd 100644 --- a/pint/pint_eval.py +++ b/pint/pint_eval.py @@ -9,13 +9,12 @@ """ from __future__ import annotations -from io import BytesIO import operator import token as tokenlib import tokenize +from io import BytesIO from tokenize import TokenInfo - -from typing import Any, Optional, Union +from typing import Any try: from uncertainties import ufloat @@ -319,9 +318,9 @@ class EvalTreeNode: def __init__( self, - left: Union[EvalTreeNode, TokenInfo], - operator: Optional[TokenInfo] = None, - right: Optional[EvalTreeNode] = None, + left: EvalTreeNode | TokenInfo, + operator: TokenInfo | None = None, + right: EvalTreeNode | None = None, ): self.left = left self.operator = operator @@ -351,8 +350,8 @@ def evaluate( ], Any, ], - bin_op: Optional[dict[str, BinaryOpT]] = None, - un_op: Optional[dict[str, UnaryOpT]] = None, + bin_op: dict[str, BinaryOpT] | None = None, + un_op: dict[str, UnaryOpT] | None = None, ): """Evaluate node. @@ -528,7 +527,7 @@ def _build_eval_tree( def build_eval_tree( tokens: Iterable[TokenInfo], - op_priority: Optional[dict[str, int]] = None, + op_priority: dict[str, int] | None = None, ) -> EvalTreeNode: """Build an evaluation tree from a set of tokens. diff --git a/pint/registry.py b/pint/registry.py index 3d85ad8ab..ceb9b62d1 100644 --- a/pint/registry.py +++ b/pint/registry.py @@ -16,11 +16,9 @@ from typing import Generic -from . import registry_helpers -from . import facets -from .util import logger, pi_theorem +from . import facets, registry_helpers from .compat import TypeAlias - +from .util import logger, pi_theorem # To build the Quantity and Unit classes # we follow the UnitRegistry bases @@ -71,31 +69,38 @@ class UnitRegistry(GenericUnitRegistry[Quantity, Unit]): ---------- filename : path of the units definition file to load or line-iterable object. - Empty to load the default definition file. + Empty string to load the default definition file. (default) None to leave the UnitRegistry empty. force_ndarray : bool convert any input, scalar or not to a numpy.ndarray. + (Default: False) force_ndarray_like : bool convert all inputs other than duck arrays to a numpy.ndarray. + (Default: False) default_as_delta : In the context of a multiplication of units, interpret non-multiplicative units as their *delta* counterparts. + (Default: False) autoconvert_offset_to_baseunit : If True converts offset units in quantities are converted to their plain units in multiplicative - context. If False no conversion happens. + context. If False no conversion happens. (Default: False) on_redefinition : str action to take in case a unit is redefined. - 'warn', 'raise', 'ignore' + 'warn', 'raise', 'ignore' (Default: 'raise') auto_reduce_dimensions : If True, reduce dimensionality on appropriate operations. + (Default: False) autoconvert_to_preferred : If True, converts preferred units on appropriate operations. + (Default: False) preprocessors : list of callables which are iteratively ran on any input expression - or unit string + or unit string or None for no preprocessor. + (Default=None) fmt_locale : - locale identifier string, used in `format_babel`. Default to None + locale identifier string, used in `format_babel` or None. + (Default=None) case_sensitive : bool, optional Control default case sensitivity of unit parsing. (Default: True) cache_folder : str or pathlib.Path or None, optional diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py index 37c539e35..f2961cc74 100644 --- a/pint/registry_helpers.py +++ b/pint/registry_helpers.py @@ -11,10 +11,10 @@ from __future__ import annotations import functools -from inspect import signature, Parameter +from collections.abc import Callable, Iterable +from inspect import Parameter, signature from itertools import zip_longest -from typing import TYPE_CHECKING, Callable, TypeVar, Any, Union, Optional -from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, TypeVar from ._typing import F from .errors import DimensionalityError @@ -197,8 +197,8 @@ def _apply_defaults(sig, args, kwargs): def wraps( ureg: UnitRegistry, - ret: Optional[Union[str, Unit, Iterable[Optional[Union[str, Unit]]]]], - args: Optional[Union[str, Unit, Iterable[Optional[Union[str, Unit]]]]], + ret: str | Unit | Iterable[str | Unit | None] | None, + args: str | Unit | Iterable[str | Unit | None] | None, strict: bool = True, ) -> Callable[[Callable[..., Any]], Callable[..., Quantity]]: """Wraps a function to become pint-aware. @@ -315,7 +315,7 @@ def wrapper(*values, **kw) -> Quantity: def check( - ureg: UnitRegistry, *args: Optional[Union[str, UnitsContainer, Unit]] + ureg: UnitRegistry, *args: str | UnitsContainer | Unit | None ) -> Callable[[F], F]: """Decorator to for quantity type checking for function inputs. diff --git a/pint/testing.py b/pint/testing.py index f2a570a59..21a1f55dd 100644 --- a/pint/testing.py +++ b/pint/testing.py @@ -3,7 +3,6 @@ import math import warnings from numbers import Number -from typing import Optional from . import Quantity from .compat import ndarray @@ -35,7 +34,7 @@ def _get_comparable_magnitudes(first, second, msg): return m1, m2 -def assert_equal(first, second, msg: Optional[str] = None) -> None: +def assert_equal(first, second, msg: str | None = None) -> None: if msg is None: msg = f"Comparing {first!r} and {second!r}. " @@ -45,10 +44,10 @@ def assert_equal(first, second, msg: Optional[str] = None) -> None: if isinstance(m1, ndarray) or isinstance(m2, ndarray): np.testing.assert_array_equal(m1, m2, err_msg=msg) elif not isinstance(m1, Number): - warnings.warn(RuntimeWarning) + warnings.warn("In assert_equal, m1 is not a number ", UserWarning) return elif not isinstance(m2, Number): - warnings.warn(RuntimeWarning) + warnings.warn("In assert_equal, m2 is not a number ", UserWarning) return elif math.isnan(m1): assert math.isnan(m2), msg @@ -59,7 +58,7 @@ def assert_equal(first, second, msg: Optional[str] = None) -> None: def assert_allclose( - first, second, rtol: float = 1e-07, atol: float = 0, msg: Optional[str] = None + first, second, rtol: float = 1e-07, atol: float = 0, msg: str | None = None ) -> None: if msg is None: try: @@ -76,10 +75,10 @@ def assert_allclose( if isinstance(m1, ndarray) or isinstance(m2, ndarray): np.testing.assert_allclose(m1, m2, rtol=rtol, atol=atol, err_msg=msg) elif not isinstance(m1, Number): - warnings.warn(RuntimeWarning) + warnings.warn("In assert_equal, m1 is not a number ", UserWarning) return elif not isinstance(m2, Number): - warnings.warn(RuntimeWarning) + warnings.warn("In assert_equal, m2 is not a number ", UserWarning) return elif math.isnan(m1): assert math.isnan(m2), msg diff --git a/pint/testsuite/__init__.py b/pint/testsuite/__init__.py index 35b0d9116..baafc5016 100644 --- a/pint/testsuite/__init__.py +++ b/pint/testsuite/__init__.py @@ -1,10 +1,12 @@ +from __future__ import annotations + +import contextlib import doctest import math import os +import pathlib import unittest import warnings -import contextlib -import pathlib from pint import UnitRegistry from pint.testsuite.helpers import PintOutputChecker diff --git a/pint/testsuite/benchmarks/test_00_common.py b/pint/testsuite/benchmarks/test_00_common.py index 3974dbcbb..43ee3fee3 100644 --- a/pint/testsuite/benchmarks/test_00_common.py +++ b/pint/testsuite/benchmarks/test_00_common.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import subprocess import sys diff --git a/pint/testsuite/benchmarks/test_01_registry_creation.py b/pint/testsuite/benchmarks/test_01_registry_creation.py index 3a17e5479..9013f2554 100644 --- a/pint/testsuite/benchmarks/test_01_registry_creation.py +++ b/pint/testsuite/benchmarks/test_01_registry_creation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pint diff --git a/pint/testsuite/benchmarks/test_10_registry.py b/pint/testsuite/benchmarks/test_10_registry.py index ec0a43429..3a1d42da5 100644 --- a/pint/testsuite/benchmarks/test_10_registry.py +++ b/pint/testsuite/benchmarks/test_10_registry.py @@ -1,13 +1,15 @@ -import pytest +from __future__ import annotations import pathlib -from typing import Any, TypeVar, Callable +from collections.abc import Callable +from operator import getitem +from typing import Any, TypeVar -from ...compat import TypeAlias +import pytest import pint -from operator import getitem +from ...compat import TypeAlias UNITS = ("meter", "kilometer", "second", "minute", "angstrom", "millisecond", "ms") @@ -162,6 +164,9 @@ def test_load_definitions_stage_1(benchmark, cache_folder, use_cache_folder): benchmark(pint.UnitRegistry, None, cache_folder=use_cache_folder) +@pytest.mark.skip( + "Test failing ValueError: Group USCSLengthInternational already present in registry" +) @pytest.mark.parametrize("use_cache_folder", (None, True)) def test_load_definitions_stage_2(benchmark, cache_folder, use_cache_folder): """empty registry creation + parsing default files + definition object loading""" diff --git a/pint/testsuite/benchmarks/test_20_quantity.py b/pint/testsuite/benchmarks/test_20_quantity.py index 1ec7cbb60..815e3c09c 100644 --- a/pint/testsuite/benchmarks/test_20_quantity.py +++ b/pint/testsuite/benchmarks/test_20_quantity.py @@ -1,12 +1,13 @@ -from typing import Any +from __future__ import annotations + import itertools as it import operator +from typing import Any import pytest import pint - UNITS = ("meter", "kilometer", "second", "minute", "angstrom") ALL_VALUES = ("int", "float", "complex") ALL_VALUES_Q = tuple( diff --git a/pint/testsuite/benchmarks/test_30_numpy.py b/pint/testsuite/benchmarks/test_30_numpy.py index 94e9f1519..482db5792 100644 --- a/pint/testsuite/benchmarks/test_30_numpy.py +++ b/pint/testsuite/benchmarks/test_30_numpy.py @@ -1,6 +1,9 @@ -from typing import Generator, Any +from __future__ import annotations + import itertools as it import operator +from collections.abc import Generator +from typing import Any import pytest diff --git a/pint/testsuite/conftest.py b/pint/testsuite/conftest.py index d51bc8c05..775480f0b 100644 --- a/pint/testsuite/conftest.py +++ b/pint/testsuite/conftest.py @@ -1,4 +1,5 @@ # pytest fixtures +from __future__ import annotations import pathlib @@ -6,7 +7,6 @@ import pint - _TINY = """ yocto- = 1e-24 = y- zepto- = 1e-21 = z- diff --git a/pint/testsuite/helpers.py b/pint/testsuite/helpers.py index 4121e09eb..d317e0755 100644 --- a/pint/testsuite/helpers.py +++ b/pint/testsuite/helpers.py @@ -1,7 +1,9 @@ +from __future__ import annotations + +import contextlib import doctest import pickle import re -import contextlib import pytest from packaging.version import parse as version_parse @@ -126,9 +128,26 @@ def requires_numpy_at_least(version): ) -requires_babel = pytest.mark.skipif( - not HAS_BABEL, reason="Requires Babel with units support" -) +def requires_babel(tested_locales=[]): + if not HAS_BABEL: + return pytest.mark.skip("Requires Babel with units support") + + import locale + + default_locale = locale.getlocale(locale.LC_NUMERIC) + locales_unavailable = False + try: + for loc in tested_locales: + locale.setlocale(locale.LC_NUMERIC, loc) + except locale.Error: + locales_unavailable = True + locale.setlocale(locale.LC_NUMERIC, default_locale) + + return pytest.mark.skipif( + locales_unavailable, reason="Tested locales not available." + ) + + requires_not_babel = pytest.mark.skipif( HAS_BABEL, reason="Requires Babel not to be installed" ) diff --git a/pint/testsuite/test_application_registry.py b/pint/testsuite/test_application_registry.py index a9bc84ee1..477e9f650 100644 --- a/pint/testsuite/test_application_registry.py +++ b/pint/testsuite/test_application_registry.py @@ -1,5 +1,7 @@ """Tests for global UnitRegistry, Unit, and Quantity """ +from __future__ import annotations + import pickle import pytest diff --git a/pint/testsuite/test_babel.py b/pint/testsuite/test_babel.py index d4e2194d7..9adcb04a9 100644 --- a/pint/testsuite/test_babel.py +++ b/pint/testsuite/test_babel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pytest @@ -14,7 +16,7 @@ def test_no_babel(func_registry): distance.format_babel(locale="fr_FR", length="long") -@helpers.requires_babel() +@helpers.requires_babel(["fr_FR", "ro_RO"]) def test_format(func_registry): ureg = func_registry dirname = os.path.dirname(__file__) @@ -28,13 +30,13 @@ def test_format(func_registry): acceleration = distance / time**2 assert ( acceleration.format_babel(spec=".3nP", locale="fr_FR", length="long") - == "0,367 mètre/seconde²" + == "0,367 mètre par seconde²" ) mks = ureg.get_system("mks") assert mks.format_babel(locale="fr_FR") == "métrique" -@helpers.requires_babel() +@helpers.requires_babel(["fr_FR", "ro_RO"]) def test_registry_locale(): ureg = UnitRegistry(fmt_locale="fr_FR") dirname = os.path.dirname(__file__) @@ -51,13 +53,14 @@ def test_registry_locale(): == "0,367 mètre/seconde**2" ) assert ( - acceleration.format_babel(spec=".3nP", length="long") == "0,367 mètre/seconde²" + acceleration.format_babel(spec=".3nP", length="long") + == "0,367 mètre par seconde²" ) mks = ureg.get_system("mks") assert mks.format_babel(locale="fr_FR") == "métrique" -@helpers.requires_babel() +@helpers.requires_babel(["fr_FR"]) def test_unit_format_babel(): ureg = UnitRegistry(fmt_locale="fr_FR") volume = ureg.Unit("ml") @@ -82,7 +85,7 @@ def test_no_registry_locale(func_registry): distance.format_babel() -@helpers.requires_babel() +@helpers.requires_babel(["fr_FR"]) def test_str(func_registry): ureg = func_registry d = 24.1 * ureg.meter diff --git a/pint/testsuite/test_compat.py b/pint/testsuite/test_compat.py index 5f3ba5d00..70a6e8e75 100644 --- a/pint/testsuite/test_compat.py +++ b/pint/testsuite/test_compat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math from datetime import datetime, timedelta diff --git a/pint/testsuite/test_compat_downcast.py b/pint/testsuite/test_compat_downcast.py index cffc3bbc6..2fccbacab 100644 --- a/pint/testsuite/test_compat_downcast.py +++ b/pint/testsuite/test_compat_downcast.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import operator + import pytest from pint import UnitRegistry diff --git a/pint/testsuite/test_compat_upcast.py b/pint/testsuite/test_compat_upcast.py index c8266f732..76ec69cbf 100644 --- a/pint/testsuite/test_compat_upcast.py +++ b/pint/testsuite/test_compat_upcast.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import operator + import pytest # Conditionally import NumPy and any upcast type libraries diff --git a/pint/testsuite/test_contexts.py b/pint/testsuite/test_contexts.py index 1a5bab237..073a5a69e 100644 --- a/pint/testsuite/test_contexts.py +++ b/pint/testsuite/test_contexts.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import logging import math @@ -16,7 +18,6 @@ from pint.testsuite import helpers from pint.util import UnitsContainer - from .helpers import internal diff --git a/pint/testsuite/test_converters.py b/pint/testsuite/test_converters.py index 71a076ff5..40346c700 100644 --- a/pint/testsuite/test_converters.py +++ b/pint/testsuite/test_converters.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools from pint.compat import np diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py index 0e6a1cfe7..e52640ff4 100644 --- a/pint/testsuite/test_dask.py +++ b/pint/testsuite/test_dask.py @@ -1,5 +1,6 @@ -import importlib +from __future__ import annotations +import importlib import pathlib import pytest diff --git a/pint/testsuite/test_definitions.py b/pint/testsuite/test_definitions.py index 69a337db7..56a107689 100644 --- a/pint/testsuite/test_definitions.py +++ b/pint/testsuite/test_definitions.py @@ -1,7 +1,9 @@ -import pytest +from __future__ import annotations import math +import pytest + from pint.definitions import Definition from pint.errors import DefinitionSyntaxError from pint.facets.nonmultiplicative.definitions import ( diff --git a/pint/testsuite/test_diskcache.py b/pint/testsuite/test_diskcache.py index 060d3f56c..16f3460c6 100644 --- a/pint/testsuite/test_diskcache.py +++ b/pint/testsuite/test_diskcache.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import decimal import pickle import time +import flexparser as fp import pytest import pint -from pint._vendor import flexparser as fp from pint.facets.plain import UnitDefinition FS_SLEEP = 0.010 diff --git a/pint/testsuite/test_errors.py b/pint/testsuite/test_errors.py index a045f6e19..e0c4ec3f4 100644 --- a/pint/testsuite/test_errors.py +++ b/pint/testsuite/test_errors.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle import pytest @@ -142,3 +144,13 @@ def test_pickle_definition_syntax_error(self, subtests): with pytest.raises(PintError): raise ex + + def test_dimensionality_error_message(self): + ureg = UnitRegistry(system="SI") + with pytest.raises(ValueError) as error: + ureg.get_dimensionality("[bilbo]") + + assert ( + str(error.value) + == "[bilbo] is not defined as dimension in the pint UnitRegistry" + ) diff --git a/pint/testsuite/test_formatter.py b/pint/testsuite/test_formatter.py index 761414b75..d8b5722bc 100644 --- a/pint/testsuite/test_formatter.py +++ b/pint/testsuite/test_formatter.py @@ -1,61 +1,45 @@ +from __future__ import annotations + import pytest from pint import formatting as fmt -import pint.delegates.formatter._format_helpers +from pint.delegates.formatter._format_helpers import formatter, join_u class TestFormatter: def test_join(self): for empty in ((), []): - assert fmt._join("s", empty) == "" - assert fmt._join("*", "1 2 3".split()) == "1*2*3" - assert fmt._join("{0}*{1}", "1 2 3".split()) == "1*2*3" + assert join_u("s", empty) == "" + assert join_u("*", "1 2 3".split()) == "1*2*3" + assert join_u("{0}*{1}", "1 2 3".split()) == "1*2*3" def test_formatter(self): - assert pint.delegates.formatter._format_helpers.formatter({}.items()) == "" - assert ( - pint.delegates.formatter._format_helpers.formatter(dict(meter=1).items()) - == "meter" - ) - assert ( - pint.delegates.formatter._format_helpers.formatter(dict(meter=-1).items()) - == "1 / meter" - ) - assert ( - pint.delegates.formatter._format_helpers.formatter( - dict(meter=-1).items(), as_ratio=False - ) - == "meter ** -1" - ) + assert formatter({}.items(), ()) == "" + assert formatter(dict(meter=1).items(), ()) == "meter" + assert formatter((), dict(meter=-1).items()) == "1 / meter" + assert formatter((), dict(meter=-1).items(), as_ratio=False) == "meter ** -1" assert ( - pint.delegates.formatter._format_helpers.formatter( - dict(meter=-1, second=-1).items(), as_ratio=False - ) + formatter((), dict(meter=-1, second=-1).items(), as_ratio=False) == "meter ** -1 * second ** -1" ) assert ( - pint.delegates.formatter._format_helpers.formatter( - dict(meter=-1, second=-1).items() + formatter( + (), + dict(meter=-1, second=-1).items(), ) == "1 / meter / second" ) assert ( - pint.delegates.formatter._format_helpers.formatter( - dict(meter=-1, second=-1).items(), single_denominator=True - ) + formatter((), dict(meter=-1, second=-1).items(), single_denominator=True) == "1 / (meter * second)" ) assert ( - pint.delegates.formatter._format_helpers.formatter( - dict(meter=-1, second=-2).items() - ) + formatter((), dict(meter=-1, second=-2).items()) == "1 / meter / second ** 2" ) assert ( - pint.delegates.formatter._format_helpers.formatter( - dict(meter=-1, second=-2).items(), single_denominator=True - ) + formatter((), dict(meter=-1, second=-2).items(), single_denominator=True) == "1 / (meter * second ** 2)" ) diff --git a/pint/testsuite/test_formatting.py b/pint/testsuite/test_formatting.py index 48e770b3b..d8f10715b 100644 --- a/pint/testsuite/test_formatting.py +++ b/pint/testsuite/test_formatting.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest import pint.formatting as fmt @@ -57,6 +59,8 @@ def test_split_format(format, default, flag, expected): def test_register_unit_format(func_registry): @fmt.register_unit_format("custom") def format_custom(unit, registry, **options): + # Ensure the registry is correct.. + registry.Unit(unit) return "" quantity = 1.0 * func_registry.meter diff --git a/pint/testsuite/test_infer_base_unit.py b/pint/testsuite/test_infer_base_unit.py index b40e5d6e2..f5d710b7d 100644 --- a/pint/testsuite/test_infer_base_unit.py +++ b/pint/testsuite/test_infer_base_unit.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from decimal import Decimal from fractions import Fraction diff --git a/pint/testsuite/test_issues.py b/pint/testsuite/test_issues.py index 3db01fb4e..97eca3cde 100644 --- a/pint/testsuite/test_issues.py +++ b/pint/testsuite/test_issues.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import decimal import math @@ -5,14 +7,19 @@ import pytest -from pint import Context, DimensionalityError, UnitRegistry, get_application_registry +from pint import ( + Context, + DimensionalityError, + UnitRegistry, + get_application_registry, +) from pint.compat import np +from pint.delegates.formatter._compound_unit_helpers import sort_by_dimensionality from pint.facets.plain.unit import UnitsContainer from pint.testing import assert_equal from pint.testsuite import QuantityTestCase, helpers from pint.util import ParserHelper - from .helpers import internal @@ -886,14 +893,34 @@ def test_issue_1300(self): m = module_registry.Measurement(1, 0.1, "meter") assert m.default_format == "~P" - @helpers.requires_babel() + @helpers.requires_numpy() + def test_issue1674(self, module_registry): + Q_ = module_registry.Quantity + arr_of_q = np.array([Q_(2, "m"), Q_(4, "m")], dtype="object") + q_arr = Q_(np.array([1, 2]), "m") + + helpers.assert_quantity_equal( + arr_of_q * q_arr, np.array([Q_(2, "m^2"), Q_(8, "m^2")], dtype="object") + ) + helpers.assert_quantity_equal( + arr_of_q / q_arr, np.array([Q_(2, ""), Q_(2, "")], dtype="object") + ) + + arr_of_q = np.array([Q_(2, "m"), Q_(4, "s")], dtype="object") + q_arr = Q_(np.array([1, 2]), "m") + + helpers.assert_quantity_equal( + arr_of_q * q_arr, np.array([Q_(2, "m^2"), Q_(8, "m s")], dtype="object") + ) + + @helpers.requires_babel(["es_ES"]) def test_issue_1400(self, sess_registry): q1 = 3.1 * sess_registry.W q2 = 3.1 * sess_registry.W / sess_registry.cm assert q1.format_babel("~", locale="es_ES") == "3,1 W" assert q1.format_babel("", locale="es_ES") == "3,1 vatios" - assert q2.format_babel("~", locale="es_ES") == "3,1 W / cm" - assert q2.format_babel("", locale="es_ES") == "3,1 vatios / centímetros" + assert q2.format_babel("~", locale="es_ES") == "3,1 W/cm" + assert q2.format_babel("", locale="es_ES") == "3,1 vatios por centímetro" @helpers.requires_uncertainties() def test_issue1611(self, module_registry): @@ -908,7 +935,7 @@ def test_issue1611(self, module_registry): u2 = ufloat(5.6, 0.78) q1_u = module_registry.Quantity(u2 - u1, "m") q1_str = str(q1_u) - q1_str = "{:.4uS}".format(q1_u) + q1_str = f"{q1_u:.4uS}" q1_m = q1_u.magnitude q2_u = module_registry.Quantity(q1_str) # Not equal because the uncertainties are differently random! @@ -1145,7 +1172,7 @@ def test_issue1725(registry_empty): assert registry_empty.get_compatible_units("dollar") == set() -def test_issues_1505(): +def test_issue1505(): ur = UnitRegistry(non_int_type=decimal.Decimal) assert isinstance(ur.Quantity("1m/s").magnitude, decimal.Decimal) @@ -1155,3 +1182,124 @@ def test_issues_1505(): assert isinstance( ur.Quantity("m/s").magnitude, decimal.Decimal ) # unexpected fail (magnitude should be a decimal) + + +def test_issue_1845(): + ur = UnitRegistry(auto_reduce_dimensions=True, non_int_type=decimal.Decimal) + # before issue 1845 these inputs would have resulted in a TypeError + assert ur("km / h * m").units == ur.Quantity("meter ** 2 / hour") + assert ur("kW / min * W").units == ur.Quantity("watts ** 2 / minute") + + +@pytest.mark.parametrize( + "units,spec,expected", + [ + # (dict(hour=1, watt=1), "P~", "W·h"), + (dict(ampere=1, volt=1), "P~", "V·A"), + # (dict(meter=1, newton=1), "P~", "N·m"), + ], +) +def test_issues_1841(func_registry, units, spec, expected): + ur = func_registry + ur.formatter.default_sort_func = sort_by_dimensionality + ur.default_format = spec + value = ur.Unit(UnitsContainer(**units)) + assert f"{value}" == expected + + +@pytest.mark.xfail +def test_issues_1841_xfail(): + from pint import formatting as fmt + from pint.delegates.formatter._compound_unit_helpers import sort_by_dimensionality + + # sets compact display mode by default + ur = UnitRegistry() + ur.default_format = "~P" + ur.formatter.default_sort_func = sort_by_dimensionality + + q = ur.Quantity("2*pi radian * hour") + + # Note that `radian` (and `bit` and `count`) are treated as dimensionless. + # And note that dimensionless quantities are stripped by this process, + # leading to errorneous output. Suggestions? + assert ( + fmt.format_unit(q.u._units, spec="", registry=ur, sort_dims=True) + == "radian * hour" + ) + assert ( + fmt.format_unit(q.u._units, spec="", registry=ur, sort_dims=False) + == "hour * radian" + ) + + # this prints "2*pi hour * radian", not "2*pi radian * hour" unless sort_dims is True + # print(q) + + +def test_issue1949(registry_empty): + ureg = UnitRegistry() + ureg.define( + "in_Hg_gauge = 3386389 * gram / metre / second ** 2; offset:101325000 = inHg_g = in_Hg_g = inHg_gauge" + ) + q = ureg.Quantity("1 atm").to("inHg_gauge") + assert q.units == ureg.in_Hg_gauge + assert_equal(q.magnitude, 0.0) + + +@pytest.mark.parametrize( + "given,expected", + [ + ( + "8.989e9 newton * meter^2 / coulomb^2", + r"\SI[]{8.989E+9}{\meter\squared\newton\per\coulomb\squared}", + ), + ("5 * meter / second", r"\SI[]{5}{\meter\per\second}"), + ("2.2 * meter^4", r"\SI[]{2.2}{\meter\tothe{4}}"), + ("2.2 * meter^-4", r"\SI[]{2.2}{\per\meter\tothe{4}}"), + ], +) +def test_issue1772(given, expected): + ureg = UnitRegistry(non_int_type=decimal.Decimal) + assert f"{ureg(given):Lx}" == expected + + +def test_issue2017(): + ureg = UnitRegistry() + + from pint import formatting as fmt + + @fmt.register_unit_format("test") + def _test_format(unit, registry, **options): + print("format called") + proc = {u.replace("µ", "u"): e for u, e in unit.items()} + return fmt.formatter( + proc.items(), + as_ratio=True, + single_denominator=False, + product_fmt="*", + division_fmt="/", + power_fmt="{}{}", + parentheses_fmt="({})", + **options, + ) + + base_unit = ureg.microsecond + assert f"{base_unit:~test}" == "us" + assert f"{base_unit:test}" == "microsecond" + + +def test_issue2007(): + ureg = UnitRegistry() + q = ureg.Quantity(1, "") + assert f"{q:P}" == "1 dimensionless" + assert f"{q:C}" == "1 dimensionless" + assert f"{q:D}" == "1 dimensionless" + assert f"{q:H}" == "1 dimensionless" + + assert f"{q:L}" == "1\\ \\mathrm{dimensionless}" + # L returned '1\\ dimensionless' in pint 0.23 + + assert f"{q:Lx}" == "\\SI[]{1}{}" + assert f"{q:~P}" == "1" + assert f"{q:~C}" == "1" + assert f"{q:~D}" == "1" + assert f"{q:~H}" == "1" diff --git a/pint/testsuite/test_log_units.py b/pint/testsuite/test_log_units.py index 3d1c90514..5f1b0be49 100644 --- a/pint/testsuite/test_log_units.py +++ b/pint/testsuite/test_log_units.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import math @@ -63,6 +65,11 @@ def test_log_convert(self): helpers.assert_quantity_almost_equal( self.Q_(0.0, "dBm"), self.Q_(29.999999999999996, "dBu"), atol=1e-7 ) + # ## Test dB to dB units dBm - dBW + # 0 dBW = 1W = 1e3 mW = 30 dBm + helpers.assert_quantity_almost_equal( + self.Q_(0.0, "dBW"), self.Q_(29.999999999999996, "dBm"), atol=1e-7 + ) def test_mix_regular_log_units(self): # Test regular-logarithmic mixed definition, such as dB/km or dB/cm @@ -82,6 +89,8 @@ def test_mix_regular_log_units(self): log_unit_names = [ + "decibelwatt", + "dBW", "decibelmilliwatt", "dBm", "decibelmicrowatt", @@ -133,6 +142,7 @@ def test_quantity_by_multiplication(module_registry_auto_offset, unit_name, mag) @pytest.mark.parametrize( "unit1,unit2", [ + ("decibelwatt", "dBW"), ("decibelmilliwatt", "dBm"), ("decibelmicrowatt", "dBu"), ("decibel", "dB"), diff --git a/pint/testsuite/test_matplotlib.py b/pint/testsuite/test_matplotlib.py index 0735721c0..5327b5b0b 100644 --- a/pint/testsuite/test_matplotlib.py +++ b/pint/testsuite/test_matplotlib.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from pint import UnitRegistry diff --git a/pint/testsuite/test_measurement.py b/pint/testsuite/test_measurement.py index 8a98128ef..8f20deead 100644 --- a/pint/testsuite/test_measurement.py +++ b/pint/testsuite/test_measurement.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from pint import DimensionalityError @@ -190,8 +192,9 @@ def test_format_exponential_neg(self, func_registry, spec, expected): ], ) def test_format_default(self, func_registry, spec, expected): - v, u = func_registry.Quantity(4.0, "s ** 2"), func_registry.Quantity( - 0.1, "s ** 2" + v, u = ( + func_registry.Quantity(4.0, "s ** 2"), + func_registry.Quantity(0.1, "s ** 2"), ) m = func_registry.Measurement(v, u) func_registry.default_format = spec diff --git a/pint/testsuite/test_non_int.py b/pint/testsuite/test_non_int.py index 5a74a993a..ccf0dd6ff 100644 --- a/pint/testsuite/test_non_int.py +++ b/pint/testsuite/test_non_int.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import math import operator as op diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py index 15e56358a..3075be7ac 100644 --- a/pint/testsuite/test_numpy.py +++ b/pint/testsuite/test_numpy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import operator as op import pickle @@ -286,6 +288,11 @@ def test_broadcast_arrays(self): result = np.broadcast_arrays(x, y, subok=True) helpers.assert_quantity_equal(result, expected) + def test_roll(self): + helpers.assert_quantity_equal( + np.roll(self.q, 1), [[4, 1], [2, 3]] * self.ureg.m + ) + class TestNumpyMathematicalFunctions(TestNumpyMethods): # https://www.numpy.org/devdocs/reference/routines.math.html @@ -330,9 +337,7 @@ def test_prod_numpy_func(self): helpers.assert_quantity_equal( np.prod(self.q, axis=axis), [3, 8] * self.ureg.m**2 ) - helpers.assert_quantity_equal( - np.prod(self.q, where=where), 12 * self.ureg.m**3 - ) + helpers.assert_quantity_equal(np.prod(self.q, where=where), 12 * self.ureg.m**3) with pytest.raises(DimensionalityError): np.prod(self.q, axis=axis, where=where) @@ -380,12 +385,7 @@ def test_cumprod(self): def test_cumprod_numpy_func(self): with pytest.raises(DimensionalityError): np.cumprod(self.q) - with pytest.raises(DimensionalityError): - np.cumproduct(self.q) helpers.assert_quantity_equal(np.cumprod(self.q / self.ureg.m), [1, 2, 6, 24]) - helpers.assert_quantity_equal( - np.cumproduct(self.q / self.ureg.m), [1, 2, 6, 24] - ) helpers.assert_quantity_equal( np.cumprod(self.q / self.ureg.m, axis=1), [[1, 2], [3, 12]] ) @@ -438,6 +438,7 @@ def test_cross(self): np.cross(a, b), [[-15, -2, 39]] * self.ureg.kPa * self.ureg.m**2 ) + # NP2: Remove this when we only support np>=2.0 @helpers.requires_array_function_protocol() def test_trapz(self): helpers.assert_quantity_equal( @@ -445,6 +446,16 @@ def test_trapz(self): 7.5 * self.ureg.J * self.ureg.m, ) + @helpers.requires_array_function_protocol() + # NP2: Remove this when we only support np>=2.0 + # trapezoid added in numpy 2.0 + @helpers.requires_numpy_at_least("2.0") + def test_trapezoid(self): + helpers.assert_quantity_equal( + np.trapezoid([1.0, 2.0, 3.0, 4.0] * self.ureg.J, dx=1 * self.ureg.m), + 7.5 * self.ureg.J * self.ureg.m, + ) + @helpers.requires_array_function_protocol() def test_dot(self): helpers.assert_quantity_equal( @@ -758,9 +769,12 @@ def test_minimum(self): np.minimum(self.q, self.Q_([0, 5], "m")), self.Q_([[0, 2], [0, 4]], "m") ) + # NP2: Can remove Q_(arr).ptp test when we only support numpy>=2 def test_ptp(self): - assert self.q.ptp() == 3 * self.ureg.m + if not np.lib.NumpyVersion(np.__version__) >= "2.0.0b1": + assert self.q.ptp() == 3 * self.ureg.m + # NP2: Keep this test for numpy>=2, it's only arr.ptp() that is deprecated @helpers.requires_array_function_protocol() def test_ptp_numpy_func(self): helpers.assert_quantity_equal(np.ptp(self.q, axis=0), [2, 2] * self.ureg.m) diff --git a/pint/testsuite/test_numpy_func.py b/pint/testsuite/test_numpy_func.py index 7a0cdb7e3..9c69a238d 100644 --- a/pint/testsuite/test_numpy_func.py +++ b/pint/testsuite/test_numpy_func.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import ExitStack from unittest.mock import patch @@ -214,6 +216,14 @@ def test_trapz_no_autoconvert(self): with pytest.raises(OffsetUnitCalculusError): np.trapz(t, x=z) + def test_correlate(self): + a = self.Q_(np.array([1, 2, 3]), "m") + v = self.Q_(np.array([0, 1, 0.5]), "s") + res = np.correlate(a, v, "full") + ref = np.array([0.5, 2.0, 3.5, 3.0, 0.0]) + assert np.array_equal(res.magnitude, ref) + assert res.units == "meter * second" + def test_dot(self): with ExitStack() as stack: stack.callback( diff --git a/pint/testsuite/test_pint_eval.py b/pint/testsuite/test_pint_eval.py index fc0012e6d..3cee7d758 100644 --- a/pint/testsuite/test_pint_eval.py +++ b/pint/testsuite/test_pint_eval.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from pint.pint_eval import build_eval_tree, tokenizer diff --git a/pint/testsuite/test_pitheorem.py b/pint/testsuite/test_pitheorem.py index 9893f507c..665d5798e 100644 --- a/pint/testsuite/test_pitheorem.py +++ b/pint/testsuite/test_pitheorem.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import logging diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py index 3fdf8c83b..8c6f15c49 100644 --- a/pint/testsuite/test_quantity.py +++ b/pint/testsuite/test_quantity.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import datetime import logging @@ -16,6 +18,7 @@ get_application_registry, ) from pint.compat import np +from pint.errors import UndefinedBehavior from pint.facets.plain.unit import UnitsContainer from pint.testsuite import QuantityTestCase, assert_no_warnings, helpers @@ -172,7 +175,7 @@ def test_quantity_format(self, subtests): ("{:Lx}", r"\SI[]{4.12345678}{\kilo\gram\meter\squared\per\second}"), ): with subtests.test(spec): - assert spec.format(x) == result + assert spec.format(x) == result, spec # Check the special case that prevents e.g. '3 1 / second' x = self.Q_(3, UnitsContainer(second=-1)) @@ -833,7 +836,7 @@ def test_limits_magnitudes(self): def test_nonnumeric_magnitudes(self): ureg = self.ureg x = "some string" * ureg.m - with pytest.warns(RuntimeWarning): + with pytest.warns(UndefinedBehavior): self.compare_quantity_compact(x, x) def test_very_large_to_compact(self): diff --git a/pint/testsuite/test_systems.py b/pint/testsuite/test_systems.py index 49da32c52..9e78a3d1e 100644 --- a/pint/testsuite/test_systems.py +++ b/pint/testsuite/test_systems.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import pytest from pint import UnitRegistry from pint.testsuite import QuantityTestCase - from .helpers import internal diff --git a/pint/testsuite/test_testing.py b/pint/testsuite/test_testing.py index eab04fcb9..dfb8b0602 100644 --- a/pint/testsuite/test_testing.py +++ b/pint/testsuite/test_testing.py @@ -1,7 +1,9 @@ -import pytest +from __future__ import annotations from typing import Any +import pytest + from .. import testing np = pytest.importorskip("numpy") diff --git a/pint/testsuite/test_umath.py b/pint/testsuite/test_umath.py index 73d0ae776..a555a7664 100644 --- a/pint/testsuite/test_umath.py +++ b/pint/testsuite/test_umath.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from pint import DimensionalityError, UnitRegistry diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py index 285ad303a..2156bbafd 100644 --- a/pint/testsuite/test_unit.py +++ b/pint/testsuite/test_unit.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import functools import logging @@ -987,6 +989,8 @@ class TestConvertWithOffset(QuantityTestCase): (({"degC": 2}, {"kelvin": 2}), "error"), (({"degC": 1, "degF": 1}, {"kelvin": 2}), "error"), (({"degC": 1, "kelvin": 1}, {"kelvin": 2}), "error"), + (({"delta_degC": 1}, {"degF": 1}), "error"), + (({"delta_degC": 1}, {"degC": 1}), "error"), ] @pytest.mark.parametrize(("input_tuple", "expected"), convert_with_offset) diff --git a/pint/testsuite/test_util.py b/pint/testsuite/test_util.py index 70136cf35..0a6d357d0 100644 --- a/pint/testsuite/test_util.py +++ b/pint/testsuite/test_util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections import copy import math diff --git a/pint/toktest.py b/pint/toktest.py index ef606d6a9..e0026a21d 100644 --- a/pint/toktest.py +++ b/pint/toktest.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import tokenize + from pint.pint_eval import _plain_tokenizer, uncertainty_tokenizer tokenizer = _plain_tokenizer diff --git a/pint/util.py b/pint/util.py index 45f409135..c7a7ec10c 100644 --- a/pint/util.py +++ b/pint/util.py @@ -14,30 +14,26 @@ import math import operator import re -from collections.abc import Mapping, Iterable, Iterator +import tokenize +import types +from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping from fractions import Fraction from functools import lru_cache, partial from logging import NullHandler from numbers import Number from token import NAME, NUMBER -import tokenize -import types from typing import ( TYPE_CHECKING, + Any, ClassVar, - Callable, TypeVar, - Any, - Optional, ) -from collections.abc import Hashable, Generator +from . import pint_eval +from ._typing import Scalar from .compat import NUMERIC_TYPES, Self from .errors import DefinitionSyntaxError from .pint_eval import build_eval_tree -from . import pint_eval - -from ._typing import Scalar if TYPE_CHECKING: from ._typing import QuantityOrUnitLike @@ -64,8 +60,8 @@ def _noop(x: T) -> T: def matrix_to_string( matrix: ItMatrix, - row_headers: Optional[Iterable[str]] = None, - col_headers: Optional[Iterable[str]] = None, + row_headers: Iterable[str] | None = None, + col_headers: Iterable[str] | None = None, fmtfun: Callable[ [ Scalar, @@ -180,9 +176,7 @@ def column_echelon_form( ItMatrix, ], Matrix, - ] = ( - transpose if transpose_result else _noop - ) + ] = transpose if transpose_result else _noop ech_matrix = matrix_apply( transpose(matrix), @@ -232,7 +226,7 @@ def column_echelon_form( return _transpose(ech_matrix), _transpose(id_matrix), swapped -def pi_theorem(quantities: dict[str, Any], registry: Optional[UnitRegistry] = None): +def pi_theorem(quantities: dict[str, Any], registry: UnitRegistry | None = None): """Builds dimensionless quantities using the Buckingham π theorem Parameters @@ -309,7 +303,7 @@ def pi_theorem(quantities: dict[str, Any], registry: Optional[UnitRegistry] = No def solve_dependencies( - dependencies: dict[TH, set[TH]] + dependencies: dict[TH, set[TH]], ) -> Generator[set[TH], None, None]: """Solve a dependency graph. @@ -348,7 +342,7 @@ def solve_dependencies( def find_shortest_path( - graph: dict[TH, set[TH]], start: TH, end: TH, path: Optional[list[TH]] = None + graph: dict[TH, set[TH]], start: TH, end: TH, path: list[TH] | None = None ): """Find shortest path between two nodes within a graph. @@ -390,8 +384,8 @@ def find_shortest_path( def find_connected_nodes( - graph: dict[TH, set[TH]], start: TH, visited: Optional[set[TH]] = None -) -> Optional[set[TH]]: + graph: dict[TH, set[TH]], start: TH, visited: set[TH] | None = None +) -> set[TH] | None: """Find all nodes connected to a start node within a graph. Parameters @@ -451,12 +445,12 @@ class UnitsContainer(Mapping[str, Scalar]): __slots__ = ("_d", "_hash", "_one", "_non_int_type") _d: udict - _hash: Optional[int] + _hash: int | None _one: Scalar _non_int_type: type def __init__( - self, *args: Any, non_int_type: Optional[type] = None, **kwargs: Any + self, *args: Any, non_int_type: type | None = None, **kwargs: Any ) -> None: if args and isinstance(args[0], UnitsContainer): default_non_int_type = args[0]._non_int_type @@ -501,7 +495,7 @@ def add(self: Self, key: str, value: Number) -> Self: UnitsContainer A copy of this container. """ - newval = self._d[key] + value + newval = self._d[key] + self._normalize_nonfloat_value(value) new = self.copy() if newval: new._d[key] = newval @@ -549,6 +543,9 @@ def rename(self: Self, oldkey: str, newkey: str) -> Self: new._hash = None return new + def unit_items(self) -> Iterable[tuple[str, Scalar]]: + return self._d.items() + def __iter__(self) -> Iterator[str]: return iter(self._d) @@ -659,7 +656,7 @@ def __truediv__(self, other: Any): new = self.copy() for key, value in other.items(): - new._d[key] -= value + new._d[key] -= self._normalize_nonfloat_value(value) if new._d[key] == 0: del new._d[key] @@ -673,6 +670,11 @@ def __rtruediv__(self, other: Any): return self**-1 + def _normalize_nonfloat_value(self, value: Scalar) -> Scalar: + if not isinstance(value, int) and not isinstance(value, self._non_int_type): + return self._non_int_type(value) # type: ignore[no-any-return] + return value + class ParserHelper(UnitsContainer): """The ParserHelper stores in place the product of variables and @@ -1027,7 +1029,7 @@ def _repr_pretty_(self, p, cycle: bool): def to_units_container( - unit_like: QuantityOrUnitLike, registry: Optional[UnitRegistry] = None + unit_like: QuantityOrUnitLike, registry: UnitRegistry | None = None ) -> UnitsContainer: """Convert a unit compatible type to a UnitsContainer. @@ -1064,7 +1066,7 @@ def to_units_container( def infer_base_unit( - unit_like: QuantityOrUnitLike, registry: Optional[UnitRegistry] = None + unit_like: QuantityOrUnitLike, registry: UnitRegistry | None = None ) -> UnitsContainer: """ Given a Quantity or UnitLike, give the UnitsContainer for it's plain units. diff --git a/pyproject.toml b/pyproject.toml index 4b6b7312d..9f29f8f92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,13 +24,11 @@ classifiers = [ "Topic :: Software Development :: Libraries", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11" + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] requires-python = ">=3.9" -dynamic = ["version"] # Version is taken from git tags using setuptools_scm -dependencies = [ - "typing_extensions" -] +dynamic = ["version", "dependencies"] [tool.setuptools.package-data] pint = [ @@ -38,6 +36,8 @@ pint = [ "constants_en.txt", "py.typed"] +[tool.setuptools.dynamic] +dependencies = {file = "requirements.txt"} [project.optional-dependencies] testbase = [ @@ -57,7 +57,7 @@ bench = [ "pytest", "pytest-codspeed" ] -numpy = ["numpy >= 1.19.5"] +numpy = ["numpy >= 1.23"] uncertainties = ["uncertainties >= 3.1.6"] babel = ["babel <= 2.8"] pandas = ["pint-pandas >= 0.3"] @@ -81,12 +81,18 @@ build-backend = "setuptools.build_meta" [tool.setuptools_scm] -[tool.ruff.isort] +[tool.ruff] +extend-exclude = ["build"] +line-length=88 + +[tool.ruff.lint.isort] required-imports = ["from __future__ import annotations"] known-first-party= ["pint"] - -[tool.ruff] +[tool.ruff.lint] +extend-select = [ + "I", # isort +] ignore = [ # whitespace before ':' - doesn't work well with black # "E203", @@ -98,5 +104,3 @@ ignore = [ # line break before binary operator # "W503" ] -extend-exclude = ["build"] -line-length=88 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..0bc99005a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +appdirs>=1.4.4 +typing_extensions +flexcache>=0.3 +flexparser>=0.3