Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨feat: dataclasses with converters #20

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,8 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"tests/**" = ["T20"]
"noxfile.py" = ["T20"]

[dependency-groups]
dev = [
"ipykernel>=6.29.5",
]
34 changes: 23 additions & 11 deletions src/quantity/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import operator
from dataclasses import dataclass, replace
from dataclasses import replace
from typing import TYPE_CHECKING

import array_api_compat
Expand All @@ -11,7 +11,7 @@
from astropy.units.quantity_helper import UFUNC_HELPERS

from .api import QuantityArray
from .utils import has_array_namespace
from .utils import dataclass, field, has_array_namespace

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -105,12 +105,20 @@ def _make_same_unit_method(attr):
if array_api_func := getattr(array_api_compat, attr, None):

def same_unit(self, *args, **kwargs):
return replace(self, value=array_api_func(self.value, *args, **kwargs))
return replace(
self,
value=array_api_func(self.value, *args, **kwargs),
_skip_convert=True,
)

else:

def same_unit(self, *args, **kwargs):
return replace(self, value=getattr(self.value, attr)(*args, **kwargs))
return replace(
self,
value=getattr(self.value, attr)(*args, **kwargs),
_skip_convert=True,
)

return same_unit

Expand All @@ -119,7 +127,7 @@ def _make_same_unit_attribute(attr):
attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr))

def same_unit(self):
return replace(self, value=attr_getter(self.value))
return replace(self, value=attr_getter(self.value), _skip_convert=True)

return property(same_unit)

Expand Down Expand Up @@ -150,10 +158,14 @@ def _check_pow_args(exp, mod):
return exp.real if exp.imag == 0 else exp


def _value_converter(v: Any, /) -> Array:
return v if has_array_namespace(v) else np.asarray(v)


@dataclass(frozen=True, eq=False)
class Quantity:
value: Any
unit: u.UnitBase
value: Array = field(converter=_value_converter)
unit: u.UnitBase = field(converter=u.Unit)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In future, to make this completely unit-type agnostic, we would set the type to be a runtime-checkable protocol and the converter would be something like converter=lambda x: x if isinstance(x, UnitAPI) else config.get("units_backend")(x), where the user can configure a preferred units package. Or, if we keep this more closely tied to Astropy, we would have converter=lambda x: x if isinstance(x, UnitAPI) else u.Unit(x)


def __array_namespace__(self, *, api_version: str | None = None) -> Any:
# TODO: make our own?
Expand All @@ -170,7 +182,7 @@ def _operate(self, other, op_func, units_helper):
except Exception:
return NotImplemented
else:
return replace(self, unit=unit)
return replace(self, unit=unit, _skip_convert=True)

other_value, other_unit = get_value_and_unit(other)
self_value = self.value
Expand All @@ -185,7 +197,7 @@ def _operate(self, other, op_func, units_helper):
# Deal with the very unlikely case that other is an array type
# that knows about Quantity, but cannot handle the array we carry.
return NotImplemented
return replace(self, value=value, unit=unit)
return replace(self, value=value, unit=unit, _skip_convert=True)

# Operators (skipping ones that make no sense, like __and__);
# __pow__ and __rpow__ need special treatment and are defined below.
Expand Down Expand Up @@ -234,15 +246,15 @@ def __pow__(self, exp, mod=None):
return NotImplemented

value = operator.__pow__(self.value, exp)
return replace(self, value=value, unit=self.unit**exp)
return replace(self, value=value, unit=self.unit**exp, _skip_convert=True)

def __ipow__(self, exp, mod=None):
exp = _check_pow_args(exp, mod)
if exp is NotImplemented:
return NotImplemented

value = operator.__ipow__(self.value, exp)
return replace(self, value=value, unit=self.unit**exp)
return replace(self, value=value, unit=self.unit**exp, _skip_convert=True)

def __setitem__(self, item, value):
self.value[item] = value_in_unit(value, self.unit)
Expand Down
124 changes: 124 additions & 0 deletions src/quantity/_src/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Utility functions for the quantity package."""

import dataclasses
import functools
import inspect
from collections.abc import Callable, Hashable, Mapping
from typing import Any, TypeVar, dataclass_transform, overload

import array_api_compat


Expand All @@ -10,3 +16,121 @@ def has_array_namespace(arg: object) -> bool:
return False
else:
return True


# ===================================================================
# Dataclass utilities

_CT = TypeVar("_CT")


def field(
*,
converter: Callable[[Any], Any] | None = None,
metadata: Mapping[Hashable, Any] | None = None,
**kwargs: Any,
) -> Any:
"""Dataclass field with a converter argument.

Parameters
----------
converter : callable, optional
A callable that converts the value of the field. This is added to the
metadata of the field.
metadata : Mapping[Hashable, Any], optional
Additional metadata to add to the field.
See `dataclasses.field` for more information.
**kwargs : Any
Additional keyword arguments to pass to `dataclasses.field`.

"""
if converter is not None:
# Check the converter
if not callable(converter):
msg = f"converter must be callable, got {converter!r}"
raise TypeError(msg)

# Convert the metadata to a mutable dict if it is not None.
metadata = dict(metadata) if metadata is not None else {}

if "converter" in metadata:
msg = "Cannot specify 'converter' in metadata and as a keyword argument."
raise ValueError(msg)

# Add the converter to the metadata
metadata["converter"] = converter

return dataclasses.field(metadata=metadata, **kwargs)


def _process_dataclass(cls: type[_CT], **kwargs: Any) -> type[_CT]:
# Make the dataclass from the class.
# This does all the usual dataclass stuff.
dcls: type[_CT] = dataclasses.dataclass(cls, **kwargs)

# Compute the signature of the __init__ method
sig = inspect.signature(dcls.__init__)
# Eliminate the 'self' parameter
sig = sig.replace(parameters=list(sig.parameters.values())[1:])
# Store the signature on the __init__ method (Not assigning to __signature__
# because that should have `self`).
dcls.__init__._obj_signature_ = sig # type: ignore[attr-defined]

# Ensure that the __init__ method does conversion
@functools.wraps(dcls.__init__) # give it the same signature
def __init__(self, *args: Any, _skip_convert: bool = False, **kwargs: Any) -> None:
# Fast path: no conversion
if _skip_convert:
self.__init__.__wrapped__(self, *args, **kwargs)
return

# Bind the arguments to the signature
ba = self.__init__._obj_signature_.bind_partial(*args, **kwargs)
ba.apply_defaults() # so eligible for conversion

# Convert the fields, if there's a converter
for f in dataclasses.fields(self):
k = f.name
if k not in ba.arguments: # mandatory field not provided?!
continue # defer the error to the dataclass __init__

converter = f.metadata.get("converter")
if converter is not None:
ba.arguments[k] = converter(ba.arguments[k])

# Call the original dataclass __init__ method
self.__init__.__wrapped__(self, *ba.args, **ba.kwargs)

dcls.__init__ = __init__ # type: ignore[method-assign]

return dcls


@overload
def dataclass(cls: type[_CT], /, **kwargs: Any) -> type[_CT]: ...


@overload
def dataclass(**kwargs: Any) -> Callable[[type[_CT]], type[_CT]]: ...


@dataclass_transform(field_specifiers=(dataclasses.Field, dataclasses.field, field))
def dataclass(
cls: type[_CT] | None = None, /, **kwargs: Any
) -> type[_CT] | Callable[[type[_CT]], type[_CT]]:
"""Make a dataclass, supporting field converters.

For more information about dataclasses see the `dataclasses` module.

Parameters
----------
cls : type | None, optional
The class to transform into a dataclass.
If None, returns a partial function that can be used as a decorator.
**kwargs : Any
Additional keyword arguments to pass to `dataclasses.dataclass`.

"""
if cls is None:
return functools.partial(_process_dataclass, **kwargs)
return _process_dataclass(cls, **kwargs)
Loading