From c13280dd3c0b1590b47872b2279da948f8aa0c2a Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 21 May 2024 16:27:10 +0200 Subject: [PATCH] FIX: avoid overwriting chain amplitudes --- docs/serialization.ipynb | 3 ++- src/ampform_dpd/io/serialization/amplitude.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/serialization.ipynb b/docs/serialization.ipynb index dab3a7aa..bffeb873 100644 --- a/docs/serialization.ipynb +++ b/docs/serialization.ipynb @@ -678,7 +678,8 @@ "metadata": { "tags": [ "hide-input", - "full-width" + "full-width", + "scroll-output" ] }, "outputs": [], diff --git a/src/ampform_dpd/io/serialization/amplitude.py b/src/ampform_dpd/io/serialization/amplitude.py index e37a56e4..8aacdb82 100644 --- a/src/ampform_dpd/io/serialization/amplitude.py +++ b/src/ampform_dpd/io/serialization/amplitude.py @@ -63,7 +63,7 @@ def formulate( # noqa: PLR0914 symbol: create_spin_range(states[i].spin) # type:ignore[index] for i, symbol in enumerate(helicity_symbols) } - amplitude_definitions = {} + amplitude_definitions = {} # type:ignore[var-annotated] angle_definitions = {} parameter_defaults = {} n_chains = len(get_decay_chains(model)) @@ -78,7 +78,9 @@ def formulate( # noqa: PLR0914 msg = f"Expected an expression, got {amp_expr!r}" raise TypeError(msg) helicity_substitutions = dict(zip(helicity_symbols, helicity_values)) - amplitude_definitions[amp_symbol] = amp_expr.subs(helicity_substitutions) + existing_amplitude = amplitude_definitions.get(amp_symbol, sp.Integer(0)) + existing_amplitude += amp_expr.subs(helicity_substitutions) + amplitude_definitions[amp_symbol] = existing_amplitude angle_definitions[θij] = θij_expr parameter_defaults.update(dict(parameters)) aligned_amp, zeta_defs = formulate_aligned_amplitude(model, *helicity_symbols)