Skip to content

Commit

Permalink
DOC: illustrate SympyDataTransformer usage (#17)
Browse files Browse the repository at this point in the history
* DOC: add sigma labels to Dalitz plotaxes
* DOC: indicate resonance lines
* DOC: remove resonances that are outside phase space
* DX: only hide caching warnings in Sphinx build
  • Loading branch information
redeboer authored Oct 11, 2022
1 parent 46f6100 commit 13a1561
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 37 deletions.
8 changes: 7 additions & 1 deletion .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
"absl",
"arange",
"aslatex",
"autoscale",
"autoupdate",
"axvline",
"caplog",
"codecov",
"codemirror",
Expand All @@ -53,10 +55,12 @@
"docnb",
"elif",
"figsize",
"fontsize",
"gcov",
"ipykernel",
"ipython",
"isinstance",
"isnan",
"itertools",
"kernelspec",
"linkcheck",
Expand All @@ -82,11 +86,13 @@
"startswith",
"textwrap",
"toctree",
"tqdm",
"wspace",
"xlabel",
"xreplace",
"xrightarrow",
"ylabel"
"ylabel",
"ylim"
],
"language": "en-US",
"version": "0.2",
Expand Down
218 changes: 182 additions & 36 deletions docs/jpsi2ksp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"\n",
"import itertools\n",
"import logging\n",
"import os\n",
"from typing import Iterable\n",
"\n",
"import jax.numpy as jnp\n",
Expand All @@ -45,6 +46,7 @@
" make_commutative,\n",
")\n",
"from IPython.display import Latex, Markdown\n",
"from tensorwaves.data.transform import SympyDataTransformer\n",
"from tensorwaves.function.sympy import create_function\n",
"\n",
"from ampform_dpd import (\n",
Expand All @@ -62,8 +64,10 @@
"from ampform_dpd.spin import filter_parity_violating_ls, generate_ls_couplings\n",
"\n",
"simplify_latex_rendering()\n",
"logging.getLogger(\"ampform_dpd.io\").setLevel(logging.ERROR)\n",
"logging.getLogger(\"absl\").setLevel(logging.ERROR) # mute JAX"
"logging.getLogger(\"absl\").setLevel(logging.ERROR) # mute JAX\n",
"NO_TQDM = \"EXECUTE_NB\" in os.environ\n",
"if NO_TQDM:\n",
" logging.getLogger(\"ampform_dpd.io\").setLevel(logging.ERROR)"
]
},
{
Expand All @@ -77,7 +81,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We follow [this example](https://qrules.readthedocs.io/en/0.9.7/usage.html#investigate-intermediate-resonances), which was generated with QRules, and leave out the $K$-resonances:\n",
"We follow [this example](https://qrules.readthedocs.io/en/0.9.7/usage.html#investigate-intermediate-resonances), which was generated with QRules, and leave out the $K$-resonances and the resonances that lie far outside of phase space:\n",
"\n",
"![](https://qrules.readthedocs.io/en/0.9.7/_images/usage_9_0.svg)\n",
"\n",
Expand Down Expand Up @@ -134,16 +138,11 @@
"outputs": [],
"source": [
"resonance_names = [\n",
" \"Sigma(1385)~-\",\n",
" \"Sigma(1660)~-\",\n",
" \"Sigma(1670)~-\",\n",
" \"Sigma(1750)~-\",\n",
" \"Sigma(1775)~-\",\n",
" \"Sigma(1910)~-\",\n",
" \"N(1440)+\",\n",
" \"N(1520)+\",\n",
" \"N(1535)+\",\n",
" \"N(1650)+\",\n",
" \"N(1675)+\",\n",
" \"N(1700)+\",\n",
" \"N(1710)+\",\n",
Expand Down Expand Up @@ -585,14 +584,33 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dalitz plot"
"## Preparing for input data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The {meth}`~sympy.core.basic.Basic.doit` operation can be cached to disk with {func}`.perform_cached_doit`. We do this twice, once for the unfolding of the {attr}`~.AmplitudeModel.intensity` expression and second for the substitution and unfolding of the {attr}`~.AmplitudeModel.amplitudes`. Note that we could also have unfolded the intensity and substituted the amplitudes with {attr}`~.AmplitudeModel.full_expression`, but then the unfolded {attr}`~.AmplitudeModel.intensity` expression is not cached."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"unfolded_intensity_expr = perform_cached_doit(model.intensity)\n",
"full_intensity_expr = perform_cached_doit(\n",
" unfolded_intensity_expr.xreplace(model.amplitudes)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We set each helicity coupling to $1$, so that each each parameter value has a definition:"
"We set each helicity coupling to $1$, so that each each parameter {class}`~sympy.core.symbol.Symbol` in the expression has a definition:"
]
},
{
Expand All @@ -601,32 +619,110 @@
"metadata": {},
"outputs": [],
"source": [
"full_expression = perform_cached_doit(model.full_expression)\n",
"couplings = {\n",
" s: 1\n",
" for s in full_expression.free_symbols\n",
" for s in full_intensity_expr.free_symbols\n",
" if \"production\" in str(s) or \"decay\" in str(s)\n",
"}\n",
"model.parameter_defaults.update(couplings)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With this, the remaining {class}`~sympy.core.symbol.Symbol`s in the full expression are kinematic variables."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"sp.Array(full_intensity_expr.free_symbols - set(model.parameter_defaults))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The $\\theta$ and $\\zeta$ angles are defined by the {attr}`~.AmplitudeModel.variables` attribute (they are shown under {ref}`jpsi2ksp:Model formulation`). Those definitions allow us to create a converter that computes kinematic variables from masses and Mandelstam variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
"masses_to_angles = SympyDataTransformer.from_sympy(model.variables, backend=\"jax\")\n",
"masses_to_angles.functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dalitz plot"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The data input for this data transformer can be several things. One can compute them from a (generated) data sample of four-momenta. Or one can compute them for a Dalitz plane. We do the latter in this section.\n",
"\n",
"First, the data transformer defined above expects values for the masses. We have already defined these values above, but we need to convert them from {mod}`sympy` objects to numerical data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dalitz_data = {str(s): float(v) for s, v in masses.items()}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we define a grid of data points over Mandelstam (Dalitz) variables $\\sigma_2=m_{13}, \\sigma_3=m_{12}$:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"subs_expression = perform_cached_doit(\n",
" full_expression.xreplace(model.variables)\n",
").xreplace(model.parameter_defaults)\n",
"subs_expression.free_symbols"
"resolution = 500\n",
"X, Y = jnp.meshgrid(\n",
" jnp.linspace(1.66**2, 2.18**2, num=resolution),\n",
" jnp.linspace(1.4**2, 1.93**2, num=resolution),\n",
")\n",
"dalitz_data[\"sigma3\"] = X\n",
"dalitz_data[\"sigma2\"] = Y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The remaining symbols in the expressions are three Mandelstam variables and initial and final state masses. The masses we know, and one of the Mandelstam variables can be expressed in terms of the others as follows:"
"The remaining Mandelstam variable can be expressed in terms of the others as follows:"
]
},
{
Expand All @@ -652,7 +748,29 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"With this, we have a single expression for the intensity that only depends on Mandelstam (Dalitz) variables $\\sigma_2=m_{13}, \\sigma_3=m_{12}$:"
"That completes the data sample over which we want to evaluate the intensity model defined above:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
"sigma1_func = create_function(s1_expr, backend=\"jax\")\n",
"dalitz_data[\"sigma1\"] = sigma1_func(dalitz_data)\n",
"dalitz_data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now extend the sample with angle definitions so that we have a data sample over which the intensity can be evaluated."
]
},
{
Expand All @@ -661,10 +779,22 @@
"metadata": {},
"outputs": [],
"source": [
"dalitz_expression = perform_cached_doit(\n",
" subs_expression.xreplace({s1: s1_expr}).xreplace(masses)\n",
")\n",
"dalitz_expression.free_symbols"
"angle_data = masses_to_angles(dalitz_data)\n",
"dalitz_data.update(angle_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"for k, v in dalitz_data.items():\n",
" assert not jnp.all(jnp.isnan(v)), f\"All values for {k} are NaN\""
]
},
{
Expand All @@ -673,7 +803,10 @@
"metadata": {},
"outputs": [],
"source": [
"func = create_function(dalitz_expression, backend=\"jax\")"
"intensity_func = create_function(\n",
" full_intensity_expr.xreplace(model.parameter_defaults),\n",
" backend=\"jax\",\n",
")"
]
},
{
Expand Down Expand Up @@ -703,25 +836,16 @@
"outputs": [],
"source": [
"plt.rc(\"font\", size=18)\n",
"resolution = 500\n",
"X, Y = jnp.meshgrid(\n",
" jnp.linspace(1.66**2, 2.18**2, num=resolution),\n",
" jnp.linspace(1.4**2, 1.93**2, num=resolution),\n",
")\n",
"data = {\n",
" \"sigma3\": X,\n",
" \"sigma2\": Y,\n",
"}\n",
"intensities = func(data)\n",
"intensities = intensity_func(dalitz_data)\n",
"normalized_intensities = intensities / jnp.nansum(intensities)\n",
"\n",
"fig, ax = plt.subplots(figsize=(14, 10))\n",
"mesh = ax.pcolormesh(X, Y, normalized_intensities)\n",
"ax.set_aspect(\"equal\")\n",
"c_bar = plt.colorbar(mesh, ax=ax, pad=0.01)\n",
"c_bar.ax.set_ylabel(\"Normalized intensity (a.u.)\")\n",
"ax.set_xlabel(R\"$M^2\\left(K^0\\Sigma^+\\right)$\")\n",
"ax.set_ylabel(R\"$M^2\\left(K^0\\bar{p}\\right)$\")\n",
"ax.set_xlabel(R\"$\\sigma_3 = M^2\\left(K^0\\Sigma^+\\right)$\")\n",
"ax.set_ylabel(R\"$\\sigma_2 = M^2\\left(K^0\\bar{p}\\right)$\")\n",
"plt.show()"
]
},
Expand All @@ -747,7 +871,8 @@
},
"tags": [
"hide-input",
"full-width"
"full-width",
"scroll-input"
]
},
"outputs": [],
Expand All @@ -758,9 +883,30 @@
"ax1, ax2 = axes\n",
"ax1.fill_between(jnp.sqrt(X[0]), jnp.nansum(normalized_intensities, axis=0))\n",
"ax2.fill_between(jnp.sqrt(Y[:, 0]), jnp.nansum(normalized_intensities, axis=1))\n",
"for ax in axes:\n",
" _, y_max = ax.get_ylim()\n",
" ax.set_ylim(0, y_max)\n",
" ax.autoscale(enable=False, axis=\"x\")\n",
"ax1.set_ylabel(\"Normalized intensity (a.u.)\")\n",
"ax1.set_xlabel(R\"$M\\left(K^0\\Sigma^+\\right)$\")\n",
"ax2.set_xlabel(R\"$M\\left(K^0\\bar{p}\\right)$\")\n",
"i1, i2 = 0, 0\n",
"for chain in model.decay.chains:\n",
" resonance = chain.resonance\n",
" decay_product = set(chain.decay_products)\n",
" if decay_product == {K, Σ}:\n",
" ax = ax1\n",
" i1 += 1\n",
" i = i1\n",
" elif decay_product == {K, pbar}:\n",
" ax = ax2\n",
" i2 += 1\n",
" i = i2\n",
" else:\n",
" continue\n",
" ax.axvline(resonance.mass, label=f\"${resonance.latex}$\", c=f\"C{i}\", ls=\"dashed\")\n",
"for ax in axes:\n",
" ax.legend(fontsize=12)\n",
"plt.show()"
]
}
Expand Down

0 comments on commit 13a1561

Please sign in to comment.