Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom pint wrapper #149

Merged
merged 3 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

**In development**

- {gh-pr}`149` {gh-issue}`145` Add custom pint wrapper for better handling of pint arrays.
- {gh-pr}`148` {gh-issue}`122` deprecate `LineParameters.from_name_lv()` in favor of the more generic
`LineParameters.from_geometry()`. The method will be removed in a future release.
- {gh-pr}`142` {gh-issue}`136` Add `Bus.res_voltage_unbalance()` method to get the Voltage Unbalance
Expand Down
148 changes: 148 additions & 0 deletions roseau/load_flow/_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import functools
from collections.abc import Iterable, MutableSequence
from inspect import Parameter, Signature, signature
from itertools import zip_longest
from typing import Any, Callable, Optional, TypeVar, Union

from pint import Quantity, Unit
from pint.registry import UnitRegistry
from pint.util import to_units_container

T = TypeVar("T")
FuncT = TypeVar("FuncT", bound=Callable)


def _parse_wrap_args(args: Iterable[Optional[Union[str, Unit]]]) -> Callable:
"""Create a converter function for the wrapper"""
# _to_units_container
args_as_uc = [to_units_container(arg) for arg in args]

# Check for references in args, remove None values
unit_args_ndx = {ndx for ndx, arg in enumerate(args_as_uc) if arg is not None}

def _converter(ureg: UnitRegistry, sig: Signature, values: list[Any], kw: dict[Any]):
len_initial_values = len(values)

# pack kwargs
for i, param_name in enumerate(sig.parameters):
if i >= len_initial_values:
values.append(kw[param_name])

# convert arguments
for ndx in unit_args_ndx:
value = values[ndx]
if isinstance(value, ureg.Quantity):
values[ndx] = ureg.convert(value.magnitude, value.units, args_as_uc[ndx])
elif isinstance(value, MutableSequence):
for i, val in enumerate(value):
if isinstance(val, ureg.Quantity):
value[i] = ureg.convert(val.magnitude, val.units, args_as_uc[ndx])

# unpack kwargs
for i, param_name in enumerate(sig.parameters):
if i >= len_initial_values:
kw[param_name] = values[i]

return values[:len_initial_values], kw

return _converter


def _apply_defaults(sig: Signature, args: tuple[Any], kwargs: dict[str, Any]) -> tuple[list[Any], dict[str, Any]]:
"""Apply default keyword arguments.

Named keywords may have been left blank. This function applies the default
values so that every argument is defined.
"""
n = len(args)
for i, param in enumerate(sig.parameters.values()):
if i >= n and param.default != Parameter.empty and param.name not in kwargs:
kwargs[param.name] = param.default
return list(args), kwargs


def wraps(
ureg: UnitRegistry,
ret: Optional[Union[str, Unit, Iterable[Optional[Union[str, Unit]]]]],
args: Optional[Union[str, Unit, Iterable[Optional[Union[str, Unit]]]]],
) -> Callable[[FuncT], FuncT]:
"""Wraps a function to become pint-aware.

Use it when a function requires a numerical value but in some specific
units. The wrapper function will take a pint quantity, convert to the units
specified in `args` and then call the wrapped function with the resulting
magnitude.

The value returned by the wrapped function will be converted to the units
specified in `ret`.

Args:
ureg:
A UnitRegistry instance.

ret:
Units of each of the return values. Use `None` to skip argument conversion.

args:
Units of each of the input arguments. Use `None` to skip argument conversion.

Returns:
The wrapper function.

Raises:
TypeError
if the number of given arguments does not match the number of function parameters.
if any of the provided arguments is not a unit a string or Quantity
"""
if not isinstance(args, (list, tuple)):
args = (args,)

for arg in args:
if arg is not None and not isinstance(arg, (ureg.Unit, str)):
raise TypeError(f"wraps arguments must by of type str or Unit, not {type(arg)} ({arg})")

converter = _parse_wrap_args(args)

is_ret_container = isinstance(ret, (list, tuple))
if is_ret_container:
for arg in ret:
if arg is not None and not isinstance(arg, (ureg.Unit, str)):
raise TypeError(f"wraps 'ret' argument must by of type str or Unit, not {type(arg)} ({arg})")
ret = ret.__class__([to_units_container(arg, ureg) for arg in ret])
else:
if ret is not None and not isinstance(ret, (ureg.Unit, str)):
raise TypeError(f"wraps 'ret' argument must by of type str or Unit, not {type(ret)} ({ret})")
ret = to_units_container(ret, ureg)

def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]:
sig = signature(func)
count_params = len(sig.parameters)
if len(args) != count_params:
raise TypeError(f"{func.__name__} takes {count_params} parameters, but {len(args)} units were passed")

assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))

@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*values, **kw) -> Quantity:
values, kw = _apply_defaults(sig, values, kw)

# In principle, the values are used as is
# When then extract the magnitudes when needed.
new_values, new_kw = converter(ureg, sig, values, kw)

result = func(*new_values, **new_kw)

if is_ret_container:
return ret.__class__(
res if unit is None else ureg.Quantity(res, unit) for unit, res in zip_longest(ret, result)
)

if ret is None:
return result

return ureg.Quantity(result, ret)

return wrapper

return decorator
8 changes: 4 additions & 4 deletions roseau/load_flow/models/branches.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _res_currents_getter(self, warning: bool) -> tuple[ComplexArray, ComplexArra
return self._res_getter(value=self._res_currents, warning=warning)

@property
@ureg_wraps(("A", "A"), (None,), strict=False)
@ureg_wraps(("A", "A"), (None,))
def res_currents(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""The load flow result of the branch currents (A)."""
return self._res_currents_getter(warning=True)
Expand All @@ -93,7 +93,7 @@ def _res_powers_getter(self, warning: bool) -> tuple[ComplexArray, ComplexArray]
return powers1, powers2

@property
@ureg_wraps(("VA", "VA"), (None,), strict=False)
@ureg_wraps(("VA", "VA"), (None,))
def res_powers(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""The load flow result of the branch powers (VA)."""
return self._res_powers_getter(warning=True)
Expand All @@ -104,7 +104,7 @@ def _res_potentials_getter(self, warning: bool) -> tuple[ComplexArray, ComplexAr
return pot1, pot2

@property
@ureg_wraps(("V", "V"), (None,), strict=False)
@ureg_wraps(("V", "V"), (None,))
def res_potentials(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""The load flow result of the branch potentials (V)."""
return self._res_potentials_getter(warning=True)
Expand All @@ -114,7 +114,7 @@ def _res_voltages_getter(self, warning: bool) -> tuple[ComplexArray, ComplexArra
return calculate_voltages(pot1, self.phases1), calculate_voltages(pot2, self.phases2)

@property
@ureg_wraps(("V", "V"), (None,), strict=False)
@ureg_wraps(("V", "V"), (None,))
def res_voltages(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""The load flow result of the branch voltages (V)."""
return self._res_voltages_getter(warning=True)
Expand Down
14 changes: 7 additions & 7 deletions roseau/load_flow/models/buses.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def __repr__(self) -> str:
return f"{type(self).__name__}(id={self.id!r}, phases={self.phases!r})"

@property
@ureg_wraps("V", (None,), strict=False)
@ureg_wraps("V", (None,))
def potentials(self) -> Q_[ComplexArray]:
"""An array of initial potentials of the bus (V)."""
return self._potentials

@potentials.setter
@ureg_wraps(None, (None, "V"), strict=False)
@ureg_wraps(None, (None, "V"))
def potentials(self, value: ComplexArrayLike1D) -> None:
if len(value) != len(self.phases):
msg = f"Incorrect number of potentials: {len(value)} instead of {len(self.phases)}"
Expand All @@ -110,7 +110,7 @@ def _res_potentials_getter(self, warning: bool) -> ComplexArray:
return self._res_getter(value=self._res_potentials, warning=warning)

@property
@ureg_wraps("V", (None,), strict=False)
@ureg_wraps("V", (None,))
def res_potentials(self) -> Q_[ComplexArray]:
"""The load flow result of the bus potentials (V)."""
return self._res_potentials_getter(warning=True)
Expand All @@ -120,7 +120,7 @@ def _res_voltages_getter(self, warning: bool) -> ComplexArray:
return calculate_voltages(potentials, self.phases)

@property
@ureg_wraps("V", (None,), strict=False)
@ureg_wraps("V", (None,))
def res_voltages(self) -> Q_[ComplexArray]:
"""The load flow result of the bus voltages (V).

Expand All @@ -146,7 +146,7 @@ def min_voltage(self) -> Optional[Q_[float]]:
return None if self._min_voltage is None else Q_(self._min_voltage, "V")

@min_voltage.setter
@ureg_wraps(None, (None, "V"), strict=False)
@ureg_wraps(None, (None, "V"))
def min_voltage(self, value: Optional[Union[float, Q_[float]]]) -> None:
if value is not None and self._max_voltage is not None and value > self._max_voltage:
msg = (
Expand All @@ -165,7 +165,7 @@ def max_voltage(self) -> Optional[Q_[float]]:
return None if self._max_voltage is None else Q_(self._max_voltage, "V")

@max_voltage.setter
@ureg_wraps(None, (None, "V"), strict=False)
@ureg_wraps(None, (None, "V"))
def max_voltage(self, value: Optional[Union[float, Q_[float]]]) -> None:
if value is not None and self._min_voltage is not None and value < self._min_voltage:
msg = (
Expand Down Expand Up @@ -284,7 +284,7 @@ def get_connected_buses(self) -> Iterator[Id]:
to_add = set(element._connected_elements).difference(visited)
remaining.update(to_add)

@ureg_wraps("percent", (None,), strict=False)
@ureg_wraps("percent", (None,))
def res_voltage_unbalance(self) -> Q_[float]:
"""Calculate the voltage unbalance on this bus according to the IEC definition.

Expand Down
2 changes: 1 addition & 1 deletion roseau/load_flow/models/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _res_potential_getter(self, warning: bool) -> complex:
return self._res_getter(self._res_potential, warning)

@property
@ureg_wraps("V", (None,), strict=False)
@ureg_wraps("V", (None,))
def res_potential(self) -> Q_[complex]:
"""The load flow result of the ground potential (V)."""
return self._res_potential_getter(warning=True)
Expand Down
18 changes: 9 additions & 9 deletions roseau/load_flow/models/lines/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ def __init__(
self._connect(self.ground)

@property
@ureg_wraps("km", (None,), strict=False)
@ureg_wraps("km", (None,))
def length(self) -> Q_[float]:
return self._length

@length.setter
@ureg_wraps(None, (None, "km"), strict=False)
@ureg_wraps(None, (None, "km"))
def length(self, value: Union[float, Q_[float]]) -> None:
if value <= 0:
msg = f"A line length must be greater than 0. {value:.2f} km provided."
Expand Down Expand Up @@ -274,13 +274,13 @@ def parameters(self, value: LineParameters) -> None:
self._invalidate_network_results()

@property
@ureg_wraps("ohm", (None,), strict=False)
@ureg_wraps("ohm", (None,))
def z_line(self) -> Q_[ComplexArray]:
"""Impedance of the line in Ohm"""
return self.parameters._z_line * self._length

@property
@ureg_wraps("S", (None,), strict=False)
@ureg_wraps("S", (None,))
def y_shunt(self) -> Q_[ComplexArray]:
"""Shunt admittance of the line in Siemens"""
return self.parameters._y_shunt * self._length
Expand All @@ -307,7 +307,7 @@ def _res_series_currents_getter(self, warning: bool) -> ComplexArray:
return i_line

@property
@ureg_wraps("A", (None,), strict=False)
@ureg_wraps("A", (None,))
def res_series_currents(self) -> Q_[ComplexArray]:
"""Get the current in the series elements of the line (A)."""
return self._res_series_currents_getter(warning=True)
Expand All @@ -317,7 +317,7 @@ def _res_series_power_losses_getter(self, warning: bool) -> ComplexArray:
return du_line * i_line.conj() # Sₗ = ΔU.Iₗ*

@property
@ureg_wraps("VA", (None,), strict=False)
@ureg_wraps("VA", (None,))
def res_series_power_losses(self) -> Q_[ComplexArray]:
"""Get the power losses in the series elements of the line (VA)."""
return self._res_series_power_losses_getter(warning=True)
Expand All @@ -341,7 +341,7 @@ def _res_shunt_currents_getter(self, warning: bool) -> tuple[ComplexArray, Compl
return cur1, cur2

@property
@ureg_wraps(("A", "A"), (None,), strict=False)
@ureg_wraps(("A", "A"), (None,))
def res_shunt_currents(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""Get the currents in the shunt elements of the line (A)."""
return self._res_shunt_currents_getter(warning=True)
Expand All @@ -353,7 +353,7 @@ def _res_shunt_power_losses_getter(self, warning: bool) -> ComplexArray:
return pot1 * cur1.conj() + pot2 * cur2.conj()

@property
@ureg_wraps("VA", (None,), strict=False)
@ureg_wraps("VA", (None,))
def res_shunt_power_losses(self) -> Q_[ComplexArray]:
"""Get the power losses in the shunt elements of the line (VA)."""
return self._res_shunt_power_losses_getter(warning=True)
Expand All @@ -364,7 +364,7 @@ def _res_power_losses_getter(self, warning: bool) -> ComplexArray:
return series_losses + shunt_losses

@property
@ureg_wraps("VA", (None,), strict=False)
@ureg_wraps("VA", (None,))
def res_power_losses(self) -> Q_[ComplexArray]:
"""Get the power losses in the line (VA)."""
return self._res_power_losses_getter(warning=True)
Expand Down
Loading