Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: import v0.15 refactorings into v0.14 #399

Merged
merged 3 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 14 additions & 62 deletions src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections import OrderedDict, abc
from decimal import Decimal
from difflib import get_close_matches
from functools import reduce, singledispatch
from functools import reduce
from typing import (
TYPE_CHECKING,
Generator,
Expand Down Expand Up @@ -45,6 +45,7 @@
)
from ampform.helicity.decay import (
TwoBodyDecay,
get_outer_state_ids,
get_parent_id,
get_prefactor,
get_sibling_state_id,
Expand All @@ -56,10 +57,13 @@
CanonicalAmplitudeNameGenerator,
HelicityAmplitudeNameGenerator,
NameGenerator,
create_amplitude_base,
create_amplitude_symbol,
create_helicity_symbol,
create_spin_projection_symbol,
generate_transition_label,
get_helicity_angle_symbols,
get_helicity_suffix,
get_topology_identifier,
natural_sorting,
)
from ampform.kinematics import HelicityAdapter
Expand Down Expand Up @@ -595,7 +599,7 @@ def formulate(self) -> HelicityModel:
)

def __formulate_top_expression(self) -> PoolSum:
outer_state_ids = _get_outer_state_ids(self.__reaction)
outer_state_ids = get_outer_state_ids(self.__reaction)
spin_projections: collections.defaultdict[sp.Symbol, set[sp.Rational]] = (
collections.defaultdict(set)
)
Expand All @@ -605,7 +609,7 @@ def __formulate_top_expression(self) -> PoolSum:
for transition in group:
for i in outer_state_ids:
state = transition.states[i]
symbol = _create_spin_projection_symbol(i)
symbol = create_spin_projection_symbol(i)
value = sp.Rational(state.spin_projection)
spin_projections[symbol].add(value)

Expand All @@ -615,21 +619,20 @@ def __formulate_top_expression(self) -> PoolSum:
else:
indices = list(spin_projections)
amplitude = sum( # type: ignore[assignment]
_create_amplitude_base(topology)[indices]
for topology in topology_groups
create_amplitude_base(topology)[indices] for topology in topology_groups
)
return PoolSum(abs(amplitude) ** 2, *spin_projections.items())

def __formulate_aligned_amplitude(
self, topology_groups: dict[Topology, list[StateTransition]]
) -> sp.Expr:
outer_state_ids = _get_outer_state_ids(self.__reaction)
outer_state_ids = get_outer_state_ids(self.__reaction)
amplitude = sp.S.Zero
for topology, transitions in topology_groups.items():
base = _create_amplitude_base(topology)
base = create_amplitude_base(topology)
helicities = [
_get_opposite_helicity_sign(topology, i)
* _create_helicity_symbol(topology, i)
* create_helicity_symbol(topology, i)
for i in outer_state_ids
]
amplitude_symbol = base[helicities]
Expand Down Expand Up @@ -666,7 +669,7 @@ def __formulate_topology_amplitude(
sequential_expressions.append(expression)

first_transition = transitions[0]
symbol = _create_amplitude_symbol(first_transition)
symbol = create_amplitude_symbol(first_transition)
expression = sum(sequential_expressions) # type: ignore[assignment]
self.__ingredients.amplitudes[symbol] = expression
return expression
Expand Down Expand Up @@ -765,63 +768,12 @@ def __generate_amplitude_prefactor(
return None


def _create_amplitude_symbol(transition: StateTransition) -> sp.Indexed:
outer_state_ids = _get_outer_state_ids(transition)
helicities = tuple(
sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids
)
base = _create_amplitude_base(transition.topology)
return base[helicities]


def _get_opposite_helicity_sign(topology: Topology, state_id: int) -> Literal[-1, 1]:
if state_id != -1 and is_opposite_helicity_state(topology, state_id):
return -1
return 1


def _create_amplitude_base(topology: Topology) -> sp.IndexedBase:
superscript = get_topology_identifier(topology)
return sp.IndexedBase(f"A^{superscript}", complex=True)


def _create_helicity_symbol(
topology: Topology, state_id: int, root: str = "lambda"
) -> sp.Symbol:
if state_id == -1: # initial state
name = "m_A"
else:
suffix = get_helicity_suffix(topology, state_id)
name = f"{root}{suffix}"
return sp.Symbol(name, rational=True)


def _create_spin_projection_symbol(state_id: int) -> sp.Symbol:
if state_id == -1: # initial state
suffix = "_A"
else:
suffix = str(state_id)
return sp.Symbol(f"m{suffix}", rational=True)


@singledispatch
def _get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]:
msg = f"Cannot get outer state IDs from a {type(obj).__name__}"
raise NotImplementedError(msg)


@_get_outer_state_ids.register(StateTransition)
def _(transition: StateTransition) -> list[int]:
outer_state_ids = list(transition.initial_states)
outer_state_ids += sorted(transition.final_states)
return outer_state_ids


@_get_outer_state_ids.register(ReactionInfo)
def _(reaction: ReactionInfo) -> list[int]:
return _get_outer_state_ids(reaction.transitions[0])


class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder):
r"""Amplitude model generator for the canonical helicity formalism.

Expand Down Expand Up @@ -1033,7 +985,7 @@ def formulate_rotation_chain(
plus a Wigner rotation (see :func:`.formulate_wigner_rotation`) in case there is
more than one helicity rotation.
"""
helicity_symbol = _create_spin_projection_symbol(rotated_state_id)
helicity_symbol = create_spin_projection_symbol(rotated_state_id)
helicity_rotations = formulate_helicity_rotation_chain(
transition, rotated_state_id, helicity_symbol
)
Expand Down
20 changes: 19 additions & 1 deletion src/ampform/helicity/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Iterable

from attrs import frozen
from qrules.transition import State, StateTransition
from qrules.transition import ReactionInfo, State, StateTransition

if TYPE_CHECKING:
from qrules.quantum_numbers import InteractionProperties
Expand Down Expand Up @@ -295,6 +295,24 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in
return sorted(topology.get_originating_final_state_edge_ids(edge.ending_node_id))


@singledispatch
def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]:
msg = f"Cannot get outer state IDs from a {type(obj).__name__}"
raise NotImplementedError(msg)


@get_outer_state_ids.register(StateTransition)
def _(transition: StateTransition) -> list[int]:
outer_state_ids = list(transition.initial_states)
outer_state_ids += sorted(transition.final_states)
return outer_state_ids


@get_outer_state_ids.register(ReactionInfo)
def _(reaction: ReactionInfo) -> list[int]:
return get_outer_state_ids(reaction.transitions[0])


def get_prefactor(transition: StateTransition) -> float:
"""Calculate the product of all prefactors defined in this transition.

Expand Down
34 changes: 34 additions & 0 deletions src/ampform/helicity/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
assert_isobar_topology,
determine_attached_final_state,
get_helicity_info,
get_outer_state_ids,
get_sorted_states,
)

Expand Down Expand Up @@ -287,6 +288,20 @@ def __generate_ls_arrow(transition: StateTransition, node_id: int) -> str:
return Rf" \xrightarrow[S={coupled_spin}]{{L={angular_momentum}}} "


def create_amplitude_symbol(transition: StateTransition) -> sp.Indexed:
outer_state_ids = get_outer_state_ids(transition)
helicities = tuple(
sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids
)
base = create_amplitude_base(transition.topology)
return base[helicities]


def create_amplitude_base(topology: Topology) -> sp.IndexedBase:
superscript = get_topology_identifier(topology)
return sp.IndexedBase(f"A^{superscript}", complex=True)


def generate_transition_label(transition: StateTransition) -> str:
r"""Generate a label for a coherent intensity, including spin projection.

Expand Down Expand Up @@ -495,3 +510,22 @@ def _render_float(value: float) -> str:
if value > 0:
return f"+{rational}"
return str(rational)


def create_helicity_symbol(
topology: Topology, state_id: int, root: str = "lambda"
) -> sp.Symbol:
if state_id == -1: # initial state
name = "m_A"
else:
suffix = get_helicity_suffix(topology, state_id)
name = f"{root}{suffix}"
return sp.Symbol(name, rational=True)


def create_spin_projection_symbol(state_id: int) -> sp.Symbol:
if state_id == -1: # initial state
suffix = "_A"
else:
suffix = str(state_id)
return sp.Symbol(f"m{suffix}", rational=True)
10 changes: 7 additions & 3 deletions src/ampform/kinematics/lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ def create_four_momentum_symbols(topology: Topology) -> FourMomenta:
>>> create_four_momentum_symbols(topologies[0])
{0: p0, 1: p1, 2: p2}
"""
n_final_states = len(topology.outgoing_edge_ids)
return {i: FourMomentumSymbol(f"p{i}", shape=[]) for i in range(n_final_states)}
final_state_ids = sorted(topology.outgoing_edge_ids)
return {i: create_four_momentum_symbol(i) for i in final_state_ids}


def create_four_momentum_symbol(index: int) -> FourMomentumSymbol:
return FourMomentumSymbol(f"p{index}", shape=[])


FourMomenta = Dict[int, "FourMomentumSymbol"]
"""A mapping of state IDs to their corresponding `FourMomentumSymbol`.
"""A mapping of state IDs to their corresponding `.FourMomentumSymbol`.

It's best to create a `dict` of `.FourMomenta` with
:func:`create_four_momentum_symbols`.
Expand Down
Loading