Skip to content

Commit

Permalink
BREAK: rename unevaluated_expression() to unevaluated() (#379)
Browse files Browse the repository at this point in the history
* BEHAVIOR: do not sympify `str` attributes
* DOC: add docstring to `_get_attribute_values()`
* ENH: implement hash for non-sympy attributes
* ENH: support non-sympy arguments in `@unevaluated()`
* ENH: implement `subs()` method for `unevaluated_expression` classes
* ENH: implement `xreplace()` method for non-sympy attributes
* MAINT: move method implementations to module level
* MAINT: put `TypeVar` definitions under `TYPE_CHECKING`
* MAINT: remove redundant type ignore
* MAINT: test class var definition without `ClassVar`
* MAINT: write test with unsympifiable class
  • Loading branch information
redeboer authored Dec 21, 2023
1 parent aaf585b commit b76435c
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 41 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
"sharey",
"startswith",
"suptitle",
"sympifiable",
"sympified",
"sympify",
"symplot",
Expand Down
13 changes: 7 additions & 6 deletions docs/usage/sympy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The {func}`.unevaluated_expression` decorator makes it easier to write classes that represent a mathematical function definition. It makes a class that derives from {class}`sp.Expr <sympy.core.expr.Expr>` behave more like a {func}`~.dataclasses.dataclass` (see [PEP&nbsp;861](https://peps.python.org/pep-0681)). All you have to do is:\n",
"The {func}`.unevaluated` decorator makes it easier to write classes that represent a mathematical function definition. It makes a class that derives from {class}`sp.Expr <sympy.core.expr.Expr>` behave more like a {func}`~.dataclasses.dataclass` (see [PEP&nbsp;861](https://peps.python.org/pep-0681)). All you have to do is:\n",
"\n",
"1. Specify the arguments the function requires.\n",
"2. Specify how to render the 'unevaluated' or 'folded' form of the expression with a `_latex_repr_` string or method.\n",
Expand All @@ -98,10 +98,10 @@
"source": [
"import sympy as sp\n",
"\n",
"from ampform.sympy import unevaluated_expression\n",
"from ampform.sympy import unevaluated\n",
"\n",
"\n",
"@unevaluated_expression(real=False)\n",
"@unevaluated(real=False)\n",
"class PhspFactorSWave(sp.Expr):\n",
" s: sp.Symbol\n",
" m1: sp.Symbol\n",
Expand All @@ -119,7 +119,7 @@
" return 16 * sp.pi * sp.I * cm\n",
"\n",
"\n",
"@unevaluated_expression(real=False)\n",
"@unevaluated(real=False)\n",
"class BreakupMomentum(sp.Expr):\n",
" s: sp.Symbol\n",
" m1: sp.Symbol\n",
Expand Down Expand Up @@ -166,7 +166,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Class variables and default arguments to instance arguments are also supported:"
"Class variables and default arguments to instance arguments are also supported. They can either be indicated with {class}`typing.ClassVar` or by not providing a type hint:"
]
},
{
Expand All @@ -180,11 +180,12 @@
"from typing import Any, ClassVar\n",
"\n",
"\n",
"@unevaluated_expression\n",
"@unevaluated\n",
"class FunkyPower(sp.Expr):\n",
" x: Any\n",
" m: int = 1\n",
" default_return: ClassVar[sp.Expr | None] = None\n",
" class_name = \"my name\"\n",
" _latex_repr_ = R\"f_{{{m}}}\\left({x}\\right)\"\n",
"\n",
" def evaluate(self) -> sp.Expr | None:\n",
Expand Down
4 changes: 2 additions & 2 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
UnevaluatedExpression,
determine_indices,
implement_doit_method,
unevaluated_expression,
unevaluated,
)

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter


@unevaluated_expression
@unevaluated
class BlattWeisskopfSquared(sp.Expr):
# cspell:ignore pychyGekoppeltePartialwellenanalyseAnnihilationen
r"""Blatt-Weisskopf function :math:`B_L^2(z)`, up to :math:`L \leq 8`.
Expand Down
6 changes: 3 additions & 3 deletions src/ampform/kinematics/phasespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import sympy as sp

from ampform.sympy import unevaluated_expression
from ampform.sympy import unevaluated


@unevaluated_expression
@unevaluated
class Kibble(sp.Expr):
"""Kibble function for determining the phase space region."""

Expand All @@ -34,7 +34,7 @@ def evaluate(self) -> Kallen:
)


@unevaluated_expression
@unevaluated
class Kallen(sp.Expr):
"""Källén function, used for computing break-up momenta."""

Expand Down
4 changes: 2 additions & 2 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tools that facilitate in building :mod:`sympy` expressions.
.. autodecorator:: unevaluated_expression
.. autodecorator:: unevaluated
.. dropdown:: SymPy assumptions
.. autodata:: ExprClass
Expand Down Expand Up @@ -30,7 +30,7 @@
from ._decorator import (
ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport]
SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport]
unevaluated_expression, # noqa: F401 # pyright: ignore[reportUnusedImport]
unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport]
)

if TYPE_CHECKING:
Expand Down
179 changes: 159 additions & 20 deletions src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,38 @@
import functools
import inspect
import sys
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, overload
from collections import abc
from inspect import isclass
from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, TypeVar, overload

import sympy as sp
from attrs import frozen
from sympy.core.basic import _aresame
from sympy.utilities.exceptions import SymPyDeprecationWarning

if sys.version_info < (3, 8):
from typing_extensions import Protocol, TypedDict
else:
from typing import Protocol, TypedDict

if sys.version_info < (3, 11):
from typing_extensions import ParamSpec, Unpack, dataclass_transform
from typing_extensions import dataclass_transform
else:
from typing import ParamSpec, Unpack, dataclass_transform
from typing import dataclass_transform

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter

if sys.version_info < (3, 11):
from typing_extensions import ParamSpec, Unpack
else:
from typing import ParamSpec, Unpack

H = TypeVar("H", bound=Hashable)
P = ParamSpec("P")
T = TypeVar("T")

ExprClass = TypeVar("ExprClass", bound=sp.Expr)
_P = ParamSpec("_P")
_T = TypeVar("_T")


class SymPyAssumptions(TypedDict, total=False):
Expand Down Expand Up @@ -56,25 +68,23 @@ class SymPyAssumptions(TypedDict, total=False):


@overload
def unevaluated_expression(cls: type[ExprClass]) -> type[ExprClass]: ...
def unevaluated(cls: type[ExprClass]) -> type[ExprClass]: ...
@overload
def unevaluated_expression(
def unevaluated(
*,
implement_doit: bool = True,
**assumptions: Unpack[SymPyAssumptions],
) -> Callable[[type[ExprClass]], type[ExprClass]]: ...


@dataclass_transform() # type: ignore[misc]
def unevaluated_expression( # type: ignore[misc]
@dataclass_transform()
def unevaluated(
cls: type[ExprClass] | None = None, *, implement_doit=True, **assumptions
):
r"""Decorator for defining 'unevaluated' SymPy expressions.
Unevaluated expressions are handy for defining large expressions that consist of
several sub-definitions.
>>> @unevaluated_expression
>>> @unevaluated
... class MyExpr(sp.Expr):
... x: sp.Symbol
... y: sp.Symbol
Expand Down Expand Up @@ -133,22 +143,54 @@ def _implement_new_method(cls: type[ExprClass]) -> type[ExprClass]:
@functools.wraps(cls.__new__)
@_insert_args_in_signature(attr_names, idx=1)
def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:
positional_args, hints = _get_attribute_values(cls, attr_names, *args, **kwargs)
sympified_args = sp.sympify(positional_args)
expr = sp.Expr.__new__(cls, *sympified_args, **hints)
for name, value in zip(attr_names, sympified_args):
attr_values, hints = _get_attribute_values(cls, attr_names, *args, **kwargs)
converted_attr_values = _safe_sympify(*attr_values)
expr = sp.Expr.__new__(cls, *converted_attr_values.sympy, **hints)
for name, value in zip(attr_names, converted_attr_values.all_args):
setattr(expr, name, value)
expr._all_args = converted_attr_values.all_args
expr._non_sympy_args = converted_attr_values.non_sympy
if evaluate:
return expr.evaluate()
return expr

cls.__new__ = new_method # type: ignore[method-assign]
cls._eval_subs = _eval_subs_method # type: ignore[method-assign]
cls._hashable_content = _hashable_content_method # type: ignore[method-assign]
cls._xreplace = _xreplace_method # type: ignore[method-assign]
return cls


@overload
def _get_hashable_object(obj: type) -> str: ... # type: ignore[overload-overlap]
@overload
def _get_hashable_object(obj: H) -> H: ...
@overload
def _get_hashable_object(obj: Any) -> str: ...
def _get_hashable_object(obj):
if isclass(obj):
return str(obj)
try:
hash(obj)
except TypeError:
return str(obj)
return obj


def _get_attribute_values(
cls: type[ExprClass], attr_names: tuple[str, ...], *args, **kwargs
) -> tuple[tuple, dict[str, Any]]:
"""Extract the attribute values from the constructor arguments.
Returns a `tuple` of:
1. the extracted, ordered attributes as requested by :code:`attr_names`,
2. a `dict` of remaining keyword arguments that can be used hints for the
constructed :class:`sp.Expr<sympy.core.expr.Expr>` instance.
An attempt is made to get any missing attributes from the type hints in the class
definition.
"""
if len(args) == len(attr_names):
return args, kwargs
if len(args) > len(attr_names):
Expand All @@ -173,12 +215,46 @@ def _get_attribute_values(
return tuple(attr_values), kwargs


def _safe_sympify(*args: Any) -> _ExprNewArumgents:
all_args = []
sympy_args = []
non_sympy_args = []
for arg in args:
converted_arg, is_sympy = _try_sympify(arg)
if is_sympy:
sympy_args.append(converted_arg)
else:
non_sympy_args.append(converted_arg)
all_args.append(converted_arg)
return _ExprNewArumgents(
all_args=tuple(all_args),
sympy=tuple(sympy_args),
non_sympy=tuple(non_sympy_args),
)


def _try_sympify(obj) -> tuple[Any, bool]:
if isinstance(obj, str):
return obj, False
try:
return sp.sympify(obj), True
except (TypeError, SymPyDeprecationWarning, sp.SympifyError):
return obj, False


@frozen
class _ExprNewArumgents:
all_args: tuple[Any, ...]
sympy: tuple[sp.Basic, ...]
non_sympy: tuple[Any, ...]


class LatexMethod(Protocol):
def __call__(self, printer: LatexPrinter, *args) -> str: ...


@dataclass_transform()
def _implement_latex_repr(cls: type[_T]) -> type[_T]:
def _implement_latex_repr(cls: type[T]) -> type[T]:
_latex_repr_: LatexMethod | str | None = getattr(cls, "_latex_repr_", None)
if _latex_repr_ is None:
msg = (
Expand Down Expand Up @@ -228,7 +304,7 @@ def _check_has_implementation(cls: type) -> None:

def _insert_args_in_signature(
new_params: Iterable[str] | None = None, idx: int = 0
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
if new_params is None:
new_params = []

Expand Down Expand Up @@ -279,10 +355,73 @@ def _get_attribute_names(cls: type) -> tuple[str, ...]:
@dataclass_transform()
def _set_assumptions(
**assumptions: Unpack[SymPyAssumptions],
) -> Callable[[type[_T]], type[_T]]:
def class_wrapper(cls: _T) -> _T:
) -> Callable[[type[T]], type[T]]:
def class_wrapper(cls: T) -> T:
for assumption, value in assumptions.items():
setattr(cls, f"is_{assumption}", value)
return cls

return class_wrapper


def _eval_subs_method(self, old, new, **hints):
# https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1117-L1147
hit = False
substituted_attrs = list(self._all_args)
for i, old_attr in enumerate(substituted_attrs):
if not hasattr(old_attr, "_eval_subs"):
continue
if isclass(old_attr):
continue
new_attr = old_attr._subs(old, new, **hints)
if not _aresame(new_attr, old_attr):
hit = True
substituted_attrs[i] = new_attr
if hit:
rv = self.func(*substituted_attrs)
hack2 = hints.get("hack2", False)
if hack2 and self.is_Mul and not rv.is_Mul: # 2-arg hack
coefficient = sp.S.One
nonnumber = []
for i in substituted_attrs:
if i.is_Number:
coefficient *= i
else:
nonnumber.append(i)
nonnumber = self.func(*nonnumber)
if coefficient is sp.S.One:
return nonnumber
return self.func(coefficient, nonnumber, evaluate=False)
return rv
return self


def _hashable_content_method(self) -> tuple:
hashable_content = super(sp.Expr, self)._hashable_content()
if not self._non_sympy_args:
return hashable_content
remaining_content = (_get_hashable_object(arg) for arg in self._non_sympy_args)
return (*hashable_content, *remaining_content)


def _xreplace_method(self, rule) -> tuple[sp.Expr, bool]:
# https://github.com/sympy/sympy/blob/1.12/sympy/core/basic.py#L1233-L1253
if self in rule:
return rule[self], True
if rule:
new_args = []
hit = False
for arg in self._all_args:
if hasattr(arg, "_xreplace") and not isclass(arg):
replace_result, is_replaced = arg._xreplace(rule)
elif isinstance(rule, abc.Mapping):
is_replaced = bool(arg in rule)
replace_result = rule.get(arg, arg)
else:
replace_result = arg
is_replaced = False
new_args.append(replace_result)
hit |= is_replaced
if hit:
return self.func(*new_args), True
return self, False
Loading

0 comments on commit b76435c

Please sign in to comment.