diff --git a/pyproject.toml b/pyproject.toml index 2e3f1d3ee..ff62ec395 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -256,6 +256,7 @@ reportPrivateImportUsage = false reportPrivateUsage = false reportUnboundVariable = false reportUnknownArgumentType = false +reportUnknownLambdaType = false reportUnknownMemberType = false reportUnknownParameterType = false reportUnknownVariableType = false diff --git a/src/ampform/helicity/__init__.py b/src/ampform/helicity/__init__.py index 47809a32c..80d6ee9e8 100644 --- a/src/ampform/helicity/__init__.py +++ b/src/ampform/helicity/__init__.py @@ -30,8 +30,14 @@ from attrs.validators import deep_iterable, instance_of, optional from qrules.combinatorics import perform_external_edge_identical_particle_combinatorics from qrules.particle import Particle -from qrules.transition import ReactionInfo, StateTransition +from qrules.transition import ( + InteractionProperties, + ReactionInfo, + State, + StateTransition, +) +from ampform._qrules import get_qrules_version from ampform.dynamics.builder import ( ResonanceDynamicsBuilder, TwoBodyKinematicVariableSet, @@ -70,6 +76,7 @@ if TYPE_CHECKING: from IPython.lib.pretty import PrettyPrinter + from qrules.topology import MutableTransition _LOGGER = logging.getLogger(__name__) @@ -450,11 +457,9 @@ def __formulate_topology_amplitude( ) -> sp.Expr: sequential_expressions: list[sp.Expr] = [] for transition in transitions: - sequential_graphs = perform_external_edge_identical_particle_combinatorics( - transition.to_graph() - ) + sequential_graphs = _perform_combinatorics(transition) for graph in sequential_graphs: - first_transition = StateTransition.from_graph(graph) + first_transition = _freeze(graph) expression = self.__formulate_sequential_decay(first_transition) sequential_expressions.append(expression) @@ -558,6 +563,24 @@ def __generate_amplitude_prefactor( return None +def _perform_combinatorics( + transition: StateTransition, +) -> list[MutableTransition[State, InteractionProperties]]: + if get_qrules_version() < (0, 10): + return perform_external_edge_identical_particle_combinatorics( + transition.to_graph() # type: ignore[attr-defined] + ) + graph = transition.convert(lambda s: (s.particle, s.spin_projection)).unfreeze() + combinations = perform_external_edge_identical_particle_combinatorics(graph) + return [g.freeze().convert(lambda s: State(*s)).unfreeze() for g in combinations] + + +def _freeze(graph: MutableTransition[State, InteractionProperties]) -> StateTransition: + if get_qrules_version() < (0, 10): + return StateTransition.from_graph(graph) # type: ignore[attr-defined] + return graph.freeze() + + class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder): r"""Amplitude model generator for the canonical helicity formalism. diff --git a/src/ampform/helicity/align/dpd.py b/src/ampform/helicity/align/dpd.py index 79138bb03..09a5598eb 100644 --- a/src/ampform/helicity/align/dpd.py +++ b/src/ampform/helicity/align/dpd.py @@ -13,9 +13,10 @@ from attrs import define, field from attrs.validators import in_ from qrules.topology import Topology -from qrules.transition import ReactionInfo, StateTransition, StateTransitionCollection +from qrules.transition import ReactionInfo, StateTransition from sympy.physics.quantum.spin import Rotation as Wigner +from ampform._qrules import get_qrules_version from ampform.helicity.align import SpinAlignment from ampform.helicity.decay import ( get_outer_state_ids, @@ -34,6 +35,11 @@ if TYPE_CHECKING: from sympy.physics.quantum.spin import WignerD +if get_qrules_version() < (0, 10): + from qrules.transition import ( # type: ignore[attr-defined] + StateTransitionCollection, + ) + @define class DalitzPlotDecomposition(SpinAlignment): @@ -109,8 +115,14 @@ def __call__( return Wigner.d(j, m, m_prime, zeta) -T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology) -"""Allowed types for :func:`relabel_edge_ids`.""" +if get_qrules_version() < (0, 10): + T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology) + """Allowed types for :func:`relabel_edge_ids`.""" +else: + T = TypeVar( # type: ignore[misc] # pyright: ignore[reportConstantRedefinition] + "T", ReactionInfo, StateTransition, Topology + ) + """Allowed types for :func:`relabel_edge_ids`.""" @singledispatch @@ -121,21 +133,29 @@ def relabel_edge_ids(obj: T) -> T: @relabel_edge_ids.register(ReactionInfo) def _(obj: ReactionInfo) -> ReactionInfo: # type: ignore[misc] - return ReactionInfo( # no attrs.evolve() in order to call __attrs_post_init__() - transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], + if get_qrules_version() < (0, 10): + return ReactionInfo( # type: ignore[call-arg] + transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], # type: ignore[attr-defined] + formalism=obj.formalism, + ) + return ReactionInfo( + # no attrs.evolve() in order to call __attrs_post_init__() + transitions=[relabel_edge_ids(g) for g in obj.transitions], formalism=obj.formalism, ) -@relabel_edge_ids.register(StateTransitionCollection) -def _(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc] - return StateTransitionCollection( # no attrs.evolve() for __attrs_post_init__() - [relabel_edge_ids(transition) for transition in obj.transitions] - ) +if get_qrules_version() < (0, 10): + def __relabel_stc(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc] + return StateTransitionCollection( + [relabel_edge_ids(transition) for transition in obj.transitions] + ) -@relabel_edge_ids.register(StateTransition) -def _(obj: StateTransition) -> StateTransition: # type: ignore[misc] + relabel_edge_ids.register(StateTransitionCollection)(__relabel_stc) + + +def __relabel_st(obj: StateTransition) -> StateTransition: # type: ignore[misc] mapping = __get_default_relabel_mapping() return attrs.evolve( obj, @@ -144,6 +164,14 @@ def _(obj: StateTransition) -> StateTransition: # type: ignore[misc] ) +if get_qrules_version() < (0, 10): + relabel_edge_ids.register(StateTransition)(__relabel_st) +else: + from qrules.topology import FrozenTransition + + relabel_edge_ids.register(FrozenTransition)(__relabel_st) + + @relabel_edge_ids.register(Topology) def _(obj: Topology) -> Topology: # type: ignore[misc] mapping = __get_default_relabel_mapping() diff --git a/src/ampform/helicity/decay.py b/src/ampform/helicity/decay.py index c2e715583..a2e738b79 100644 --- a/src/ampform/helicity/decay.py +++ b/src/ampform/helicity/decay.py @@ -7,16 +7,22 @@ from typing import TYPE_CHECKING, Iterable from attrs import frozen +from qrules.quantum_numbers import InteractionProperties from qrules.transition import ReactionInfo, State, StateTransition +from ampform._qrules import get_qrules_version + if TYPE_CHECKING: - from qrules.quantum_numbers import InteractionProperties from qrules.topology import Topology if sys.version_info < (3, 8): from typing_extensions import Literal else: from typing import Literal +if sys.version_info < (3, 10): + from typing_extensions import TypeGuard +else: + from typing import TypeGuard @frozen @@ -103,12 +109,30 @@ def _(obj: TwoBodyDecay) -> TwoBodyDecay: def _(obj: tuple) -> TwoBodyDecay: if len(obj) == 2: # noqa: PLR2004 transition, node_id = obj - if isinstance(transition, StateTransition) and isinstance(node_id, int): - return TwoBodyDecay.from_transition(*obj) + if _is_qrules_state_transition(transition) and isinstance(node_id, int): + return TwoBodyDecay.from_transition(transition, node_id) msg = f"Cannot create a {TwoBodyDecay.__name__} from {obj}" raise NotImplementedError(msg) +def _is_qrules_state_transition(obj) -> TypeGuard[StateTransition]: + if get_qrules_version() >= (0, 10): + from qrules.topology import FrozenTransition + + if isinstance(obj, FrozenTransition): + if any(not isinstance(s, State) for s in obj.states.values()): + return False + if any( + not isinstance(i, InteractionProperties) + for i in obj.interactions.values() + ): + return False + return True + if get_qrules_version() < (0, 10) and isinstance(obj, StateTransition): # type: ignore[misc] + return True + return False + + @lru_cache(maxsize=None) def is_opposite_helicity_state(topology: Topology, state_id: int) -> bool: """Determine if an edge is an "opposite helicity" state. @@ -328,8 +352,13 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in >>> from qrules.topology import create_isobar_topologies >>> topologies = create_isobar_topologies(5) - >>> determine_attached_final_state(topologies[0], state_id=5) + >>> determine_attached_final_state(topologies[3], state_id=5) [0, 3, 4] + >>> import pytest + >>> from ampform._qrules import get_qrules_version + >>> if get_qrules_version() < (0, 10): + ... pytest.skip('Doctest only works for qrules>=0.10') + ... """ edge = topology.edges[state_id] if edge.ending_node_id is None: @@ -343,13 +372,20 @@ def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]: raise NotImplementedError(msg) -@get_outer_state_ids.register(StateTransition) -def _(transition: StateTransition) -> list[int]: +def __convert_state_transition(transition: StateTransition) -> list[int]: outer_state_ids = list(transition.initial_states) outer_state_ids += sorted(transition.final_states) return outer_state_ids +if get_qrules_version() < (0, 10): + get_outer_state_ids.register(StateTransition)(__convert_state_transition) +else: + from qrules.topology import FrozenTransition + + get_outer_state_ids.register(FrozenTransition)(__convert_state_transition) + + @get_outer_state_ids.register(ReactionInfo) def _(reaction: ReactionInfo) -> list[int]: return get_outer_state_ids(reaction.transitions[0]) diff --git a/src/ampform/helicity/naming.py b/src/ampform/helicity/naming.py index 761d10585..91339e9cf 100644 --- a/src/ampform/helicity/naming.py +++ b/src/ampform/helicity/naming.py @@ -351,8 +351,9 @@ def get_boost_chain_suffix(topology: Topology, state_id: int) -> str: the internal decay topology. >>> from qrules.topology import create_isobar_topologies + >>> from ampform._qrules import get_qrules_version >>> topologies = create_isobar_topologies(5) - >>> topology = topologies[0] + >>> topology = topologies[0 if get_qrules_version() < (0, 10) else 3] >>> for i in topology.intermediate_edge_ids | topology.outgoing_edge_ids: ... suffix = get_boost_chain_suffix(topology, i) ... print(f"{i}: 'phi{suffix}'") @@ -364,7 +365,7 @@ def get_boost_chain_suffix(topology: Topology, state_id: int) -> str: 5: 'phi_034' 6: 'phi_12' 7: 'phi_34^034' - >>> topology = topologies[1] + >>> topology = topologies[1 if get_qrules_version() < (0, 10) else 2] >>> for i in topology.intermediate_edge_ids | topology.outgoing_edge_ids: ... suffix = get_boost_chain_suffix(topology, i) ... print(f"{i}: 'phi{suffix}'") diff --git a/src/ampform/kinematics/__init__.py b/src/ampform/kinematics/__init__.py index 3ae4cddcd..cac5816f7 100644 --- a/src/ampform/kinematics/__init__.py +++ b/src/ampform/kinematics/__init__.py @@ -16,6 +16,7 @@ from qrules.topology import Topology from qrules.transition import ReactionInfo, StateTransition +from ampform._qrules import get_qrules_version from ampform.helicity.decay import assert_isobar_topology from ampform.kinematics.angles import compute_helicity_angles from ampform.kinematics.lorentz import ( @@ -120,6 +121,13 @@ def _(obj: Topology) -> Topology: return obj -@_get_topology.register(StateTransition) -def _(obj: StateTransition) -> Topology: +def __get_state_transition(obj: StateTransition) -> Topology: return obj.topology + + +if get_qrules_version() < (0, 10): + _get_topology.register(StateTransition)(__get_state_transition) +else: + from qrules.topology import FrozenTransition + + _get_topology.register(FrozenTransition)(__get_state_transition) diff --git a/src/ampform/kinematics/lorentz.py b/src/ampform/kinematics/lorentz.py index d01c6fa7b..c169e2ce9 100644 --- a/src/ampform/kinematics/lorentz.py +++ b/src/ampform/kinematics/lorentz.py @@ -742,8 +742,10 @@ def get_invariant_mass_symbol(topology: Topology, state_id: int) -> sp.Symbol: state :math:`5` is :math:`m_{034}`, because :math:`p_5=p_0+p_3+p_4`: >>> from qrules.topology import create_isobar_topologies + >>> from ampform._qrules import get_qrules_version >>> topologies = create_isobar_topologies(5) - >>> get_invariant_mass_symbol(topologies[0], state_id=5) + >>> topology = topologies[0 if get_qrules_version() < (0, 10) else 3] + >>> get_invariant_mass_symbol(topology, state_id=5) m_034 Naturally, the 'invariant' mass label for a final state is just the mass of the diff --git a/tests/helicity/test_decay.py b/tests/helicity/test_decay.py index 857a4fd4c..4a06b3b4d 100644 --- a/tests/helicity/test_decay.py +++ b/tests/helicity/test_decay.py @@ -6,6 +6,7 @@ import pytest from qrules.topology import Topology, create_isobar_topologies +from ampform._qrules import get_qrules_version from ampform.helicity.decay import ( determine_attached_final_state, get_sibling_state_id, @@ -24,10 +25,10 @@ def test_determine_attached_final_state(): topology.outgoing_edge_ids ) # intermediate states - topology = topologies[0] + topology = topologies[0 if get_qrules_version() < (0, 10) else 1] assert determine_attached_final_state(topology, state_id=4) == [0, 1] assert determine_attached_final_state(topology, state_id=5) == [2, 3] - topology = topologies[1] + topology = topologies[1 if get_qrules_version() < (0, 10) else 0] assert determine_attached_final_state(topology, state_id=4) == [1, 2, 3] assert determine_attached_final_state(topology, state_id=5) == [2, 3] diff --git a/tests/kinematics/conftest.py b/tests/kinematics/conftest.py index 385c10807..ba2ac41ff 100644 --- a/tests/kinematics/conftest.py +++ b/tests/kinematics/conftest.py @@ -5,6 +5,7 @@ import pytest from qrules.topology import Topology, create_isobar_topologies +from ampform._qrules import get_qrules_version from ampform.kinematics.lorentz import FourMomenta, create_four_momentum_symbols if TYPE_CHECKING: @@ -18,6 +19,6 @@ def topology_and_momentum_symbols( n = len(data_sample) assert n == 4 topologies = create_isobar_topologies(n) - topology = topologies[1] + topology = topologies[1 if get_qrules_version() < (0, 10) else 0] momentum_symbols = create_four_momentum_symbols(topology) return topology, momentum_symbols