From c4a253cd331755c9b071b12e87528079f381002e Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 16 May 2024 15:03:35 +0200 Subject: [PATCH] FEAT: simplify Hankel polynomial internally --- src/ampform/dynamics/form_factor.py | 34 ++++++++++++++++++++++++----- tests/dynamics/test_dynamics.py | 20 +---------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/ampform/dynamics/form_factor.py b/src/ampform/dynamics/form_factor.py index d17af8a8d..dd526b4db 100644 --- a/src/ampform/dynamics/form_factor.py +++ b/src/ampform/dynamics/form_factor.py @@ -2,6 +2,7 @@ from __future__ import annotations +from functools import lru_cache from typing import Any import sympy as sp @@ -37,12 +38,16 @@ class BlattWeisskopfSquared(sp.Expr): _latex_repr_ = R"B_{{{angular_momentum}}}^2\left({z}\right)" def evaluate(self) -> sp.Expr: - z, angular_momentum = self.args - return ( - sp.Abs(SphericalHankel1(angular_momentum, 1)) ** 2 - / sp.Abs(SphericalHankel1(angular_momentum, sp.sqrt(z))) ** 2 + ell = self.angular_momentum + z = sp.Dummy("z", nonnegative=True, real=True) + expr = ( + sp.Abs(SphericalHankel1(ell, 1)) ** 2 + / sp.Abs(SphericalHankel1(ell, sp.sqrt(z))) ** 2 / z ) + if not ell.free_symbols: + expr = expr.doit().simplify() + return expr.xreplace({z: self.z}) @unevaluated(implement_doit=False) @@ -75,10 +80,29 @@ def evaluate(self) -> sp.Expr: return ( (-sp.I) ** (1 + l) # type:ignore[operator] * (sp.exp(z * sp.I) / z) - * sp.Sum( + * _SymbolicSum( sp.factorial(l + k) / (sp.factorial(l - k) * sp.factorial(k)) * (sp.I / (2 * z)) ** k, # type:ignore[operator] (k, 0, l), ) ) + + +class _SymbolicSum(sp.Sum): + """See [TR-029](https://compwa.github.io/report/029.html) for why this class is needed.""" + + def doit(self, deep: bool = True, **kwargs) -> sp.Expr: + if _get_indices(self): + expression = self.args[0] + indices = self.args[1:] + return _SymbolicSum(expression.doit(deep=deep, **kwargs), *indices) + return super().doit(deep=deep, **kwargs) + + +@lru_cache(maxsize=None) +def _get_indices(expr: sp.Sum) -> set[sp.Basic]: + free_symbols = set() + for index in expr.args[1:]: + free_symbols.update(index.free_symbols) + return {s for s in free_symbols if not isinstance(s, sp.Dummy)} diff --git a/tests/dynamics/test_dynamics.py b/tests/dynamics/test_dynamics.py index 3ad8f561e..3473c6f33 100644 --- a/tests/dynamics/test_dynamics.py +++ b/tests/dynamics/test_dynamics.py @@ -21,7 +21,7 @@ class TestBlattWeisskopfSquared: - def test_max_angular_momentum(self): + def test_factorials(self): z = sp.Symbol("z") angular_momentum = sp.Symbol("L", integer=True) form_factor = BlattWeisskopfSquared(z, angular_momentum) @@ -29,24 +29,6 @@ def test_max_angular_momentum(self): factor, z_power, _ = form_factor_9.args assert factor == 4392846440677 assert z_power == z**8 - assert BlattWeisskopfSquared.max_angular_momentum is None - BlattWeisskopfSquared.max_angular_momentum = 1 - assert form_factor.evaluate() == sp.Piecewise( - (1, sp.Eq(angular_momentum, 0)), - (2 * z / (z + 1), sp.Eq(angular_momentum, 1)), - ) - BlattWeisskopfSquared.max_angular_momentum = None - - def test_unevaluated_expression(self): - z = sp.Symbol("z") - ff1 = BlattWeisskopfSquared(z, angular_momentum=1) - ff2 = BlattWeisskopfSquared(z, angular_momentum=2) - assert ff1.max_angular_momentum is None - assert ff2.max_angular_momentum is None - BlattWeisskopfSquared.max_angular_momentum = 3 - assert ff1.max_angular_momentum is 3 # noqa: F632 - assert ff2.max_angular_momentum is 3 # noqa: F632 - BlattWeisskopfSquared.max_angular_momentum = None class TestEnergyDependentWidth: