Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: support class attributes in unevaluated_expression #375

Merged
merged 4 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/_extend_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def extend_docstrings() -> None:
def extend_BlattWeisskopfSquared() -> None:
from ampform.dynamics import BlattWeisskopfSquared

L = sp.Symbol("L", integer=True)
z = sp.Symbol("z", real=True)
expr = BlattWeisskopfSquared(L, z)
L = sp.Symbol("L", integer=True)
expr = BlattWeisskopfSquared(z, angular_momentum=L)
_append_latex_doit_definition(expr, deep=True, full_width=True)


Expand Down
4 changes: 2 additions & 2 deletions docs/usage/dynamics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
"\n",
"L = sp.Symbol(\"L\", integer=True)\n",
"z = sp.Symbol(\"z\", real=True)\n",
"ff2 = BlattWeisskopfSquared(L, z)\n",
"ff2 = BlattWeisskopfSquared(z, L)\n",
"Math(sp.multiline_latex(ff2, ff2.doit(), environment=\"eqnarray\"))"
]
},
Expand All @@ -183,7 +183,7 @@
"m, m_a, m_b, d = sp.symbols(\"m, m_a, m_b, d\")\n",
"s = m**2\n",
"q_squared = BreakupMomentumSquared(s, m_a, m_b)\n",
"ff2 = BlattWeisskopfSquared(L, z=q_squared * d**2)"
"ff2 = BlattWeisskopfSquared(q_squared * d**2, angular_momentum=L)"
]
},
{
Expand Down
36 changes: 16 additions & 20 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# cspell:ignore asner mhash
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, ClassVar

import sympy as sp
from sympy.core.basic import _aresame
Expand All @@ -24,27 +24,27 @@
)
from ampform.sympy import (
UnevaluatedExpression,
create_expression,
determine_indices,
implement_doit_method,
unevaluated_expression,
)

if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter


@implement_doit_method
class BlattWeisskopfSquared(UnevaluatedExpression):
@unevaluated_expression
class BlattWeisskopfSquared(sp.Expr):
# cspell:ignore pychyGekoppeltePartialwellenanalyseAnnihilationen
r"""Blatt-Weisskopf function :math:`B_L^2(z)`, up to :math:`L \leq 8`.

Args:
angular_momentum: Angular momentum :math:`L` of the decaying particle.

z: Argument of the Blatt-Weisskopf function :math:`B_L^2(z)`. A usual
choice is :math:`z = (d q)^2` with :math:`d` the impact parameter and
:math:`q` the breakup-momentum (see `.BreakupMomentumSquared`).

angular_momentum: Angular momentum :math:`L` of the decaying particle.

Note that equal powers of :math:`z` appear in the nominator and the denominator,
while some sources have nominator :math:`1`, instead of :math:`z^L`. Compare for
instance Equation (50.27) in :pdg-review:`2021; Resonances; p.9`.
Expand All @@ -57,20 +57,20 @@ class BlattWeisskopfSquared(UnevaluatedExpression):

See also :ref:`usage/dynamics:Form factor`.
"""
is_commutative = True
max_angular_momentum: int | None = None
z: Any
angular_momentum: Any
_latex_repr_ = R"B_{{{angular_momentum}}}^2\left({z}\right)"

max_angular_momentum: ClassVar[int | None] = None
"""Limit the maximum allowed angular momentum :math:`L`.

This improves performance when :math:`L` is a `~sympy.core.symbol.Symbol` and you
are note interested in higher angular momenta.
"""

def __new__(cls, angular_momentum, z, **hints) -> BlattWeisskopfSquared:
return create_expression(cls, angular_momentum, z, **hints)

def evaluate(self) -> sp.Expr:
angular_momentum: sp.Expr = self.args[0] # type: ignore[assignment]
z: sp.Expr = self.args[1] # type: ignore[assignment]
z: sp.Expr = self.args[0] # type: ignore[assignment]
angular_momentum: sp.Expr = self.args[1] # type: ignore[assignment]
cases: dict[int, sp.Expr] = {
0: sp.S.One,
1: 2 * z / (z + 1),
Expand Down Expand Up @@ -138,10 +138,6 @@ def evaluate(self) -> sp.Expr:
if self.max_angular_momentum is None or value <= self.max_angular_momentum
])

def _latex(self, printer: LatexPrinter, *args) -> str:
angular_momentum, z = tuple(map(printer._print, self.args))
return Rf"B_{{{angular_momentum}}}^2\left({z}\right)"


@implement_doit_method
class EnergyDependentWidth(UnevaluatedExpression):
Expand Down Expand Up @@ -208,12 +204,12 @@ def evaluate(self) -> sp.Expr:
q_squared = BreakupMomentumSquared(s, m_a, m_b)
q0_squared = BreakupMomentumSquared(mass0**2, m_a, m_b) # type: ignore[operator]
form_factor_sq = BlattWeisskopfSquared(
q_squared * meson_radius**2, # type: ignore[operator]
angular_momentum,
z=q_squared * meson_radius**2, # type: ignore[operator]
)
form_factor0_sq = BlattWeisskopfSquared(
q0_squared * meson_radius**2, # type: ignore[operator]
angular_momentum,
z=q0_squared * meson_radius**2, # type: ignore[operator]
)
rho = self.phsp_factor(s, m_a, m_b)
rho0 = self.phsp_factor(mass0**2, m_a, m_b) # type: ignore[operator]
Expand Down Expand Up @@ -303,5 +299,5 @@ def formulate_form_factor(s, m_a, m_b, angular_momentum, meson_radius) -> sp.Exp
`~sympy.functions.elementary.miscellaneous.sqrt` of a `.BlattWeisskopfSquared`.
"""
q_squared = BreakupMomentumSquared(s, m_a, m_b)
ff_squared = BlattWeisskopfSquared(angular_momentum, z=q_squared * meson_radius**2)
ff_squared = BlattWeisskopfSquared(q_squared * meson_radius**2, angular_momentum)
return sp.sqrt(ff_squared)
7 changes: 6 additions & 1 deletion src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,19 @@ def _get_attribute_names(cls: type) -> tuple[str, ...]:
... a: int
... b: int
... _c: int
... n: ClassVar[int] = 2
...
... def print(self): ...
...
>>> _get_attribute_names(MyClass)
('a', 'b')
"""
return tuple(
k for k in cls.__annotations__ if not callable(k) if not k.startswith("_")
k
for k, v in cls.__annotations__.items()
if not callable(k)
if not k.startswith("_")
if not str(v).startswith("ClassVar")
)


Expand Down
14 changes: 13 additions & 1 deletion tests/dynamics/test_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TestBlattWeisskopfSquared:
def test_max_angular_momentum(self):
z = sp.Symbol("z")
angular_momentum = sp.Symbol("L", integer=True)
form_factor = BlattWeisskopfSquared(angular_momentum, z=z)
form_factor = BlattWeisskopfSquared(z, angular_momentum)
form_factor_9 = form_factor.subs(angular_momentum, 8).evaluate()
factor, z_power, _ = form_factor_9.args
assert factor == 4392846440677
Expand All @@ -35,6 +35,18 @@ def test_max_angular_momentum(self):
(1, sp.Eq(angular_momentum, 0)),
(2 * z / (z + 1), sp.Eq(angular_momentum, 1)),
)
BlattWeisskopfSquared.max_angular_momentum = None

def test_unevaluated_expression(self):
z = sp.Symbol("z")
ff1 = BlattWeisskopfSquared(z, angular_momentum=1)
ff2 = BlattWeisskopfSquared(z, angular_momentum=2)
assert ff1.max_angular_momentum is None
assert ff2.max_angular_momentum is None
BlattWeisskopfSquared.max_angular_momentum = 3
assert ff1.max_angular_momentum is 3 # noqa: F632
assert ff2.max_angular_momentum is 3 # noqa: F632
BlattWeisskopfSquared.max_angular_momentum = None


class TestEnergyDependentWidth:
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamics/test_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_pickle():
assert expr == imported_expr

# Pickle classes derived from UnevaluatedExpression
expr = BlattWeisskopfSquared(angular_momentum, z=z)
expr = BlattWeisskopfSquared(z, angular_momentum)
pickled_obj = pickle.dumps(expr)
imported_expr = pickle.loads(pickled_obj) # noqa: S301
assert expr == imported_expr
Expand Down
12 changes: 6 additions & 6 deletions tests/sympy/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ def test_get_readable_hash_large(amplitude_model: tuple[str, HelicityModel]):
# https://github.com/ComPWA/ampform/actions/runs/3277058875/jobs/5393849802
# https://github.com/ComPWA/ampform/actions/runs/3277143883/jobs/5394043014
expected_hash = {
"canonical-helicity": "pythonhashseed-0-6040455869260657745",
"helicity": "pythonhashseed-0-1928646339459384503",
"canonical-helicity": "pythonhashseed-0-3873186712292274641",
"helicity": "pythonhashseed-0-8800154542426799839",
}[formalism]
elif sys.version_info >= (3, 11):
expected_hash = {
"canonical-helicity": "pythonhashseed-0+409069872540431022",
"helicity": "pythonhashseed-0-8907705932662936900",
"canonical-helicity": "pythonhashseed-0+4035132515642199515",
"helicity": "pythonhashseed-0-2843057473565885663",
}[formalism]
else:
expected_hash = {
"canonical-helicity": "pythonhashseed-0-7143983882032045549",
"helicity": "pythonhashseed-0+3357246175053927117",
"canonical-helicity": "pythonhashseed-0+3420919389670627445",
"helicity": "pythonhashseed-0-6681863313351758450",
}[formalism]
assert get_readable_hash(model.expression) == expected_hash
23 changes: 22 additions & 1 deletion tests/sympy/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import inspect
from typing import Any
from typing import Any, ClassVar

import pytest
import sympy as sp
Expand Down Expand Up @@ -113,6 +113,27 @@ def evaluate(self) -> sp.Expr:
assert isinstance(q_value.m2, sp.Float)


def test_unevaluated_expression_classvar():
@unevaluated_expression
class MyExpr(sp.Expr):
x: float
m: ClassVar[int] = 2

def evaluate(self) -> sp.Expr:
return self.x**self.m # type: ignore[return-value]

x_expr = MyExpr(4)
assert x_expr.x is sp.Integer(4)
assert x_expr.m is 2 # noqa: F632

y_expr = MyExpr(5)
assert x_expr.doit() == 4**2
assert y_expr.doit() == 5**2
MyExpr.m = 3
assert x_expr.doit() == 4**3
assert y_expr.doit() == 5**3


def test_unevaluated_expression_callable():
@unevaluated_expression(implement_doit=False)
class Squared(sp.Expr):
Expand Down
Loading