Skip to content

Commit

Permalink
Merge pull request #149 from RoseauTechnologies/wrapper
Browse files Browse the repository at this point in the history
Add custom pint wrapper
  • Loading branch information
Saelyos authored Nov 17, 2023
2 parents 57eff2a + 94073bd commit f7f6070
Show file tree
Hide file tree
Showing 14 changed files with 339 additions and 87 deletions.
1 change: 1 addition & 0 deletions doc/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

**In development**

- {gh-pr}`149` {gh-issue}`145` Add custom pint wrapper for better handling of pint arrays.
- {gh-pr}`148` {gh-issue}`122` deprecate `LineParameters.from_name_lv()` in favor of the more generic
`LineParameters.from_geometry()`. The method will be removed in a future release.
- {gh-pr}`142` {gh-issue}`136` Add `Bus.res_voltage_unbalance()` method to get the Voltage Unbalance
Expand Down
148 changes: 148 additions & 0 deletions roseau/load_flow/_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import functools
from collections.abc import Iterable, MutableSequence
from inspect import Parameter, Signature, signature
from itertools import zip_longest
from typing import Any, Callable, Optional, TypeVar, Union

from pint import Quantity, Unit
from pint.registry import UnitRegistry
from pint.util import to_units_container

T = TypeVar("T")
FuncT = TypeVar("FuncT", bound=Callable)


def _parse_wrap_args(args: Iterable[Optional[Union[str, Unit]]]) -> Callable:
"""Create a converter function for the wrapper"""
# _to_units_container
args_as_uc = [to_units_container(arg) for arg in args]

# Check for references in args, remove None values
unit_args_ndx = {ndx for ndx, arg in enumerate(args_as_uc) if arg is not None}

def _converter(ureg: UnitRegistry, sig: Signature, values: list[Any], kw: dict[Any]):
len_initial_values = len(values)

# pack kwargs
for i, param_name in enumerate(sig.parameters):
if i >= len_initial_values:
values.append(kw[param_name])

# convert arguments
for ndx in unit_args_ndx:
value = values[ndx]
if isinstance(value, ureg.Quantity):
values[ndx] = ureg.convert(value.magnitude, value.units, args_as_uc[ndx])
elif isinstance(value, MutableSequence):
for i, val in enumerate(value):
if isinstance(val, ureg.Quantity):
value[i] = ureg.convert(val.magnitude, val.units, args_as_uc[ndx])

# unpack kwargs
for i, param_name in enumerate(sig.parameters):
if i >= len_initial_values:
kw[param_name] = values[i]

return values[:len_initial_values], kw

return _converter


def _apply_defaults(sig: Signature, args: tuple[Any], kwargs: dict[str, Any]) -> tuple[list[Any], dict[str, Any]]:
"""Apply default keyword arguments.
Named keywords may have been left blank. This function applies the default
values so that every argument is defined.
"""
n = len(args)
for i, param in enumerate(sig.parameters.values()):
if i >= n and param.default != Parameter.empty and param.name not in kwargs:
kwargs[param.name] = param.default
return list(args), kwargs


def wraps(
ureg: UnitRegistry,
ret: Optional[Union[str, Unit, Iterable[Optional[Union[str, Unit]]]]],
args: Optional[Union[str, Unit, Iterable[Optional[Union[str, Unit]]]]],
) -> Callable[[FuncT], FuncT]:
"""Wraps a function to become pint-aware.
Use it when a function requires a numerical value but in some specific
units. The wrapper function will take a pint quantity, convert to the units
specified in `args` and then call the wrapped function with the resulting
magnitude.
The value returned by the wrapped function will be converted to the units
specified in `ret`.
Args:
ureg:
A UnitRegistry instance.
ret:
Units of each of the return values. Use `None` to skip argument conversion.
args:
Units of each of the input arguments. Use `None` to skip argument conversion.
Returns:
The wrapper function.
Raises:
TypeError
if the number of given arguments does not match the number of function parameters.
if any of the provided arguments is not a unit a string or Quantity
"""
if not isinstance(args, (list, tuple)):
args = (args,)

for arg in args:
if arg is not None and not isinstance(arg, (ureg.Unit, str)):
raise TypeError(f"wraps arguments must by of type str or Unit, not {type(arg)} ({arg})")

converter = _parse_wrap_args(args)

is_ret_container = isinstance(ret, (list, tuple))
if is_ret_container:
for arg in ret:
if arg is not None and not isinstance(arg, (ureg.Unit, str)):
raise TypeError(f"wraps 'ret' argument must by of type str or Unit, not {type(arg)} ({arg})")
ret = ret.__class__([to_units_container(arg, ureg) for arg in ret])
else:
if ret is not None and not isinstance(ret, (ureg.Unit, str)):
raise TypeError(f"wraps 'ret' argument must by of type str or Unit, not {type(ret)} ({ret})")
ret = to_units_container(ret, ureg)

def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]:
sig = signature(func)
count_params = len(sig.parameters)
if len(args) != count_params:
raise TypeError(f"{func.__name__} takes {count_params} parameters, but {len(args)} units were passed")

assigned = tuple(attr for attr in functools.WRAPPER_ASSIGNMENTS if hasattr(func, attr))
updated = tuple(attr for attr in functools.WRAPPER_UPDATES if hasattr(func, attr))

@functools.wraps(func, assigned=assigned, updated=updated)
def wrapper(*values, **kw) -> Quantity:
values, kw = _apply_defaults(sig, values, kw)

# In principle, the values are used as is
# When then extract the magnitudes when needed.
new_values, new_kw = converter(ureg, sig, values, kw)

result = func(*new_values, **new_kw)

if is_ret_container:
return ret.__class__(
res if unit is None else ureg.Quantity(res, unit) for unit, res in zip_longest(ret, result)
)

if ret is None:
return result

return ureg.Quantity(result, ret)

return wrapper

return decorator
8 changes: 4 additions & 4 deletions roseau/load_flow/models/branches.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _res_currents_getter(self, warning: bool) -> tuple[ComplexArray, ComplexArra
return self._res_getter(value=self._res_currents, warning=warning)

@property
@ureg_wraps(("A", "A"), (None,), strict=False)
@ureg_wraps(("A", "A"), (None,))
def res_currents(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""The load flow result of the branch currents (A)."""
return self._res_currents_getter(warning=True)
Expand All @@ -93,7 +93,7 @@ def _res_powers_getter(self, warning: bool) -> tuple[ComplexArray, ComplexArray]
return powers1, powers2

@property
@ureg_wraps(("VA", "VA"), (None,), strict=False)
@ureg_wraps(("VA", "VA"), (None,))
def res_powers(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""The load flow result of the branch powers (VA)."""
return self._res_powers_getter(warning=True)
Expand All @@ -104,7 +104,7 @@ def _res_potentials_getter(self, warning: bool) -> tuple[ComplexArray, ComplexAr
return pot1, pot2

@property
@ureg_wraps(("V", "V"), (None,), strict=False)
@ureg_wraps(("V", "V"), (None,))
def res_potentials(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""The load flow result of the branch potentials (V)."""
return self._res_potentials_getter(warning=True)
Expand All @@ -114,7 +114,7 @@ def _res_voltages_getter(self, warning: bool) -> tuple[ComplexArray, ComplexArra
return calculate_voltages(pot1, self.phases1), calculate_voltages(pot2, self.phases2)

@property
@ureg_wraps(("V", "V"), (None,), strict=False)
@ureg_wraps(("V", "V"), (None,))
def res_voltages(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""The load flow result of the branch voltages (V)."""
return self._res_voltages_getter(warning=True)
Expand Down
14 changes: 7 additions & 7 deletions roseau/load_flow/models/buses.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def __repr__(self) -> str:
return f"{type(self).__name__}(id={self.id!r}, phases={self.phases!r})"

@property
@ureg_wraps("V", (None,), strict=False)
@ureg_wraps("V", (None,))
def potentials(self) -> Q_[ComplexArray]:
"""An array of initial potentials of the bus (V)."""
return self._potentials

@potentials.setter
@ureg_wraps(None, (None, "V"), strict=False)
@ureg_wraps(None, (None, "V"))
def potentials(self, value: ComplexArrayLike1D) -> None:
if len(value) != len(self.phases):
msg = f"Incorrect number of potentials: {len(value)} instead of {len(self.phases)}"
Expand All @@ -110,7 +110,7 @@ def _res_potentials_getter(self, warning: bool) -> ComplexArray:
return self._res_getter(value=self._res_potentials, warning=warning)

@property
@ureg_wraps("V", (None,), strict=False)
@ureg_wraps("V", (None,))
def res_potentials(self) -> Q_[ComplexArray]:
"""The load flow result of the bus potentials (V)."""
return self._res_potentials_getter(warning=True)
Expand All @@ -120,7 +120,7 @@ def _res_voltages_getter(self, warning: bool) -> ComplexArray:
return calculate_voltages(potentials, self.phases)

@property
@ureg_wraps("V", (None,), strict=False)
@ureg_wraps("V", (None,))
def res_voltages(self) -> Q_[ComplexArray]:
"""The load flow result of the bus voltages (V).
Expand All @@ -146,7 +146,7 @@ def min_voltage(self) -> Optional[Q_[float]]:
return None if self._min_voltage is None else Q_(self._min_voltage, "V")

@min_voltage.setter
@ureg_wraps(None, (None, "V"), strict=False)
@ureg_wraps(None, (None, "V"))
def min_voltage(self, value: Optional[Union[float, Q_[float]]]) -> None:
if value is not None and self._max_voltage is not None and value > self._max_voltage:
msg = (
Expand All @@ -165,7 +165,7 @@ def max_voltage(self) -> Optional[Q_[float]]:
return None if self._max_voltage is None else Q_(self._max_voltage, "V")

@max_voltage.setter
@ureg_wraps(None, (None, "V"), strict=False)
@ureg_wraps(None, (None, "V"))
def max_voltage(self, value: Optional[Union[float, Q_[float]]]) -> None:
if value is not None and self._min_voltage is not None and value < self._min_voltage:
msg = (
Expand Down Expand Up @@ -284,7 +284,7 @@ def get_connected_buses(self) -> Iterator[Id]:
to_add = set(element._connected_elements).difference(visited)
remaining.update(to_add)

@ureg_wraps("percent", (None,), strict=False)
@ureg_wraps("percent", (None,))
def res_voltage_unbalance(self) -> Q_[float]:
"""Calculate the voltage unbalance on this bus according to the IEC definition.
Expand Down
2 changes: 1 addition & 1 deletion roseau/load_flow/models/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _res_potential_getter(self, warning: bool) -> complex:
return self._res_getter(self._res_potential, warning)

@property
@ureg_wraps("V", (None,), strict=False)
@ureg_wraps("V", (None,))
def res_potential(self) -> Q_[complex]:
"""The load flow result of the ground potential (V)."""
return self._res_potential_getter(warning=True)
Expand Down
18 changes: 9 additions & 9 deletions roseau/load_flow/models/lines/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ def __init__(
self._connect(self.ground)

@property
@ureg_wraps("km", (None,), strict=False)
@ureg_wraps("km", (None,))
def length(self) -> Q_[float]:
return self._length

@length.setter
@ureg_wraps(None, (None, "km"), strict=False)
@ureg_wraps(None, (None, "km"))
def length(self, value: Union[float, Q_[float]]) -> None:
if value <= 0:
msg = f"A line length must be greater than 0. {value:.2f} km provided."
Expand Down Expand Up @@ -274,13 +274,13 @@ def parameters(self, value: LineParameters) -> None:
self._invalidate_network_results()

@property
@ureg_wraps("ohm", (None,), strict=False)
@ureg_wraps("ohm", (None,))
def z_line(self) -> Q_[ComplexArray]:
"""Impedance of the line in Ohm"""
return self.parameters._z_line * self._length

@property
@ureg_wraps("S", (None,), strict=False)
@ureg_wraps("S", (None,))
def y_shunt(self) -> Q_[ComplexArray]:
"""Shunt admittance of the line in Siemens"""
return self.parameters._y_shunt * self._length
Expand All @@ -307,7 +307,7 @@ def _res_series_currents_getter(self, warning: bool) -> ComplexArray:
return i_line

@property
@ureg_wraps("A", (None,), strict=False)
@ureg_wraps("A", (None,))
def res_series_currents(self) -> Q_[ComplexArray]:
"""Get the current in the series elements of the line (A)."""
return self._res_series_currents_getter(warning=True)
Expand All @@ -317,7 +317,7 @@ def _res_series_power_losses_getter(self, warning: bool) -> ComplexArray:
return du_line * i_line.conj() # Sₗ = ΔU.Iₗ*

@property
@ureg_wraps("VA", (None,), strict=False)
@ureg_wraps("VA", (None,))
def res_series_power_losses(self) -> Q_[ComplexArray]:
"""Get the power losses in the series elements of the line (VA)."""
return self._res_series_power_losses_getter(warning=True)
Expand All @@ -341,7 +341,7 @@ def _res_shunt_currents_getter(self, warning: bool) -> tuple[ComplexArray, Compl
return cur1, cur2

@property
@ureg_wraps(("A", "A"), (None,), strict=False)
@ureg_wraps(("A", "A"), (None,))
def res_shunt_currents(self) -> tuple[Q_[ComplexArray], Q_[ComplexArray]]:
"""Get the currents in the shunt elements of the line (A)."""
return self._res_shunt_currents_getter(warning=True)
Expand All @@ -353,7 +353,7 @@ def _res_shunt_power_losses_getter(self, warning: bool) -> ComplexArray:
return pot1 * cur1.conj() + pot2 * cur2.conj()

@property
@ureg_wraps("VA", (None,), strict=False)
@ureg_wraps("VA", (None,))
def res_shunt_power_losses(self) -> Q_[ComplexArray]:
"""Get the power losses in the shunt elements of the line (VA)."""
return self._res_shunt_power_losses_getter(warning=True)
Expand All @@ -364,7 +364,7 @@ def _res_power_losses_getter(self, warning: bool) -> ComplexArray:
return series_losses + shunt_losses

@property
@ureg_wraps("VA", (None,), strict=False)
@ureg_wraps("VA", (None,))
def res_power_losses(self) -> Q_[ComplexArray]:
"""Get the power losses in the line (VA)."""
return self._res_power_losses_getter(warning=True)
Expand Down
Loading

0 comments on commit f7f6070

Please sign in to comment.