Skip to content

Commit

Permalink
MAINT: switch to @unevaluated where possible (#382)
Browse files Browse the repository at this point in the history
* FIX: stabilize hash of `@unevaluated` expression classes
* MAINT: define all array classes with `@unevaluated`
* MAINT: define `EnergyDependentWidth` with `@unevaluated`
* MAINT: define phase space protocols with `unevaluated`
  • Loading branch information
redeboer authored Dec 22, 2023
1 parent 2722118 commit 7cd9a32
Show file tree
Hide file tree
Showing 13 changed files with 317 additions and 532 deletions.
108 changes: 17 additions & 91 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import TYPE_CHECKING, Any, ClassVar

import sympy as sp
from sympy.core.basic import _aresame

# pyright: reportUnusedImport=false
from ampform.dynamics.phasespace import (
Expand All @@ -22,12 +21,7 @@
PhaseSpaceFactorSWave, # noqa: F401
_indices_to_subscript,
)
from ampform.sympy import (
UnevaluatedExpression,
determine_indices,
implement_doit_method,
unevaluated,
)
from ampform.sympy import argument, determine_indices, unevaluated

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter
Expand Down Expand Up @@ -139,8 +133,8 @@ def evaluate(self) -> sp.Expr:
])


@implement_doit_method
class EnergyDependentWidth(UnevaluatedExpression):
@unevaluated
class EnergyDependentWidth(sp.Expr):
r"""Mass-dependent width, coupled to the pole position of the resonance.
See Equation (50.28) in :pdg-review:`2021; Resonances; p.9` and
Expand All @@ -154,50 +148,17 @@ class EnergyDependentWidth(UnevaluatedExpression):
:math:`\left(q/q_0\right)^{2L}` in the definition for :math:`\Gamma(m)`.
"""

# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77
__slots__ = ("phsp_factor",)
phsp_factor: PhaseSpaceFactorProtocol
is_commutative = True

def __new__(
cls,
s,
mass0,
gamma0,
m_a,
m_b,
angular_momentum,
meson_radius,
phsp_factor: PhaseSpaceFactorProtocol | None = None,
name: str | None = None,
evaluate: bool = False,
) -> EnergyDependentWidth:
args = sp.sympify((s, mass0, gamma0, m_a, m_b, angular_momentum, meson_radius))
if phsp_factor is None:
phsp_factor = PhaseSpaceFactor
# Overwritting Basic.__new__ to store phase space factor type
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L121-L127
expr = object.__new__(cls)
expr._assumptions = cls.default_assumptions # type: ignore[attr-defined]
expr._mhash = None
expr._args = args
expr._name = name
expr.phsp_factor = phsp_factor
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr

def __getnewargs_ex__(self) -> tuple[tuple, dict]:
# Pickling support, see
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L132-L133
args = (*self.args, self.phsp_factor)
kwargs = {"name": self._name}
return args, kwargs

def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
# phsp_factor is converted to string because of unstable hash for classes
return (*super()._hashable_content(), str(self.phsp_factor))
s: Any
mass0: Any
gamma0: Any
m_a: Any
m_b: Any
angular_momentum: Any
meson_radius: Any
phsp_factor: PhaseSpaceFactorProtocol = argument(
default=PhaseSpaceFactor, sympify=False
)
name: str | None = argument(default=None, sympify=False)

def evaluate(self) -> sp.Expr:
s, mass0, gamma0, m_a, m_b, angular_momentum, meson_radius = self.args
Expand All @@ -215,48 +176,13 @@ def evaluate(self) -> sp.Expr:
rho0 = self.phsp_factor(mass0**2, m_a, m_b) # type: ignore[operator]
return gamma0 * (form_factor_sq / form_factor0_sq) * (rho / rho0)

def _latex(self, printer: LatexPrinter, *args) -> str:
def _latex_repr_(self, printer: LatexPrinter, *args) -> str:
s = printer._print(self.args[0])
gamma0 = self.args[2]
subscript = _indices_to_subscript(determine_indices(gamma0))
name = Rf"\Gamma{subscript}" if self._name is None else self._name
name = Rf"\Gamma{subscript}" if self.name is None else self.name
return Rf"{name}\left({s}\right)"

def _eval_subs(self, old, new):
# https://github.com/ComPWA/sympy/blob/bd0cf9a/sympy/core/basic.py#L1074-L1104
hit = False
new_args = list(self.args)
for i, arg in enumerate(self.args):
if not hasattr(arg, "_eval_subs"):
continue
arg = arg._subs(old, new)
if not _aresame(arg, new_args[i]):
hit = True
new_args[i] = arg
if hit:
return self.func(*new_args, self.phsp_factor, self._name)
return self

def _xreplace(self, rule):
# https://github.com/sympy/sympy/blob/bd0cf9a/sympy/core/basic.py#L1190-L1210
if self in rule:
return rule[self], True
if rule:
new_args = []
hit = False
for a in self.args:
_xreplace = getattr(a, "_xreplace", None)
if _xreplace is not None:
a_xr = _xreplace(rule)
new_args.append(a_xr[0])
hit |= a_xr[1]
else:
new_args.append(a)
new_args = tuple(new_args)
if hit:
return self.func(*new_args, self.phsp_factor, self._name), True
return self, False


def relativistic_breit_wigner(s, mass0, gamma0) -> sp.Expr:
"""Relativistic Breit-Wigner lineshape.
Expand All @@ -275,7 +201,7 @@ def relativistic_breit_wigner_with_ff(
m_b,
angular_momentum,
meson_radius,
phsp_factor: PhaseSpaceFactorProtocol | None = None,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
) -> sp.Expr:
"""Relativistic Breit-Wigner with `.BlattWeisskopfSquared` factor.
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/dynamics/kmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def parametrization(
pole_id,
angular_momentum=0,
meson_radius=1,
phsp_factor: PhaseSpaceFactorProtocol | None = None,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
) -> sp.Expr:
def residue_function(pole_id, i) -> sp.Expr:
return residue_constant[pole_id, i] * sp.sqrt(
Expand Down
Loading

0 comments on commit 7cd9a32

Please sign in to comment.