Skip to content

Commit

Permalink
Improve wraps performances (#1866)
Browse files Browse the repository at this point in the history
  • Loading branch information
Saelyos authored Dec 3, 2023
1 parent 04cc929 commit 5d533d6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 18 deletions.
54 changes: 36 additions & 18 deletions pint/registry_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from __future__ import annotations

import functools
from inspect import signature
from inspect import signature, Parameter
from itertools import zip_longest
from typing import TYPE_CHECKING, Callable, TypeVar, Any, Union, Optional
from collections.abc import Iterable
Expand Down Expand Up @@ -119,22 +119,27 @@ def _parse_wrap_args(args, registry=None):
"Not all variable referenced in %s are defined using !" % args[ndx]
)

def _converter(ureg, values, strict):
new_values = list(value for value in values)
def _converter(ureg, sig, values, kw, strict):
len_initial_values = len(values)

# pack kwargs
for i, param_name in enumerate(sig.parameters):
if i >= len_initial_values:
values.append(kw[param_name])

values_by_name = {}

# first pass: Grab named values
for ndx in defs_args_ndx:
value = values[ndx]
values_by_name[args_as_uc[ndx][0]] = value
new_values[ndx] = getattr(value, "_magnitude", value)
values[ndx] = getattr(value, "_magnitude", value)

# second pass: calculate derived values based on named values
for ndx in dependent_args_ndx:
value = values[ndx]
assert _replace_units(args_as_uc[ndx][0], values_by_name) is not None
new_values[ndx] = ureg._convert(
values[ndx] = ureg._convert(
getattr(value, "_magnitude", value),
getattr(value, "_units", UnitsContainer({})),
_replace_units(args_as_uc[ndx][0], values_by_name),
Expand All @@ -143,27 +148,32 @@ def _converter(ureg, values, strict):
# third pass: convert other arguments
for ndx in unit_args_ndx:
if isinstance(values[ndx], ureg.Quantity):
new_values[ndx] = ureg._convert(
values[ndx] = ureg._convert(
values[ndx]._magnitude, values[ndx]._units, args_as_uc[ndx][0]
)
else:
if strict:
if isinstance(values[ndx], str):
# if the value is a string, we try to parse it
tmp_value = ureg.parse_expression(values[ndx])
new_values[ndx] = ureg._convert(
values[ndx] = ureg._convert(
tmp_value._magnitude, tmp_value._units, args_as_uc[ndx][0]
)
else:
raise ValueError(
"A wrapped function using strict=True requires "
"quantity or a string for all arguments with not None units. "
"(error found for {}, {})".format(
args_as_uc[ndx][0], new_values[ndx]
args_as_uc[ndx][0], values[ndx]
)
)

return new_values, values_by_name
# unpack kwargs
for i, param_name in enumerate(sig.parameters):
if i >= len_initial_values:
kw[param_name] = values[i]

return values[:len_initial_values], kw, values_by_name

return _converter

Expand All @@ -175,12 +185,14 @@ def _apply_defaults(sig, args, kwargs):
values so that every argument is defined.
"""

bound_arguments = sig.bind(*args, **kwargs)
for param in sig.parameters.values():
if param.name not in bound_arguments.arguments:
bound_arguments.arguments[param.name] = param.default
args = [bound_arguments.arguments[key] for key in sig.parameters.keys()]
return args, {}
for i, param in enumerate(sig.parameters.values()):
if (
i >= len(args)
and param.default != Parameter.empty
and param.name not in kwargs
):
kwargs[param.name] = param.default
return list(args), kwargs


def wraps(
Expand Down Expand Up @@ -274,9 +286,11 @@ def wrapper(*values, **kw) -> Quantity:

# In principle, the values are used as is
# When then extract the magnitudes when needed.
new_values, values_by_name = converter(ureg, values, strict)
new_values, new_kw, values_by_name = converter(
ureg, sig, values, kw, strict
)

result = func(*new_values, **kw)
result = func(*new_values, **new_kw)

if is_ret_container:
out_units = (
Expand Down Expand Up @@ -352,7 +366,11 @@ def decorator(func):

@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*args, **kwargs):
list_args, empty = _apply_defaults(sig, args, kwargs)
list_args, kw = _apply_defaults(sig, args, kwargs)

for i, param_name in enumerate(sig.parameters):
if i >= len(args):
list_args.append(kw[param_name])

for dim, value in zip(dimensions, list_args):
if dim is None:
Expand Down
18 changes: 18 additions & 0 deletions pint/testsuite/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,23 @@ def hfunc(x, y):
h3 = ureg.wraps((None,), (None, None))(hfunc)
assert h3(3, 1) == (3, 1)

def kfunc(a, /, b, c=5, *, d=6):
return a, b, c, d

k1 = ureg.wraps((None,), (None, None, None, None))(kfunc)
assert k1(1, 2, 3, d=4) == (1, 2, 3, 4)
assert k1(1, 2, c=3, d=4) == (1, 2, 3, 4)
assert k1(1, b=2, c=3, d=4) == (1, 2, 3, 4)
assert k1(1, d=4, b=2, c=3) == (1, 2, 3, 4)
assert k1(1, 2, c=3) == (1, 2, 3, 6)
assert k1(1, 2, d=4) == (1, 2, 5, 4)
assert k1(1, 2) == (1, 2, 5, 6)

k2 = ureg.wraps((None,), ("meter", "centimeter", "meter", "centimeter"))(kfunc)
assert k2(
1 * ureg.meter, 2 * ureg.centimeter, 3 * ureg.meter, d=4 * ureg.centimeter
) == (1, 2, 3, 4)

def test_wrap_referencing(self):
ureg = self.ureg

Expand Down Expand Up @@ -643,6 +660,7 @@ def func(x):
assert f0(3.0 * ureg.centimeter) == 0.03 * ureg.meter
with pytest.raises(DimensionalityError):
f0(3.0 * ureg.kilogram)
assert f0(x=3.0 * ureg.centimeter) == 0.03 * ureg.meter

f0b = ureg.check(ureg.meter)(func)
with pytest.raises(DimensionalityError):
Expand Down

0 comments on commit 5d533d6

Please sign in to comment.