From d45f427fc152b8f514da6b6b8ff23527a8b297b1 Mon Sep 17 00:00:00 2001 From: Saelyos Date: Wed, 25 Oct 2023 11:27:11 +0200 Subject: [PATCH] Improve wraps performances --- pint/registry_helpers.py | 54 ++++++++++++++++++++++++------------- pint/testsuite/test_unit.py | 18 +++++++++++++ 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py index a31836ea6..37c539e35 100644 --- a/pint/registry_helpers.py +++ b/pint/registry_helpers.py @@ -11,7 +11,7 @@ from __future__ import annotations import functools -from inspect import signature +from inspect import signature, Parameter from itertools import zip_longest from typing import TYPE_CHECKING, Callable, TypeVar, Any, Union, Optional from collections.abc import Iterable @@ -119,8 +119,13 @@ def _parse_wrap_args(args, registry=None): "Not all variable referenced in %s are defined using !" % args[ndx] ) - def _converter(ureg, values, strict): - new_values = list(value for value in values) + def _converter(ureg, sig, values, kw, strict): + len_initial_values = len(values) + + # pack kwargs + for i, param_name in enumerate(sig.parameters): + if i >= len_initial_values: + values.append(kw[param_name]) values_by_name = {} @@ -128,13 +133,13 @@ def _converter(ureg, values, strict): for ndx in defs_args_ndx: value = values[ndx] values_by_name[args_as_uc[ndx][0]] = value - new_values[ndx] = getattr(value, "_magnitude", value) + values[ndx] = getattr(value, "_magnitude", value) # second pass: calculate derived values based on named values for ndx in dependent_args_ndx: value = values[ndx] assert _replace_units(args_as_uc[ndx][0], values_by_name) is not None - new_values[ndx] = ureg._convert( + values[ndx] = ureg._convert( getattr(value, "_magnitude", value), getattr(value, "_units", UnitsContainer({})), _replace_units(args_as_uc[ndx][0], values_by_name), @@ -143,7 +148,7 @@ def _converter(ureg, values, strict): # third pass: convert other arguments for ndx in unit_args_ndx: if isinstance(values[ndx], ureg.Quantity): - new_values[ndx] = ureg._convert( + values[ndx] = ureg._convert( values[ndx]._magnitude, values[ndx]._units, args_as_uc[ndx][0] ) else: @@ -151,7 +156,7 @@ def _converter(ureg, values, strict): if isinstance(values[ndx], str): # if the value is a string, we try to parse it tmp_value = ureg.parse_expression(values[ndx]) - new_values[ndx] = ureg._convert( + values[ndx] = ureg._convert( tmp_value._magnitude, tmp_value._units, args_as_uc[ndx][0] ) else: @@ -159,11 +164,16 @@ def _converter(ureg, values, strict): "A wrapped function using strict=True requires " "quantity or a string for all arguments with not None units. " "(error found for {}, {})".format( - args_as_uc[ndx][0], new_values[ndx] + args_as_uc[ndx][0], values[ndx] ) ) - return new_values, values_by_name + # unpack kwargs + for i, param_name in enumerate(sig.parameters): + if i >= len_initial_values: + kw[param_name] = values[i] + + return values[:len_initial_values], kw, values_by_name return _converter @@ -175,12 +185,14 @@ def _apply_defaults(sig, args, kwargs): values so that every argument is defined. """ - bound_arguments = sig.bind(*args, **kwargs) - for param in sig.parameters.values(): - if param.name not in bound_arguments.arguments: - bound_arguments.arguments[param.name] = param.default - args = [bound_arguments.arguments[key] for key in sig.parameters.keys()] - return args, {} + for i, param in enumerate(sig.parameters.values()): + if ( + i >= len(args) + and param.default != Parameter.empty + and param.name not in kwargs + ): + kwargs[param.name] = param.default + return list(args), kwargs def wraps( @@ -274,9 +286,11 @@ def wrapper(*values, **kw) -> Quantity: # In principle, the values are used as is # When then extract the magnitudes when needed. - new_values, values_by_name = converter(ureg, values, strict) + new_values, new_kw, values_by_name = converter( + ureg, sig, values, kw, strict + ) - result = func(*new_values, **kw) + result = func(*new_values, **new_kw) if is_ret_container: out_units = ( @@ -352,7 +366,11 @@ def decorator(func): @functools.wraps(func, assigned=assigned, updated=updated) def wrapper(*args, **kwargs): - list_args, empty = _apply_defaults(sig, args, kwargs) + list_args, kw = _apply_defaults(sig, args, kwargs) + + for i, param_name in enumerate(sig.parameters): + if i >= len(args): + list_args.append(kw[param_name]) for dim, value in zip(dimensions, list_args): if dim is None: diff --git a/pint/testsuite/test_unit.py b/pint/testsuite/test_unit.py index c1a2704b5..2fb6ecac2 100644 --- a/pint/testsuite/test_unit.py +++ b/pint/testsuite/test_unit.py @@ -593,6 +593,23 @@ def hfunc(x, y): h3 = ureg.wraps((None,), (None, None))(hfunc) assert h3(3, 1) == (3, 1) + def kfunc(a, /, b, c=5, *, d=6): + return a, b, c, d + + k1 = ureg.wraps((None,), (None, None, None, None))(kfunc) + assert k1(1, 2, 3, d=4) == (1, 2, 3, 4) + assert k1(1, 2, c=3, d=4) == (1, 2, 3, 4) + assert k1(1, b=2, c=3, d=4) == (1, 2, 3, 4) + assert k1(1, d=4, b=2, c=3) == (1, 2, 3, 4) + assert k1(1, 2, c=3) == (1, 2, 3, 6) + assert k1(1, 2, d=4) == (1, 2, 5, 4) + assert k1(1, 2) == (1, 2, 5, 6) + + k2 = ureg.wraps((None,), ("meter", "centimeter", "meter", "centimeter"))(kfunc) + assert k2( + 1 * ureg.meter, 2 * ureg.centimeter, 3 * ureg.meter, d=4 * ureg.centimeter + ) == (1, 2, 3, 4) + def test_wrap_referencing(self): ureg = self.ureg @@ -641,6 +658,7 @@ def func(x): assert f0(3.0 * ureg.centimeter) == 0.03 * ureg.meter with pytest.raises(DimensionalityError): f0(3.0 * ureg.kilogram) + assert f0(x=3.0 * ureg.centimeter) == 0.03 * ureg.meter f0b = ureg.check(ureg.meter)(func) with pytest.raises(DimensionalityError):