From 5704f99e986301195f74fb1d3ba3eb897350d437 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 1 Jun 2023 18:20:09 +0200 Subject: [PATCH 01/17] DOC: show how to compute gradient with autodiff --- .cspell.json | 2 + docs/report/022.ipynb | 557 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 559 insertions(+) create mode 100644 docs/report/022.ipynb diff --git a/.cspell.json b/.cspell.json index 3dc6cd16..6bc30402 100644 --- a/.cspell.json +++ b/.cspell.json @@ -69,6 +69,7 @@ "analyticity", "argand", "Atlassian", + "autodiff", "autograd", "blatt", "breit", @@ -130,6 +131,7 @@ "arctan", "asarray", "asdot", + "aslatex", "astype", "autolaunch", "autonumbering", diff --git a/docs/report/022.ipynb b/docs/report/022.ipynb new file mode 100644 index 00000000..3c23373e --- /dev/null +++ b/docs/report/022.ipynb @@ -0,0 +1,557 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hideCode": true, + "hideOutput": true, + "hidePrompt": true, + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "skip" + }, + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "%config InlineBackend.figure_formats = ['svg']\n", + "import os\n", + "\n", + "STATIC_WEB_PAGE = {\"EXECUTE_NB\", \"READTHEDOCS\"}.intersection(os.environ)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{autolink-concat}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "````{margin}\n", + "```{spec} Gradient with autodiff\n", + ":id: TR-022\n", + ":status: WIP\n", + ":tags: tensorwaves\n", + "\n", + "In this report, we investigate whether autodiff can be be used to analytically compute the gradient of an amplitude model. The suspicion is that autodiff cannot handle large expressions well, because the chain rule results in an excessive number of computational nodes for the gradient of the function.\n", + "```\n", + "````" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "# Gradient with autodiff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "%pip install -q \"tensorwaves[jax,pwa]@git+https://github.com/ComPWA/tensorwaves@order-function-args\" ampform~=0.14 qrules~=0.9.8" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import inspect\n", + "import os\n", + "\n", + "import ampform\n", + "import jax\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import qrules\n", + "from ampform.dynamics.builder import (\n", + " create_non_dynamic_with_ff,\n", + " create_relativistic_breit_wigner_with_ff,\n", + ")\n", + "from ampform.io import aslatex\n", + "from IPython.display import Latex\n", + "from jax.tree_util import Partial\n", + "from matplotlib import cm\n", + "from tensorwaves.data import (\n", + " IntensityDistributionGenerator,\n", + " SympyDataTransformer,\n", + " TFPhaseSpaceGenerator,\n", + " TFUniformRealNumberGenerator,\n", + " TFWeightedPhaseSpaceGenerator,\n", + ")\n", + "from tensorwaves.function.sympy import create_function, create_parametrized_function\n", + "\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "## Formulate model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "reaction = qrules.generate_transitions(\n", + " initial_state=(\"J/psi(1S)\", [-1, +1]),\n", + " final_state=[\"gamma\", \"pi0\", \"pi0\"],\n", + " allowed_intermediate_particles=[\"a(0)\", \"f(0)\", \"omega\"],\n", + " allowed_interaction_types=[\"strong\", \"EM\"],\n", + " formalism=\"helicity\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "import graphviz\n", + "\n", + "dot = qrules.io.asdot(reaction, collapse_graphs=True)\n", + "graphviz.Source(dot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "full-width" + ] + }, + "outputs": [], + "source": [ + "model_builder = ampform.get_builder(reaction)\n", + "model_builder.adapter.permutate_registered_topologies()\n", + "model_builder.set_dynamics(\"J/psi(1S)\", create_non_dynamic_with_ff)\n", + "for name in reaction.get_intermediate_particles().names:\n", + " model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)\n", + "model = model_builder.formulate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "model.intensity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input", + "hide-output" + ] + }, + "outputs": [], + "source": [ + "selection = {k: v for i, (k, v) in enumerate(model.amplitudes.items()) if i < 3}\n", + "src = aslatex(selection)\n", + "del selection\n", + "Latex(src)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, + "source": [ + "## Generate data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "rng = TFUniformRealNumberGenerator(seed=0)\n", + "phsp_generator = TFPhaseSpaceGenerator(\n", + " initial_state_mass=reaction.initial_state[-1].mass,\n", + " final_state_masses={i: p.mass for i, p in reaction.final_state.items()},\n", + ")\n", + "phsp_momenta = phsp_generator.generate(1_000_000, rng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "helicity_transformer = SympyDataTransformer.from_sympy(\n", + " model.kinematic_variables, backend=\"jax\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "skip-flake8" + ] + }, + "outputs": [], + "source": [ + "unfolded_expression = model.expression.doit()\n", + "substituted_expression = unfolded_expression.xreplace(model.parameter_defaults)\n", + "fixed_intensity_func = create_function(substituted_expression, backend=\"jax\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "full-width" + ] + }, + "outputs": [], + "source": [ + "weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(\n", + " initial_state_mass=reaction.initial_state[-1].mass,\n", + " final_state_masses={i: p.mass for i, p in reaction.final_state.items()},\n", + ")\n", + "data_generator = IntensityDistributionGenerator(\n", + " domain_generator=weighted_phsp_generator,\n", + " function=fixed_intensity_func,\n", + " domain_transformer=helicity_transformer,\n", + ")\n", + "data_momenta = data_generator.generate(100_000, rng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "list(helicity_transformer.functions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "phsp = helicity_transformer(phsp_momenta)\n", + "data = helicity_transformer(data_momenta)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "sorted(substituted_expression.free_symbols, key=str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "list(model.kinematic_variables)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "full-width", + "hide-input" + ] + }, + "outputs": [], + "source": [ + "resonances = sorted(reaction.get_intermediate_particles(), key=lambda p: p.mass)\n", + "evenly_spaced_interval = np.linspace(0, 1, len(resonances))\n", + "colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n", + "fig, ax = plt.subplots(figsize=(9, 4))\n", + "ax.hist(\n", + " np.real(data[\"m_12\"]),\n", + " bins=200,\n", + " alpha=0.5,\n", + " density=True,\n", + ")\n", + "ax.set_xlabel(\"$m$ [GeV]\")\n", + "for p, color in zip(resonances, colors):\n", + " ax.axvline(x=p.mass, linestyle=\"dotted\", label=p.name, color=color)\n", + "ax.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gradient creation with autodiff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "free_symbols = {\n", + " symbol: value\n", + " for symbol, value in model.parameter_defaults.items()\n", + " if not symbol.name.startswith(\"d\")\n", + "}\n", + "some_coefficient = next(s for s in free_symbols if s.name.startswith(\"C\"))\n", + "free_symbols.pop(some_coefficient)\n", + "fixed_symbols = {\n", + " symbol: value\n", + " for symbol, value in model.parameter_defaults.items()\n", + " if symbol not in free_symbols\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "src = aslatex({k: free_symbols[k] for k in sorted(free_symbols, key=str)})\n", + "Latex(src)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "skip-flake8" + ] + }, + "outputs": [], + "source": [ + "expression = unfolded_expression\n", + "expression = expression.xreplace(fixed_symbols)\n", + "intensity_func = create_parametrized_function(\n", + " expression, parameters=free_symbols, backend=\"jax\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "from IPython.display import Markdown\n", + "\n", + "Markdown(f\"Function has **{len(intensity_func.parameters)} free parameters**.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "sig = inspect.signature(intensity_func.function)\n", + "arg_names = tuple(sig.parameters)\n", + "data_columns = {\n", + " arg: data[key]\n", + " for arg, key in zip(arg_names, intensity_func.argument_order)\n", + " if key in data\n", + "}\n", + "data_columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "parameter_values = {\n", + " arg: complex(intensity_func.parameters[key]).real\n", + " for arg, key in zip(arg_names, intensity_func.argument_order)\n", + " if key in intensity_func.parameters\n", + "}\n", + "parameter_values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "func_with_data_inserted = Partial(intensity_func.function, *data_columns.values())\n", + "gradient_func = jax.jacfwd(\n", + " func_with_data_inserted,\n", + " argnums=range(len(parameter_values)),\n", + ")\n", + "gradient_func" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%time\n", + "_ = tuple(v.block_until_ready() for v in gradient_func(*parameter_values.values()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%time\n", + "gradient_values = gradient_func(*parameter_values.values())\n", + "gradient_values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "gradient_values[0].shape" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 2590f01cd389171b6db0a7db232324a084dadb09 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 1 Jun 2023 18:20:13 +0200 Subject: [PATCH 02/17] DOC: implement unbinned NLL fit with autodiff gradient --- .cspell.json | 4 ++ docs/report/022.ipynb | 157 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 159 insertions(+), 2 deletions(-) diff --git a/.cspell.json b/.cspell.json index 6bc30402..1632c7ac 100644 --- a/.cspell.json +++ b/.cspell.json @@ -174,6 +174,7 @@ "elif", "endswith", "eqnarray", + "errordef", "evaluatable", "expertsystem", "facecolor", @@ -198,6 +199,7 @@ "hspace", "hypotests", "imag", + "iminuit", "infty", "ioff", "iplt", @@ -234,6 +236,7 @@ "maxdepth", "maxsize", "meshgrid", + "migrad", "mname", "multiline", "mystnb", @@ -310,6 +313,7 @@ "threebody", "timeit", "toctree", + "tqdm", "treewise", "unevaluatable", "unsrt", diff --git a/docs/report/022.ipynb b/docs/report/022.ipynb index 3c23373e..ee794d86 100644 --- a/docs/report/022.ipynb +++ b/docs/report/022.ipynb @@ -237,7 +237,7 @@ " initial_state_mass=reaction.initial_state[-1].mass,\n", " final_state_masses={i: p.mass for i, p in reaction.final_state.items()},\n", ")\n", - "phsp_momenta = phsp_generator.generate(1_000_000, rng)" + "phsp_momenta = phsp_generator.generate(100_000, rng)" ] }, { @@ -285,7 +285,7 @@ " function=fixed_intensity_func,\n", " domain_transformer=helicity_transformer,\n", ")\n", - "data_momenta = data_generator.generate(100_000, rng)" + "data_momenta = data_generator.generate(10_000, rng)" ] }, { @@ -531,6 +531,159 @@ "source": [ "gradient_values[0].shape" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimize parameters" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Numerical gradient descent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "phsp_columns = {\n", + " arg: phsp[key]\n", + " for arg, key in zip(arg_names, intensity_func.argument_order)\n", + " if key in data\n", + "}\n", + "func_with_phsp_inserted = Partial(intensity_func.function, *phsp_columns.values())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "func_with_phsp_inserted(*parameter_values.values())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "\n", + "\n", + "# @jax.jit # Do not JIT here, otherwise jax.jacfwd crashes!\n", + "def estimator(args):\n", + " data_intensities = func_with_data_inserted(*args)\n", + " phsp_intensities = func_with_phsp_inserted(*args)\n", + " likelihoods = data_intensities / jnp.mean(phsp_intensities)\n", + " return -jnp.sum(jnp.log(likelihoods))\n", + "\n", + "\n", + "estimator(parameter_values.values())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "import iminuit\n", + "from tqdm.auto import tqdm\n", + "\n", + "PROGRESS_BAR = tqdm()\n", + "\n", + "\n", + "def estimator_with_progress_bar(*args, **kwargs):\n", + " estimator_value = estimator(*args, **kwargs)\n", + " PROGRESS_BAR.update()\n", + " PROGRESS_BAR.set_postfix({\"estimator\": estimator_value})\n", + " return estimator_value\n", + "\n", + "\n", + "starting_values = tuple(parameter_values.values())\n", + "optimizer = iminuit.Minuit(\n", + " estimator_with_progress_bar,\n", + " starting_values,\n", + " name=tuple(parameter_values),\n", + ")\n", + "optimizer.errors = tuple(\n", + " 0.1 * abs(x) if abs(x) != 0.0 else 0.1 for x in starting_values\n", + ")\n", + "optimizer.errordef = iminuit.Minuit.LIKELIHOOD\n", + "optimizer.migrad()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### With analytic gradient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "estimator_gradient = jax.jacfwd(estimator)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%time\n", + "estimator_gradient(tuple(parameter_values.values()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "PROGRESS_BAR = tqdm() # reset\n", + "autodiff_optimizer = iminuit.Minuit(\n", + " estimator_with_progress_bar,\n", + " starting_values,\n", + " grad=estimator_gradient, # analytic!\n", + " name=tuple(parameter_values),\n", + ")\n", + "autodiff_optimizer.errors = tuple(\n", + " 0.1 * abs(x) if abs(x) != 0.0 else 0.1 for x in starting_values\n", + ")\n", + "autodiff_optimizer.errordef = iminuit.Minuit.LIKELIHOOD\n", + "autodiff_optimizer.migrad()" + ] } ], "metadata": { From 85ce0b38a99ab7b268ed814358d2df18a306b1ed Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 1 Jun 2023 18:20:14 +0200 Subject: [PATCH 03/17] DOC: add summary of fit results --- .cspell.json | 1 + docs/report/022.ipynb | 48 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/.cspell.json b/.cspell.json index 1632c7ac..3eefc678 100644 --- a/.cspell.json +++ b/.cspell.json @@ -181,6 +181,7 @@ "facecolors", "figsize", "filterwarnings", + "fmin", "fontcolor", "fontsize", "framealpha", diff --git a/docs/report/022.ipynb b/docs/report/022.ipynb index ee794d86..ef82e7aa 100644 --- a/docs/report/022.ipynb +++ b/docs/report/022.ipynb @@ -364,7 +364,10 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "## Gradient creation with autodiff" ] @@ -534,7 +537,10 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true, + "tags": [] + }, "source": [ "## Optimize parameters" ] @@ -684,6 +690,44 @@ "autodiff_optimizer.errordef = iminuit.Minuit.LIKELIHOOD\n", "autodiff_optimizer.migrad()" ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Conclusion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "def compute_diff(minuit):\n", + " original_pars = np.array(starting_values)\n", + " optimized_pars = np.array([p.value for p in minuit.params])\n", + " diff = original_pars - optimized_pars\n", + " return np.sqrt(np.sum(np.abs(diff) ** 2)) / len(minuit.params)\n", + "\n", + "\n", + "src = f\"\"\"\n", + "| | numerical | autodiff |\n", + "|--|-----------|----------|\n", + "| time (s) | {optimizer.fmin.time:.1f} | {autodiff_optimizer.fmin.time:.1f} |\n", + "| average parameter offset | {compute_diff(optimizer):.4f} | {compute_diff(autodiff_optimizer):.4f} |\n", + "\"\"\"\n", + "Markdown(src)" + ] } ], "metadata": { From 61e638702bed65c570f82e446f859dc8daf279b4 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 20 Jun 2023 13:48:39 +0200 Subject: [PATCH 04/17] MAINT: rename TR to draft --- docs/report/{022.ipynb => draft.ipynb} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename docs/report/{022.ipynb => draft.ipynb} (99%) diff --git a/docs/report/022.ipynb b/docs/report/draft.ipynb similarity index 99% rename from docs/report/022.ipynb rename to docs/report/draft.ipynb index ef82e7aa..84feeed4 100644 --- a/docs/report/022.ipynb +++ b/docs/report/draft.ipynb @@ -41,7 +41,7 @@ "\n", "````{margin}\n", "```{spec} Gradient with autodiff\n", - ":id: TR-022\n", + ":id: TR-999\n", ":status: WIP\n", ":tags: tensorwaves\n", "\n", From 8d87d8ec9c38675731c0621e6f2ef0a95c1dda34 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 13 Jul 2023 14:45:52 +0200 Subject: [PATCH 05/17] MAINT: autpupdate pip constraints --- .constraints/py3.10.txt | 28 ++++++++++++++-------------- .constraints/py3.11.txt | 28 ++++++++++++++-------------- .constraints/py3.7.txt | 8 ++++---- .constraints/py3.8.txt | 28 ++++++++++++++-------------- .constraints/py3.9.txt | 30 +++++++++++++++--------------- 5 files changed, 61 insertions(+), 61 deletions(-) diff --git a/.constraints/py3.10.txt b/.constraints/py3.10.txt index 00cf381f..cff928ad 100644 --- a/.constraints/py3.10.txt +++ b/.constraints/py3.10.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.10.txt --strip-extras setup.py +# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.10.txt --resolver=backtracking --strip-extras # accessible-pygments==0.0.4 alabaster==0.7.13 @@ -12,12 +12,12 @@ argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 arrow==1.2.3 asttokens==2.2.1 -async-lru==2.0.2 +async-lru==2.0.3 attrs==23.1.0 babel==2.12.1 backcall==0.2.0 beautifulsoup4==4.12.2 -black==23.3.0 +black==23.7.0 bleach==6.0.0 cachetools==5.3.1 cattrs==23.1.2 @@ -25,7 +25,7 @@ certifi==2023.5.7 cffi==1.15.1 cfgv==3.3.1 chardet==5.1.0 -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 click==8.1.4 colorama==0.4.6 comm==0.1.3 @@ -41,14 +41,14 @@ exceptiongroup==1.1.2 executing==1.2.0 fastjsonschema==2.17.1 filelock==3.12.2 -fonttools==4.40.0 +fonttools==4.41.0 fqdn==1.5.1 graphviz==0.20.1 greenlet==2.0.2 identify==2.5.24 idna==3.4 imagesize==1.4.1 -importlib-metadata==6.7.0 +importlib-metadata==6.8.0 iniconfig==2.0.0 ipykernel==6.24.0 ipympl==0.9.3 @@ -60,7 +60,7 @@ jedi==0.18.2 jinja2==3.1.2 json5==0.9.14 jsonpointer==2.4 -jsonschema==4.18.0 +jsonschema==4.18.2 jsonschema-specifications==2023.6.1 jupyter==1.0.0 jupyter-cache==0.6.1 @@ -71,7 +71,7 @@ jupyter-events==0.6.3 jupyter-lsp==2.2.0 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 -jupyterlab==4.0.2 +jupyterlab==4.0.3 jupyterlab-code-formatter==2.2.1 jupyterlab-myst==2.0.1 jupyterlab-pygments==0.2.2 @@ -95,13 +95,13 @@ myst-parser==0.18.1 nbclassic==1.0.0 nbclient==0.6.8 nbconvert==7.6.0 -nbformat==5.9.0 +nbformat==5.9.1 nbmake==1.4.1 nest-asyncio==1.5.6 nodeenv==1.8.0 notebook==6.5.4 notebook-shim==0.2.3 -numpy==1.25.0 +numpy==1.25.1 overrides==7.3.1 packaging==23.1 pandocfilters==1.5.0 @@ -113,7 +113,7 @@ pillow==10.0.0 platformdirs==3.8.1 pluggy==1.2.0 pre-commit==3.3.3 -prometheus-client==0.17.0 +prometheus-client==0.17.1 prompt-toolkit==3.0.39 psutil==5.9.5 ptyprocess==0.7.0 @@ -140,8 +140,8 @@ requests==2.31.0 requests-file==1.5.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 -rpds-py==0.8.8 -ruff==0.0.277 +rpds-py==0.8.10 +ruff==0.0.278 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 @@ -188,7 +188,7 @@ webencodings==0.5.1 websocket-client==1.6.1 wheel==0.40.0 widgetsnbextension==4.0.8 -zipp==3.15.0 +zipp==3.16.1 # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/.constraints/py3.11.txt b/.constraints/py3.11.txt index bb96e4ef..0c651a9b 100644 --- a/.constraints/py3.11.txt +++ b/.constraints/py3.11.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.11 # by the following command: # -# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.11.txt --strip-extras setup.py +# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.11.txt --resolver=backtracking --strip-extras # accessible-pygments==0.0.4 alabaster==0.7.13 @@ -12,12 +12,12 @@ argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 arrow==1.2.3 asttokens==2.2.1 -async-lru==2.0.2 +async-lru==2.0.3 attrs==23.1.0 babel==2.12.1 backcall==0.2.0 beautifulsoup4==4.12.2 -black==23.3.0 +black==23.7.0 bleach==6.0.0 cachetools==5.3.1 cattrs==23.1.2 @@ -25,7 +25,7 @@ certifi==2023.5.7 cffi==1.15.1 cfgv==3.3.1 chardet==5.1.0 -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 click==8.1.4 colorama==0.4.6 comm==0.1.3 @@ -40,14 +40,14 @@ esbonio==0.16.1 executing==1.2.0 fastjsonschema==2.17.1 filelock==3.12.2 -fonttools==4.40.0 +fonttools==4.41.0 fqdn==1.5.1 graphviz==0.20.1 greenlet==2.0.2 identify==2.5.24 idna==3.4 imagesize==1.4.1 -importlib-metadata==6.7.0 +importlib-metadata==6.8.0 iniconfig==2.0.0 ipykernel==6.24.0 ipympl==0.9.3 @@ -59,7 +59,7 @@ jedi==0.18.2 jinja2==3.1.2 json5==0.9.14 jsonpointer==2.4 -jsonschema==4.18.0 +jsonschema==4.18.2 jsonschema-specifications==2023.6.1 jupyter==1.0.0 jupyter-cache==0.6.1 @@ -70,7 +70,7 @@ jupyter-events==0.6.3 jupyter-lsp==2.2.0 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 -jupyterlab==4.0.2 +jupyterlab==4.0.3 jupyterlab-code-formatter==2.2.1 jupyterlab-myst==2.0.1 jupyterlab-pygments==0.2.2 @@ -94,13 +94,13 @@ myst-parser==0.18.1 nbclassic==1.0.0 nbclient==0.6.8 nbconvert==7.6.0 -nbformat==5.9.0 +nbformat==5.9.1 nbmake==1.4.1 nest-asyncio==1.5.6 nodeenv==1.8.0 notebook==6.5.4 notebook-shim==0.2.3 -numpy==1.25.0 +numpy==1.25.1 overrides==7.3.1 packaging==23.1 pandocfilters==1.5.0 @@ -112,7 +112,7 @@ pillow==10.0.0 platformdirs==3.8.1 pluggy==1.2.0 pre-commit==3.3.3 -prometheus-client==0.17.0 +prometheus-client==0.17.1 prompt-toolkit==3.0.39 psutil==5.9.5 ptyprocess==0.7.0 @@ -139,8 +139,8 @@ requests==2.31.0 requests-file==1.5.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 -rpds-py==0.8.8 -ruff==0.0.277 +rpds-py==0.8.10 +ruff==0.0.278 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 @@ -186,7 +186,7 @@ webencodings==0.5.1 websocket-client==1.6.1 wheel==0.40.0 widgetsnbextension==4.0.8 -zipp==3.15.0 +zipp==3.16.1 # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/.constraints/py3.7.txt b/.constraints/py3.7.txt index 1511de00..9eea8b37 100644 --- a/.constraints/py3.7.txt +++ b/.constraints/py3.7.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.7 # by the following command: # -# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.7.txt --strip-extras setup.py +# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.7.txt --resolver=backtracking --strip-extras # accessible-pygments==0.0.4 aiofiles==22.1.0 @@ -26,7 +26,7 @@ certifi==2023.5.7 cffi==1.15.1 cfgv==3.3.1 chardet==5.1.0 -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 click==8.1.4 colorama==0.4.6 cycler==0.11.0 @@ -113,7 +113,7 @@ pkgutil-resolve-name==1.3.10 platformdirs==3.8.1 pluggy==1.2.0 pre-commit==2.21.0 -prometheus-client==0.17.0 +prometheus-client==0.17.1 prompt-toolkit==3.0.39 psutil==5.9.5 ptyprocess==0.7.0 @@ -140,7 +140,7 @@ requests==2.31.0 requests-file==1.5.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 -ruff==0.0.277 +ruff==0.0.278 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 diff --git a/.constraints/py3.8.txt b/.constraints/py3.8.txt index 82593697..6cddc22b 100644 --- a/.constraints/py3.8.txt +++ b/.constraints/py3.8.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.8 # by the following command: # -# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.8.txt --strip-extras setup.py +# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.8.txt --resolver=backtracking --strip-extras # accessible-pygments==0.0.4 alabaster==0.7.13 @@ -12,12 +12,12 @@ argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 arrow==1.2.3 asttokens==2.2.1 -async-lru==2.0.2 +async-lru==2.0.3 attrs==23.1.0 babel==2.12.1 backcall==0.2.0 beautifulsoup4==4.12.2 -black==23.3.0 +black==23.7.0 bleach==6.0.0 cachetools==5.3.1 cattrs==23.1.2 @@ -25,7 +25,7 @@ certifi==2023.5.7 cffi==1.15.1 cfgv==3.3.1 chardet==5.1.0 -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 click==8.1.4 colorama==0.4.6 comm==0.1.3 @@ -41,15 +41,15 @@ exceptiongroup==1.1.2 executing==1.2.0 fastjsonschema==2.17.1 filelock==3.12.2 -fonttools==4.40.0 +fonttools==4.41.0 fqdn==1.5.1 graphviz==0.20.1 greenlet==2.0.2 identify==2.5.24 idna==3.4 imagesize==1.4.1 -importlib-metadata==6.7.0 -importlib-resources==5.12.0 +importlib-metadata==6.8.0 +importlib-resources==6.0.0 iniconfig==2.0.0 ipykernel==6.24.0 ipympl==0.9.3 @@ -61,7 +61,7 @@ jedi==0.18.2 jinja2==3.1.2 json5==0.9.14 jsonpointer==2.4 -jsonschema==4.18.0 +jsonschema==4.18.2 jsonschema-specifications==2023.6.1 jupyter==1.0.0 jupyter-cache==0.6.1 @@ -72,7 +72,7 @@ jupyter-events==0.6.3 jupyter-lsp==2.2.0 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 -jupyterlab==4.0.2 +jupyterlab==4.0.3 jupyterlab-code-formatter==2.2.1 jupyterlab-myst==2.0.1 jupyterlab-pygments==0.2.2 @@ -96,7 +96,7 @@ myst-parser==0.18.1 nbclassic==1.0.0 nbclient==0.6.8 nbconvert==7.6.0 -nbformat==5.9.0 +nbformat==5.9.1 nbmake==1.4.1 nest-asyncio==1.5.6 nodeenv==1.8.0 @@ -115,7 +115,7 @@ pkgutil-resolve-name==1.3.10 platformdirs==3.8.1 pluggy==1.2.0 pre-commit==3.3.3 -prometheus-client==0.17.0 +prometheus-client==0.17.1 prompt-toolkit==3.0.39 psutil==5.9.5 ptyprocess==0.7.0 @@ -143,8 +143,8 @@ requests==2.31.0 requests-file==1.5.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 -rpds-py==0.8.8 -ruff==0.0.277 +rpds-py==0.8.10 +ruff==0.0.278 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 @@ -191,7 +191,7 @@ webencodings==0.5.1 websocket-client==1.6.1 wheel==0.40.0 widgetsnbextension==4.0.8 -zipp==3.15.0 +zipp==3.16.1 # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/.constraints/py3.9.txt b/.constraints/py3.9.txt index d56df4c5..37a60334 100644 --- a/.constraints/py3.9.txt +++ b/.constraints/py3.9.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.9 # by the following command: # -# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.9.txt --strip-extras setup.py +# pip-compile --extra=dev --no-annotate --output-file=.constraints/py3.9.txt --resolver=backtracking --strip-extras # accessible-pygments==0.0.4 alabaster==0.7.13 @@ -12,12 +12,12 @@ argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 arrow==1.2.3 asttokens==2.2.1 -async-lru==2.0.2 +async-lru==2.0.3 attrs==23.1.0 babel==2.12.1 backcall==0.2.0 beautifulsoup4==4.12.2 -black==23.3.0 +black==23.7.0 bleach==6.0.0 cachetools==5.3.1 cattrs==23.1.2 @@ -25,7 +25,7 @@ certifi==2023.5.7 cffi==1.15.1 cfgv==3.3.1 chardet==5.1.0 -charset-normalizer==3.1.0 +charset-normalizer==3.2.0 click==8.1.4 colorama==0.4.6 comm==0.1.3 @@ -41,15 +41,15 @@ exceptiongroup==1.1.2 executing==1.2.0 fastjsonschema==2.17.1 filelock==3.12.2 -fonttools==4.40.0 +fonttools==4.41.0 fqdn==1.5.1 graphviz==0.20.1 greenlet==2.0.2 identify==2.5.24 idna==3.4 imagesize==1.4.1 -importlib-metadata==6.7.0 -importlib-resources==5.12.0 +importlib-metadata==6.8.0 +importlib-resources==6.0.0 iniconfig==2.0.0 ipykernel==6.24.0 ipympl==0.9.3 @@ -61,7 +61,7 @@ jedi==0.18.2 jinja2==3.1.2 json5==0.9.14 jsonpointer==2.4 -jsonschema==4.18.0 +jsonschema==4.18.2 jsonschema-specifications==2023.6.1 jupyter==1.0.0 jupyter-cache==0.6.1 @@ -72,7 +72,7 @@ jupyter-events==0.6.3 jupyter-lsp==2.2.0 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 -jupyterlab==4.0.2 +jupyterlab==4.0.3 jupyterlab-code-formatter==2.2.1 jupyterlab-myst==2.0.1 jupyterlab-pygments==0.2.2 @@ -96,13 +96,13 @@ myst-parser==0.18.1 nbclassic==1.0.0 nbclient==0.6.8 nbconvert==7.6.0 -nbformat==5.9.0 +nbformat==5.9.1 nbmake==1.4.1 nest-asyncio==1.5.6 nodeenv==1.8.0 notebook==6.5.4 notebook-shim==0.2.3 -numpy==1.25.0 +numpy==1.25.1 overrides==7.3.1 packaging==23.1 pandocfilters==1.5.0 @@ -114,7 +114,7 @@ pillow==10.0.0 platformdirs==3.8.1 pluggy==1.2.0 pre-commit==3.3.3 -prometheus-client==0.17.0 +prometheus-client==0.17.1 prompt-toolkit==3.0.39 psutil==5.9.5 ptyprocess==0.7.0 @@ -141,8 +141,8 @@ requests==2.31.0 requests-file==1.5.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 -rpds-py==0.8.8 -ruff==0.0.277 +rpds-py==0.8.10 +ruff==0.0.278 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 @@ -189,7 +189,7 @@ webencodings==0.5.1 websocket-client==1.6.1 wheel==0.40.0 widgetsnbextension==4.0.8 -zipp==3.15.0 +zipp==3.16.1 # The following packages are considered to be unsafe in a requirements file: # setuptools From a018b5e5f21875a5a2a57ad7ab7fbdb16ee92052 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 13 Jul 2023 16:33:40 +0200 Subject: [PATCH 06/17] DX: install `jupyter-resource-usage` --- .constraints/py3.10.txt | 1 + .constraints/py3.11.txt | 1 + .constraints/py3.7.txt | 1 + .constraints/py3.8.txt | 1 + .constraints/py3.9.txt | 1 + setup.cfg | 1 + 6 files changed, 6 insertions(+) diff --git a/.constraints/py3.10.txt b/.constraints/py3.10.txt index cff928ad..542fc06d 100644 --- a/.constraints/py3.10.txt +++ b/.constraints/py3.10.txt @@ -69,6 +69,7 @@ jupyter-console==6.6.3 jupyter-core==5.3.1 jupyter-events==0.6.3 jupyter-lsp==2.2.0 +jupyter-resource-usage==0.7.2 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 jupyterlab==4.0.3 diff --git a/.constraints/py3.11.txt b/.constraints/py3.11.txt index 0c651a9b..7c9b4c3c 100644 --- a/.constraints/py3.11.txt +++ b/.constraints/py3.11.txt @@ -68,6 +68,7 @@ jupyter-console==6.6.3 jupyter-core==5.3.1 jupyter-events==0.6.3 jupyter-lsp==2.2.0 +jupyter-resource-usage==0.7.2 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 jupyterlab==4.0.3 diff --git a/.constraints/py3.7.txt b/.constraints/py3.7.txt index 9eea8b37..552c49b3 100644 --- a/.constraints/py3.7.txt +++ b/.constraints/py3.7.txt @@ -67,6 +67,7 @@ jupyter-client==7.4.9 jupyter-console==6.6.3 jupyter-core==4.12.0 jupyter-events==0.6.3 +jupyter-resource-usage==0.7.2 jupyter-server==1.24.0 jupyter-server-fileid==0.9.0 jupyter-server-ydoc==0.8.0 diff --git a/.constraints/py3.8.txt b/.constraints/py3.8.txt index 6cddc22b..f9aea9d3 100644 --- a/.constraints/py3.8.txt +++ b/.constraints/py3.8.txt @@ -70,6 +70,7 @@ jupyter-console==6.6.3 jupyter-core==5.3.1 jupyter-events==0.6.3 jupyter-lsp==2.2.0 +jupyter-resource-usage==0.7.2 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 jupyterlab==4.0.3 diff --git a/.constraints/py3.9.txt b/.constraints/py3.9.txt index 37a60334..55bcf868 100644 --- a/.constraints/py3.9.txt +++ b/.constraints/py3.9.txt @@ -70,6 +70,7 @@ jupyter-console==6.6.3 jupyter-core==5.3.1 jupyter-events==0.6.3 jupyter-lsp==2.2.0 +jupyter-resource-usage==0.7.2 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 jupyterlab==4.0.3 diff --git a/setup.cfg b/setup.cfg index 57a1184b..136052b2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -69,6 +69,7 @@ sty = pre-commit >=1.4.0 jupyter = %(doc)s + jupyter-resource-usage jupyterlab jupyterlab-code-formatter jupyterlab-myst From 18d72e5e2fe1c5efcc385301557c30e1c1499e11 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 11 Aug 2023 10:50:10 +0200 Subject: [PATCH 07/17] Switch to jacrev` --- docs/report/draft.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/report/draft.ipynb b/docs/report/draft.ipynb index 84feeed4..199bb33f 100644 --- a/docs/report/draft.ipynb +++ b/docs/report/draft.ipynb @@ -37,7 +37,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", + "\n", "\n", "````{margin}\n", "```{spec} Gradient with autodiff\n", @@ -492,7 +492,7 @@ "outputs": [], "source": [ "func_with_data_inserted = Partial(intensity_func.function, *data_columns.values())\n", - "gradient_func = jax.jacfwd(\n", + "gradient_func = jax.jacrev(\n", " func_with_data_inserted,\n", " argnums=range(len(parameter_values)),\n", ")\n", @@ -590,7 +590,7 @@ "import jax.numpy as jnp\n", "\n", "\n", - "# @jax.jit # Do not JIT here, otherwise jax.jacfwd crashes!\n", + "# @jax.jit # Do not JIT here, otherwise jax.jacrev crashes!\n", "def estimator(args):\n", " data_intensities = func_with_data_inserted(*args)\n", " phsp_intensities = func_with_phsp_inserted(*args)\n", @@ -652,7 +652,7 @@ }, "outputs": [], "source": [ - "estimator_gradient = jax.jacfwd(estimator)" + "estimator_gradient = jax.jacrev(estimator)" ] }, { From d179312510a0f31ada2b6f1038e5d1ba3a712267 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 11 Aug 2023 12:03:05 +0200 Subject: [PATCH 08/17] FIX: store `gradient_values` %%time undoes variable storing --- docs/report/draft.ipynb | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/docs/report/draft.ipynb b/docs/report/draft.ipynb index 199bb33f..bd5cc023 100644 --- a/docs/report/draft.ipynb +++ b/docs/report/draft.ipynb @@ -507,7 +507,7 @@ }, "outputs": [], "source": [ - "%%time\n", + "%%time # compilation\n", "_ = tuple(v.block_until_ready() for v in gradient_func(*parameter_values.values()))" ] }, @@ -519,19 +519,7 @@ }, "outputs": [], "source": [ - "%%time\n", "gradient_values = gradient_func(*parameter_values.values())\n", - "gradient_values" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ "gradient_values[0].shape" ] }, From add3046be451d04816e02587c388b5f610c1f707 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 11 Aug 2023 12:06:53 +0200 Subject: [PATCH 09/17] DX: ignore `jacrev` globally --- .cspell.json | 1 + docs/report/draft.ipynb | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.cspell.json b/.cspell.json index 67c4aab1..39e1f462 100644 --- a/.cspell.json +++ b/.cspell.json @@ -213,6 +213,7 @@ "ipywidgets", "isinstance", "isort", + "jacrev", "jaxlib", "jupyterlab", "kernelspec", diff --git a/docs/report/draft.ipynb b/docs/report/draft.ipynb index bd5cc023..d0ee39d2 100644 --- a/docs/report/draft.ipynb +++ b/docs/report/draft.ipynb @@ -37,7 +37,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", + "\n", "\n", "````{margin}\n", "```{spec} Gradient with autodiff\n", From 448b62d44b74dbf5c86692bc98f416ea9347b75c Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 11 Aug 2023 12:10:29 +0200 Subject: [PATCH 10/17] FIX: remove comments after `%%time%` --- docs/report/draft.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/report/draft.ipynb b/docs/report/draft.ipynb index d0ee39d2..835e73c2 100644 --- a/docs/report/draft.ipynb +++ b/docs/report/draft.ipynb @@ -507,7 +507,7 @@ }, "outputs": [], "source": [ - "%%time # compilation\n", + "%%time\n", "_ = tuple(v.block_until_ready() for v in gradient_func(*parameter_values.values()))" ] }, From c58019daded4ea5a72878f36847c8b5eda312463 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 11 Aug 2023 12:15:14 +0200 Subject: [PATCH 11/17] MAINT: autoupdate `jupyter-resource-usage` --- .constraints/py3.10.txt | 2 +- .constraints/py3.11.txt | 2 +- .constraints/py3.8.txt | 2 +- .constraints/py3.9.txt | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.constraints/py3.10.txt b/.constraints/py3.10.txt index 0cb6ffba..b622eb42 100644 --- a/.constraints/py3.10.txt +++ b/.constraints/py3.10.txt @@ -69,7 +69,7 @@ jupyter-console==6.6.3 jupyter-core==5.3.1 jupyter-events==0.6.3 jupyter-lsp==2.2.0 -jupyter-resource-usage==0.7.2 +jupyter-resource-usage==1.0.0 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 jupyterlab==4.0.3 diff --git a/.constraints/py3.11.txt b/.constraints/py3.11.txt index 59709be3..51a5f2cd 100644 --- a/.constraints/py3.11.txt +++ b/.constraints/py3.11.txt @@ -68,7 +68,7 @@ jupyter-console==6.6.3 jupyter-core==5.3.1 jupyter-events==0.6.3 jupyter-lsp==2.2.0 -jupyter-resource-usage==0.7.2 +jupyter-resource-usage==1.0.0 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 jupyterlab==4.0.3 diff --git a/.constraints/py3.8.txt b/.constraints/py3.8.txt index 82d77904..3ec74ec7 100644 --- a/.constraints/py3.8.txt +++ b/.constraints/py3.8.txt @@ -70,7 +70,7 @@ jupyter-console==6.6.3 jupyter-core==5.3.1 jupyter-events==0.6.3 jupyter-lsp==2.2.0 -jupyter-resource-usage==0.7.2 +jupyter-resource-usage==1.0.0 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 jupyterlab==4.0.3 diff --git a/.constraints/py3.9.txt b/.constraints/py3.9.txt index a08d8517..73a035f2 100644 --- a/.constraints/py3.9.txt +++ b/.constraints/py3.9.txt @@ -70,7 +70,7 @@ jupyter-console==6.6.3 jupyter-core==5.3.1 jupyter-events==0.6.3 jupyter-lsp==2.2.0 -jupyter-resource-usage==0.7.2 +jupyter-resource-usage==1.0.0 jupyter-server==2.7.0 jupyter-server-terminals==0.4.4 jupyterlab==4.0.3 From 922a85f85064f56147d007f88540f508beeca25e Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 11 Aug 2023 14:33:30 +0200 Subject: [PATCH 12/17] FIX: switch back to `jacfwd` --- .cspell.json | 2 +- docs/report/draft.ipynb | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.cspell.json b/.cspell.json index 38391e16..b859ff23 100644 --- a/.cspell.json +++ b/.cspell.json @@ -213,7 +213,7 @@ "ipywidgets", "isinstance", "isort", - "jacrev", + "jacfwd", "jaxlib", "joinpath", "juliaup", diff --git a/docs/report/draft.ipynb b/docs/report/draft.ipynb index 835e73c2..7742fc2e 100644 --- a/docs/report/draft.ipynb +++ b/docs/report/draft.ipynb @@ -492,7 +492,7 @@ "outputs": [], "source": [ "func_with_data_inserted = Partial(intensity_func.function, *data_columns.values())\n", - "gradient_func = jax.jacrev(\n", + "gradient_func = jax.jacfwd(\n", " func_with_data_inserted,\n", " argnums=range(len(parameter_values)),\n", ")\n", @@ -578,7 +578,7 @@ "import jax.numpy as jnp\n", "\n", "\n", - "# @jax.jit # Do not JIT here, otherwise jax.jacrev crashes!\n", + "# @jax.jit # Do not JIT here, otherwise jax.jacfwd crashes!\n", "def estimator(args):\n", " data_intensities = func_with_data_inserted(*args)\n", " phsp_intensities = func_with_phsp_inserted(*args)\n", @@ -640,7 +640,7 @@ }, "outputs": [], "source": [ - "estimator_gradient = jax.jacrev(estimator)" + "estimator_gradient = jax.jacfwd(estimator)" ] }, { From dbbdc4802c882450162aed2d04bdd8e53aaa8dc1 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:38:37 +0200 Subject: [PATCH 13/17] DOC: print JAX precision --- .cspell.json | 1 + docs/report/999.ipynb | 48 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/.cspell.json b/.cspell.json index 058edd9b..743535c3 100644 --- a/.cspell.json +++ b/.cspell.json @@ -156,6 +156,7 @@ "docnb", "docstrings", "dotprint", + "dtype", "einsum", "elif", "endswith", diff --git a/docs/report/999.ipynb b/docs/report/999.ipynb index 3c5d7600..19ccc286 100644 --- a/docs/report/999.ipynb +++ b/docs/report/999.ipynb @@ -114,6 +114,27 @@ "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "\n", + "\n", + "def print_jax_precision() -> None:\n", + " arr = jnp.array([1.0])\n", + " print(arr.dtype)\n", + "\n", + "\n", + "print_jax_precision()" + ] + }, { "cell_type": "markdown", "metadata": { @@ -526,7 +547,6 @@ { "cell_type": "markdown", "metadata": { - "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ @@ -575,9 +595,6 @@ }, "outputs": [], "source": [ - "import jax.numpy as jnp\n", - "\n", - "\n", "# @jax.jit # Do not JIT here, otherwise jax.jacfwd crashes!\n", "def estimator(args):\n", " data_intensities = func_with_data_inserted(*args)\n", @@ -589,6 +606,15 @@ "estimator(parameter_values.values())" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print_jax_precision()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -632,6 +658,18 @@ "### With analytic gradient" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from jax.config import config\n", + "\n", + "config.update(\"jax_enable_x64\", False)\n", + "print_jax_precision()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -737,7 +775,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.18" } }, "nbformat": 4, From 3f3dcf04e1f26f9bceb3e0c13069a019508d6257 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Sun, 29 Oct 2023 17:05:11 +0100 Subject: [PATCH 14/17] DX: uninstall `jupyter-resource-usage` --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d3f9a2b1..777db99c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,6 @@ format = [ ] jupyter = [ "compwa-org[doc]", - "jupyter-resource-usage", "jupyterlab", "jupyterlab-code-formatter", "jupyterlab-lsp", From ea22ddd9e44eb64aeb02476f9451acf25e3a3bcf Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Sun, 29 Oct 2023 17:09:57 +0100 Subject: [PATCH 15/17] MAINT: autoupdate dependencies --- .constraints/py3.10.txt | 68 ++++++++++++++++++++--------------------- .constraints/py3.11.txt | 68 ++++++++++++++++++++--------------------- .constraints/py3.7.txt | 24 +++++++-------- .constraints/py3.8.txt | 66 +++++++++++++++++++-------------------- .constraints/py3.9.txt | 68 ++++++++++++++++++++--------------------- 5 files changed, 147 insertions(+), 147 deletions(-) diff --git a/.constraints/py3.10.txt b/.constraints/py3.10.txt index ce718bab..f0966ac0 100644 --- a/.constraints/py3.10.txt +++ b/.constraints/py3.10.txt @@ -10,20 +10,20 @@ anyio==4.0.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 -asttokens==2.4.0 +asttokens==2.4.1 async-lru==2.0.4 attrs==23.1.0 -babel==2.13.0 +babel==2.13.1 backcall==0.2.0 beautifulsoup4==4.12.2 -black==23.9.1 +black==23.10.1 bleach==6.1.0 -cachetools==5.3.1 +cachetools==5.3.2 certifi==2023.7.22 cffi==1.16.0 cfgv==3.4.0 chardet==5.2.0 -charset-normalizer==3.3.0 +charset-normalizer==3.3.1 click==8.1.7 colorama==0.4.6 comm==0.1.4 @@ -33,22 +33,22 @@ debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 distlib==0.3.7 -docstring-to-markdown==0.12 +docstring-to-markdown==0.13 docutils==0.17.1 exceptiongroup==1.1.3 -executing==2.0.0 +executing==2.0.1 fastjsonschema==2.18.1 -filelock==3.12.4 +filelock==3.13.0 fonttools==4.43.1 fqdn==1.5.1 graphviz==0.20.1 -greenlet==3.0.0 -identify==2.5.30 +greenlet==3.0.1 +identify==2.5.31 idna==3.4 imagesize==1.4.1 importlib-metadata==6.8.0 iniconfig==2.0.0 -ipykernel==6.25.2 +ipykernel==6.26.0 ipympl==0.9.3 ipython==8.16.1 ipython-genutils==0.2.0 @@ -62,17 +62,17 @@ jsonschema==4.19.1 jsonschema-specifications==2023.7.1 jupyter==1.0.0 jupyter-cache==0.6.1 -jupyter-client==8.3.1 +jupyter-client==8.5.0 jupyter-console==6.6.3 -jupyter-core==5.3.2 -jupyter-events==0.7.0 +jupyter-core==5.4.0 +jupyter-events==0.8.0 jupyter-lsp==2.2.0 -jupyter-server==2.7.3 +jupyter-server==2.9.1 jupyter-server-terminals==0.4.4 -jupyterlab==4.0.6 +jupyterlab==4.0.7 jupyterlab-code-formatter==2.2.1 jupyterlab-lsp==5.0.0 -jupyterlab-myst==2.0.2 +jupyterlab-myst==2.1.0 jupyterlab-pygments==0.2.2 jupyterlab-server==2.25.0 jupyterlab-widgets==3.0.9 @@ -93,12 +93,12 @@ myst-parser==0.18.1 nbclient==0.6.8 nbconvert==7.9.2 nbformat==5.9.2 -nbmake==1.4.5 +nbmake==1.4.6 nest-asyncio==1.5.8 nodeenv==1.8.0 -notebook==7.0.4 +notebook==7.0.6 notebook-shim==0.2.3 -numpy==1.26.0 +numpy==1.26.1 overrides==7.4.0 packaging==23.2 pandocfilters==1.5.0 @@ -106,39 +106,39 @@ parso==0.8.3 pathspec==0.11.2 pexpect==4.8.0 pickleshare==0.7.5 -pillow==10.0.1 +pillow==10.1.0 platformdirs==3.11.0 pluggy==1.3.0 -pre-commit==3.4.0 +pre-commit==3.5.0 prometheus-client==0.17.1 prompt-toolkit==3.0.39 -psutil==5.9.5 +psutil==5.9.6 ptyprocess==0.7.0 pure-eval==0.2.2 pybtex==0.24.0 pybtex-docutils==1.0.3 pycparser==2.21 -pydata-sphinx-theme==0.14.1 +pydata-sphinx-theme==0.14.2 pygments==2.16.1 pyparsing==3.1.1 pyproject-api==1.6.1 -pytest==7.4.2 +pytest==7.4.3 python-dateutil==2.8.2 python-json-logger==2.0.7 python-lsp-jsonrpc==1.1.2 python-lsp-server==1.8.2 -pytoolconfig==1.2.5 +pytoolconfig==1.2.6 pyyaml==6.0.1 pyzmq==25.1.1 qtconsole==5.4.4 -qtpy==2.4.0 +qtpy==2.4.1 referencing==0.30.2 requests==2.31.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rope==1.10.0 -rpds-py==0.10.4 -ruff==0.0.292 +rpds-py==0.10.6 +ruff==0.1.3 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 @@ -153,7 +153,7 @@ sphinx-copybutton==0.5.2 sphinx-design==0.5.0 sphinx-hep-pdgref==0.2.0 sphinx-remove-toctrees==0.0.3 -sphinx-thebe==0.2.1 +sphinx-thebe==0.3.0 sphinx-togglebutton==0.3.2 sphinxcontrib-applehelp==1.0.7 sphinxcontrib-bibtex==2.6.1 @@ -162,7 +162,7 @@ sphinxcontrib-htmlhelp==2.0.4 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.6 sphinxcontrib-serializinghtml==1.1.9 -sqlalchemy==2.0.21 +sqlalchemy==2.0.22 stack-data==0.6.3 tabulate==0.9.0 terminado==0.17.1 @@ -170,13 +170,13 @@ tinycss2==1.2.1 tomli==2.0.1 tornado==6.3.3 tox==4.11.3 -traitlets==5.11.2 +traitlets==5.12.0 types-python-dateutil==2.8.19.14 typing-extensions==4.8.0 ujson==5.8.0 uri-template==1.3.0 -urllib3==2.0.6 -virtualenv==20.24.5 +urllib3==2.0.7 +virtualenv==20.24.6 wcwidth==0.2.8 webcolors==1.13 webencodings==0.5.1 diff --git a/.constraints/py3.11.txt b/.constraints/py3.11.txt index e46a7b85..b07b760c 100644 --- a/.constraints/py3.11.txt +++ b/.constraints/py3.11.txt @@ -10,20 +10,20 @@ anyio==4.0.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 -asttokens==2.4.0 +asttokens==2.4.1 async-lru==2.0.4 attrs==23.1.0 -babel==2.13.0 +babel==2.13.1 backcall==0.2.0 beautifulsoup4==4.12.2 -black==23.9.1 +black==23.10.1 bleach==6.1.0 -cachetools==5.3.1 +cachetools==5.3.2 certifi==2023.7.22 cffi==1.16.0 cfgv==3.4.0 chardet==5.2.0 -charset-normalizer==3.3.0 +charset-normalizer==3.3.1 click==8.1.7 colorama==0.4.6 comm==0.1.4 @@ -33,21 +33,21 @@ debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 distlib==0.3.7 -docstring-to-markdown==0.12 +docstring-to-markdown==0.13 docutils==0.17.1 -executing==2.0.0 +executing==2.0.1 fastjsonschema==2.18.1 -filelock==3.12.4 +filelock==3.13.0 fonttools==4.43.1 fqdn==1.5.1 graphviz==0.20.1 -greenlet==3.0.0 -identify==2.5.30 +greenlet==3.0.1 +identify==2.5.31 idna==3.4 imagesize==1.4.1 importlib-metadata==6.8.0 iniconfig==2.0.0 -ipykernel==6.25.2 +ipykernel==6.26.0 ipympl==0.9.3 ipython==8.16.1 ipython-genutils==0.2.0 @@ -61,17 +61,17 @@ jsonschema==4.19.1 jsonschema-specifications==2023.7.1 jupyter==1.0.0 jupyter-cache==0.6.1 -jupyter-client==8.3.1 +jupyter-client==8.5.0 jupyter-console==6.6.3 -jupyter-core==5.3.2 -jupyter-events==0.7.0 +jupyter-core==5.4.0 +jupyter-events==0.8.0 jupyter-lsp==2.2.0 -jupyter-server==2.7.3 +jupyter-server==2.9.1 jupyter-server-terminals==0.4.4 -jupyterlab==4.0.6 +jupyterlab==4.0.7 jupyterlab-code-formatter==2.2.1 jupyterlab-lsp==5.0.0 -jupyterlab-myst==2.0.2 +jupyterlab-myst==2.1.0 jupyterlab-pygments==0.2.2 jupyterlab-server==2.25.0 jupyterlab-widgets==3.0.9 @@ -92,12 +92,12 @@ myst-parser==0.18.1 nbclient==0.6.8 nbconvert==7.9.2 nbformat==5.9.2 -nbmake==1.4.5 +nbmake==1.4.6 nest-asyncio==1.5.8 nodeenv==1.8.0 -notebook==7.0.4 +notebook==7.0.6 notebook-shim==0.2.3 -numpy==1.26.0 +numpy==1.26.1 overrides==7.4.0 packaging==23.2 pandocfilters==1.5.0 @@ -105,39 +105,39 @@ parso==0.8.3 pathspec==0.11.2 pexpect==4.8.0 pickleshare==0.7.5 -pillow==10.0.1 +pillow==10.1.0 platformdirs==3.11.0 pluggy==1.3.0 -pre-commit==3.4.0 +pre-commit==3.5.0 prometheus-client==0.17.1 prompt-toolkit==3.0.39 -psutil==5.9.5 +psutil==5.9.6 ptyprocess==0.7.0 pure-eval==0.2.2 pybtex==0.24.0 pybtex-docutils==1.0.3 pycparser==2.21 -pydata-sphinx-theme==0.14.1 +pydata-sphinx-theme==0.14.2 pygments==2.16.1 pyparsing==3.1.1 pyproject-api==1.6.1 -pytest==7.4.2 +pytest==7.4.3 python-dateutil==2.8.2 python-json-logger==2.0.7 python-lsp-jsonrpc==1.1.2 python-lsp-server==1.8.2 -pytoolconfig==1.2.5 +pytoolconfig==1.2.6 pyyaml==6.0.1 pyzmq==25.1.1 qtconsole==5.4.4 -qtpy==2.4.0 +qtpy==2.4.1 referencing==0.30.2 requests==2.31.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rope==1.10.0 -rpds-py==0.10.4 -ruff==0.0.292 +rpds-py==0.10.6 +ruff==0.1.3 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 @@ -152,7 +152,7 @@ sphinx-copybutton==0.5.2 sphinx-design==0.5.0 sphinx-hep-pdgref==0.2.0 sphinx-remove-toctrees==0.0.3 -sphinx-thebe==0.2.1 +sphinx-thebe==0.3.0 sphinx-togglebutton==0.3.2 sphinxcontrib-applehelp==1.0.7 sphinxcontrib-bibtex==2.6.1 @@ -161,20 +161,20 @@ sphinxcontrib-htmlhelp==2.0.4 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.6 sphinxcontrib-serializinghtml==1.1.9 -sqlalchemy==2.0.21 +sqlalchemy==2.0.22 stack-data==0.6.3 tabulate==0.9.0 terminado==0.17.1 tinycss2==1.2.1 tornado==6.3.3 tox==4.11.3 -traitlets==5.11.2 +traitlets==5.12.0 types-python-dateutil==2.8.19.14 typing-extensions==4.8.0 ujson==5.8.0 uri-template==1.3.0 -urllib3==2.0.6 -virtualenv==20.24.5 +urllib3==2.0.7 +virtualenv==20.24.6 wcwidth==0.2.8 webcolors==1.13 webencodings==0.5.1 diff --git a/.constraints/py3.7.txt b/.constraints/py3.7.txt index 88318dd4..c56fde8b 100644 --- a/.constraints/py3.7.txt +++ b/.constraints/py3.7.txt @@ -13,18 +13,18 @@ argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.2.3 attrs==23.1.0 -babel==2.13.0 +babel==2.13.1 backcall==0.2.0 beautifulsoup4==4.12.2 black==23.3.0 bleach==6.0.0 cached-property==1.5.2 -cachetools==5.3.1 +cachetools==5.3.2 certifi==2023.7.22 cffi==1.15.1 cfgv==3.3.1 chardet==5.2.0 -charset-normalizer==3.3.0 +charset-normalizer==3.3.1 click==8.1.7 colorama==0.4.6 comm==0.1.4 @@ -33,7 +33,7 @@ debugpy==1.7.0 decorator==5.1.1 defusedxml==0.7.1 distlib==0.3.7 -docstring-to-markdown==0.12 +docstring-to-markdown==0.13 docutils==0.17.1 entrypoints==0.4 exceptiongroup==1.1.3 @@ -42,7 +42,7 @@ filelock==3.12.2 fonttools==4.38.0 fqdn==1.5.1 graphviz==0.20.1 -greenlet==3.0.0 +greenlet==3.0.1 identify==2.5.24 idna==3.4 imagesize==1.4.1 @@ -115,7 +115,7 @@ pluggy==1.2.0 pre-commit==2.21.0 prometheus-client==0.17.1 prompt-toolkit==3.0.39 -psutil==5.9.5 +psutil==5.9.6 ptyprocess==0.7.0 pybtex==0.24.0 pybtex-docutils==1.0.3 @@ -126,22 +126,22 @@ pygments==2.16.1 pyparsing==3.1.1 pyproject-api==1.5.3 pyrsistent==0.19.3 -pytest==7.4.2 +pytest==7.4.3 python-dateutil==2.8.2 python-json-logger==2.0.7 python-lsp-jsonrpc==1.0.0 python-lsp-server==1.7.4 -pytoolconfig==1.2.5 +pytoolconfig==1.2.6 pytz==2023.3.post1 pyyaml==6.0.1 pyzmq==24.0.1 qtconsole==5.4.4 -qtpy==2.4.0 +qtpy==2.4.1 requests==2.31.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rope==1.9.0 -ruff==0.0.292 +ruff==0.1.3 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 @@ -177,8 +177,8 @@ typed-ast==1.5.5 typing-extensions==4.7.1 ujson==5.7.0 uri-template==1.3.0 -urllib3==2.0.6 -virtualenv==20.24.5 +urllib3==2.0.7 +virtualenv==20.24.6 wcwidth==0.2.8 webcolors==1.13 webencodings==0.5.1 diff --git a/.constraints/py3.8.txt b/.constraints/py3.8.txt index 64b4abb5..e53293fc 100644 --- a/.constraints/py3.8.txt +++ b/.constraints/py3.8.txt @@ -10,20 +10,20 @@ anyio==4.0.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 -asttokens==2.4.0 +asttokens==2.4.1 async-lru==2.0.4 attrs==23.1.0 -babel==2.13.0 +babel==2.13.1 backcall==0.2.0 beautifulsoup4==4.12.2 -black==23.9.1 +black==23.10.1 bleach==6.1.0 -cachetools==5.3.1 +cachetools==5.3.2 certifi==2023.7.22 cffi==1.16.0 cfgv==3.4.0 chardet==5.2.0 -charset-normalizer==3.3.0 +charset-normalizer==3.3.1 click==8.1.7 colorama==0.4.6 comm==0.1.4 @@ -33,23 +33,23 @@ debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 distlib==0.3.7 -docstring-to-markdown==0.12 +docstring-to-markdown==0.13 docutils==0.17.1 exceptiongroup==1.1.3 -executing==2.0.0 +executing==2.0.1 fastjsonschema==2.18.1 -filelock==3.12.4 +filelock==3.13.0 fonttools==4.43.1 fqdn==1.5.1 graphviz==0.20.1 -greenlet==3.0.0 -identify==2.5.30 +greenlet==3.0.1 +identify==2.5.31 idna==3.4 imagesize==1.4.1 importlib-metadata==6.8.0 importlib-resources==6.1.0 iniconfig==2.0.0 -ipykernel==6.25.2 +ipykernel==6.26.0 ipympl==0.9.3 ipython==8.12.3 ipython-genutils==0.2.0 @@ -63,17 +63,17 @@ jsonschema==4.19.1 jsonschema-specifications==2023.7.1 jupyter==1.0.0 jupyter-cache==0.6.1 -jupyter-client==8.3.1 +jupyter-client==8.5.0 jupyter-console==6.6.3 -jupyter-core==5.3.2 -jupyter-events==0.7.0 +jupyter-core==5.4.0 +jupyter-events==0.8.0 jupyter-lsp==2.2.0 -jupyter-server==2.7.3 +jupyter-server==2.9.1 jupyter-server-terminals==0.4.4 -jupyterlab==4.0.6 +jupyterlab==4.0.7 jupyterlab-code-formatter==2.2.1 jupyterlab-lsp==5.0.0 -jupyterlab-myst==2.0.2 +jupyterlab-myst==2.1.0 jupyterlab-pygments==0.2.2 jupyterlab-server==2.25.0 jupyterlab-widgets==3.0.9 @@ -94,10 +94,10 @@ myst-parser==0.18.1 nbclient==0.6.8 nbconvert==7.9.2 nbformat==5.9.2 -nbmake==1.4.5 +nbmake==1.4.6 nest-asyncio==1.5.8 nodeenv==1.8.0 -notebook==7.0.4 +notebook==7.0.6 notebook-shim==0.2.3 numpy==1.24.4 overrides==7.4.0 @@ -107,41 +107,41 @@ parso==0.8.3 pathspec==0.11.2 pexpect==4.8.0 pickleshare==0.7.5 -pillow==10.0.1 +pillow==10.1.0 pkgutil-resolve-name==1.3.10 platformdirs==3.11.0 pluggy==1.3.0 -pre-commit==3.4.0 +pre-commit==3.5.0 prometheus-client==0.17.1 prompt-toolkit==3.0.39 -psutil==5.9.5 +psutil==5.9.6 ptyprocess==0.7.0 pure-eval==0.2.2 pybtex==0.24.0 pybtex-docutils==1.0.3 pycparser==2.21 -pydata-sphinx-theme==0.14.1 +pydata-sphinx-theme==0.14.2 pygments==2.16.1 pyparsing==3.1.1 pyproject-api==1.6.1 -pytest==7.4.2 +pytest==7.4.3 python-dateutil==2.8.2 python-json-logger==2.0.7 python-lsp-jsonrpc==1.1.2 python-lsp-server==1.8.2 -pytoolconfig==1.2.5 +pytoolconfig==1.2.6 pytz==2023.3.post1 pyyaml==6.0.1 pyzmq==25.1.1 qtconsole==5.4.4 -qtpy==2.4.0 +qtpy==2.4.1 referencing==0.30.2 requests==2.31.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rope==1.10.0 -rpds-py==0.10.4 -ruff==0.0.292 +rpds-py==0.10.6 +ruff==0.1.3 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 @@ -156,7 +156,7 @@ sphinx-copybutton==0.5.2 sphinx-design==0.5.0 sphinx-hep-pdgref==0.2.0 sphinx-remove-toctrees==0.0.3 -sphinx-thebe==0.2.1 +sphinx-thebe==0.3.0 sphinx-togglebutton==0.3.2 sphinxcontrib-applehelp==1.0.4 sphinxcontrib-bibtex==2.6.1 @@ -165,7 +165,7 @@ sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 -sqlalchemy==2.0.21 +sqlalchemy==2.0.22 stack-data==0.6.3 tabulate==0.9.0 terminado==0.17.1 @@ -173,13 +173,13 @@ tinycss2==1.2.1 tomli==2.0.1 tornado==6.3.3 tox==4.11.3 -traitlets==5.11.2 +traitlets==5.12.0 types-python-dateutil==2.8.19.14 typing-extensions==4.8.0 ujson==5.8.0 uri-template==1.3.0 -urllib3==2.0.6 -virtualenv==20.24.5 +urllib3==2.0.7 +virtualenv==20.24.6 wcwidth==0.2.8 webcolors==1.13 webencodings==0.5.1 diff --git a/.constraints/py3.9.txt b/.constraints/py3.9.txt index 8b616965..8764618a 100644 --- a/.constraints/py3.9.txt +++ b/.constraints/py3.9.txt @@ -10,20 +10,20 @@ anyio==4.0.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 -asttokens==2.4.0 +asttokens==2.4.1 async-lru==2.0.4 attrs==23.1.0 -babel==2.13.0 +babel==2.13.1 backcall==0.2.0 beautifulsoup4==4.12.2 -black==23.9.1 +black==23.10.1 bleach==6.1.0 -cachetools==5.3.1 +cachetools==5.3.2 certifi==2023.7.22 cffi==1.16.0 cfgv==3.4.0 chardet==5.2.0 -charset-normalizer==3.3.0 +charset-normalizer==3.3.1 click==8.1.7 colorama==0.4.6 comm==0.1.4 @@ -33,23 +33,23 @@ debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 distlib==0.3.7 -docstring-to-markdown==0.12 +docstring-to-markdown==0.13 docutils==0.17.1 exceptiongroup==1.1.3 -executing==2.0.0 +executing==2.0.1 fastjsonschema==2.18.1 -filelock==3.12.4 +filelock==3.13.0 fonttools==4.43.1 fqdn==1.5.1 graphviz==0.20.1 -greenlet==3.0.0 -identify==2.5.30 +greenlet==3.0.1 +identify==2.5.31 idna==3.4 imagesize==1.4.1 importlib-metadata==6.8.0 importlib-resources==6.1.0 iniconfig==2.0.0 -ipykernel==6.25.2 +ipykernel==6.26.0 ipympl==0.9.3 ipython==8.16.1 ipython-genutils==0.2.0 @@ -63,17 +63,17 @@ jsonschema==4.19.1 jsonschema-specifications==2023.7.1 jupyter==1.0.0 jupyter-cache==0.6.1 -jupyter-client==8.3.1 +jupyter-client==8.5.0 jupyter-console==6.6.3 -jupyter-core==5.3.2 -jupyter-events==0.7.0 +jupyter-core==5.4.0 +jupyter-events==0.8.0 jupyter-lsp==2.2.0 -jupyter-server==2.7.3 +jupyter-server==2.9.1 jupyter-server-terminals==0.4.4 -jupyterlab==4.0.6 +jupyterlab==4.0.7 jupyterlab-code-formatter==2.2.1 jupyterlab-lsp==5.0.0 -jupyterlab-myst==2.0.2 +jupyterlab-myst==2.1.0 jupyterlab-pygments==0.2.2 jupyterlab-server==2.25.0 jupyterlab-widgets==3.0.9 @@ -94,12 +94,12 @@ myst-parser==0.18.1 nbclient==0.6.8 nbconvert==7.9.2 nbformat==5.9.2 -nbmake==1.4.5 +nbmake==1.4.6 nest-asyncio==1.5.8 nodeenv==1.8.0 -notebook==7.0.4 +notebook==7.0.6 notebook-shim==0.2.3 -numpy==1.26.0 +numpy==1.26.1 overrides==7.4.0 packaging==23.2 pandocfilters==1.5.0 @@ -107,39 +107,39 @@ parso==0.8.3 pathspec==0.11.2 pexpect==4.8.0 pickleshare==0.7.5 -pillow==10.0.1 +pillow==10.1.0 platformdirs==3.11.0 pluggy==1.3.0 -pre-commit==3.4.0 +pre-commit==3.5.0 prometheus-client==0.17.1 prompt-toolkit==3.0.39 -psutil==5.9.5 +psutil==5.9.6 ptyprocess==0.7.0 pure-eval==0.2.2 pybtex==0.24.0 pybtex-docutils==1.0.3 pycparser==2.21 -pydata-sphinx-theme==0.14.1 +pydata-sphinx-theme==0.14.2 pygments==2.16.1 pyparsing==3.1.1 pyproject-api==1.6.1 -pytest==7.4.2 +pytest==7.4.3 python-dateutil==2.8.2 python-json-logger==2.0.7 python-lsp-jsonrpc==1.1.2 python-lsp-server==1.8.2 -pytoolconfig==1.2.5 +pytoolconfig==1.2.6 pyyaml==6.0.1 pyzmq==25.1.1 qtconsole==5.4.4 -qtpy==2.4.0 +qtpy==2.4.1 referencing==0.30.2 requests==2.31.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rope==1.10.0 -rpds-py==0.10.4 -ruff==0.0.292 +rpds-py==0.10.6 +ruff==0.1.3 send2trash==1.8.2 six==1.16.0 sniffio==1.3.0 @@ -154,7 +154,7 @@ sphinx-copybutton==0.5.2 sphinx-design==0.5.0 sphinx-hep-pdgref==0.2.0 sphinx-remove-toctrees==0.0.3 -sphinx-thebe==0.2.1 +sphinx-thebe==0.3.0 sphinx-togglebutton==0.3.2 sphinxcontrib-applehelp==1.0.7 sphinxcontrib-bibtex==2.6.1 @@ -163,7 +163,7 @@ sphinxcontrib-htmlhelp==2.0.4 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.6 sphinxcontrib-serializinghtml==1.1.9 -sqlalchemy==2.0.21 +sqlalchemy==2.0.22 stack-data==0.6.3 tabulate==0.9.0 terminado==0.17.1 @@ -171,13 +171,13 @@ tinycss2==1.2.1 tomli==2.0.1 tornado==6.3.3 tox==4.11.3 -traitlets==5.11.2 +traitlets==5.12.0 types-python-dateutil==2.8.19.14 typing-extensions==4.8.0 ujson==5.8.0 uri-template==1.3.0 -urllib3==2.0.6 -virtualenv==20.24.5 +urllib3==2.0.7 +virtualenv==20.24.6 wcwidth==0.2.8 webcolors==1.13 webencodings==0.5.1 From 84e654328f5a6184d7f22135b61539eeba4e1a11 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Sun, 29 Oct 2023 22:21:53 +0100 Subject: [PATCH 16/17] DOC: split analytic and numerical fit notebook --- .cspell.json | 4 + docs/report/998.ipynb | 709 ++++++++++++++++++++++++++++++++++++++++++ docs/report/999.ipynb | 565 ++++++++++++++++++++------------- 3 files changed, 1067 insertions(+), 211 deletions(-) create mode 100644 docs/report/998.ipynb diff --git a/.cspell.json b/.cspell.json index 743535c3..f231d08c 100644 --- a/.cspell.json +++ b/.cspell.json @@ -134,6 +134,7 @@ "clim", "cmap", "cmath", + "cmin", "codegen", "codemirror", "colorbar", @@ -174,6 +175,7 @@ "framealpha", "funcs", "getitem", + "getpid", "getsource", "graphviz", "griddata", @@ -260,6 +262,7 @@ "preorder", "prereleased", "println", + "psutil", "pvalues", "py's", "pycode", @@ -307,6 +310,7 @@ "surfacecolor", "symplot", "tbody", + "textwrap", "thead", "theano", "threebody", diff --git a/docs/report/998.ipynb b/docs/report/998.ipynb new file mode 100644 index 00000000..57879a0f --- /dev/null +++ b/docs/report/998.ipynb @@ -0,0 +1,709 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "hideCode": true, + "hideOutput": true, + "hidePrompt": true, + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "skip" + }, + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "%config InlineBackend.figure_formats = ['svg']\n", + "import os\n", + "\n", + "STATIC_WEB_PAGE = {\"EXECUTE_NB\", \"READTHEDOCS\"}.intersection(os.environ)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{autolink-concat}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "::::{margin}\n", + ":::{card} Gradient of an amplitude model with autodiff\n", + "TR-999\n", + "^^^\n", + "In this report, we investigate whether autodiff can be be used to analytically compute the gradient of an amplitude model. The suspicion is that autodiff cannot handle large expressions well, because the chain rule results in an excessive number of computational nodes for the gradient of the function.\n", + "+++\n", + "WIP\n", + ":::\n", + "::::" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "# Gradient with autodiff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "%pip install -q \"tensorwaves[jax,pwa]@git+https://github.com/ComPWA/tensorwaves@order-function-args\" ampform~=0.14 psutil==5.9.6 qrules~=0.9.8" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "import inspect\n", + "import os\n", + "from textwrap import dedent\n", + "\n", + "import ampform\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import psutil\n", + "import qrules\n", + "import sympy as sp\n", + "from ampform.dynamics.builder import (\n", + " create_non_dynamic_with_ff,\n", + " create_relativistic_breit_wigner_with_ff,\n", + ")\n", + "from ampform.io import aslatex\n", + "from IPython.display import Latex, Markdown\n", + "from matplotlib import cm\n", + "from tensorwaves.data import (\n", + " IntensityDistributionGenerator,\n", + " SympyDataTransformer,\n", + " TFPhaseSpaceGenerator,\n", + " TFUniformRealNumberGenerator,\n", + " TFWeightedPhaseSpaceGenerator,\n", + ")\n", + "from tensorwaves.function.sympy import create_parametrized_function\n", + "\n", + "\n", + "def display_memory_usage() -> None:\n", + " process = psutil.Process(os.getpid())\n", + " memory = process.memory_info().rss\n", + " if memory < 1024**2:\n", + " memory_str = f\"{memory / 1024**1:.2f} kB\"\n", + " elif memory < 1024**3:\n", + " memory_str = f\"{memory / 1024**2:.2f} MB\"\n", + " else:\n", + " memory_str = f\"{memory / 1024**3:.2f} GB\"\n", + " msg = dedent(f\"\"\"\n", + " :::{{hint}}\n", + " Memory Usage: **{memory_str}**\n", + " :::\n", + " \"\"\").strip()\n", + " display(Markdown(msg))\n", + "\n", + "\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", + "display_memory_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Formulate model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "REACTION = qrules.generate_transitions(\n", + " initial_state=\"B0\",\n", + " final_state=[\"K+\", \"pi-\", \"pi0\"],\n", + " allowed_intermediate_particles=[\"K*(892)\", \"rho\"],\n", + " formalism=\"helicity\",\n", + " mass_conservation_factor=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "import graphviz\n", + "\n", + "dot = qrules.io.asdot(REACTION, collapse_graphs=True)\n", + "graphviz.Source(dot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "full-width" + ] + }, + "outputs": [], + "source": [ + "INITIAL_STATE, *_ = REACTION.initial_state.values()\n", + "BUILDER = ampform.get_builder(REACTION)\n", + "BUILDER.adapter.permutate_registered_topologies()\n", + "BUILDER.set_dynamics(INITIAL_STATE.name, create_non_dynamic_with_ff)\n", + "for name in REACTION.get_intermediate_particles().names:\n", + " BUILDER.set_dynamics(name, create_relativistic_breit_wigner_with_ff)\n", + " del name\n", + "MODEL = BUILDER.formulate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "MODEL.intensity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input", + "hide-output", + "full-width" + ] + }, + "outputs": [], + "source": [ + "selection = {k: v for i, (k, v) in enumerate(MODEL.amplitudes.items()) if i < 3}\n", + "src = aslatex(selection)\n", + "del selection\n", + "Latex(src)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-input" + ] + }, + "outputs": [], + "source": [ + "display_memory_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Generate data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "rng = TFUniformRealNumberGenerator(seed=0)\n", + "phsp_generator = TFPhaseSpaceGenerator(\n", + " initial_state_mass=REACTION.initial_state[-1].mass,\n", + " final_state_masses={i: p.mass for i, p in REACTION.final_state.items()},\n", + ")\n", + "phsp_momenta = phsp_generator.generate(100_000, rng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "helicity_transformer = SympyDataTransformer.from_sympy(\n", + " MODEL.kinematic_variables, backend=\"jax\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "unfolded_expression = MODEL.expression.doit()\n", + "intensity_func = create_parametrized_function(\n", + " unfolded_expression,\n", + " parameters=MODEL.parameter_defaults,\n", + " backend=\"jax\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "full-width" + ] + }, + "outputs": [], + "source": [ + "weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(\n", + " initial_state_mass=REACTION.initial_state[-1].mass,\n", + " final_state_masses={i: p.mass for i, p in REACTION.final_state.items()},\n", + ")\n", + "data_generator = IntensityDistributionGenerator(\n", + " domain_generator=weighted_phsp_generator,\n", + " function=intensity_func,\n", + " domain_transformer=helicity_transformer,\n", + ")\n", + "data_momenta = data_generator.generate(10_000, rng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "phsp = helicity_transformer(phsp_momenta)\n", + "data = helicity_transformer(data_momenta)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Gradient creation with autodiff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "free_symbols = {\n", + " symbol: value\n", + " for symbol, value in MODEL.parameter_defaults.items()\n", + " if symbol.name[0] in {\"C\", \"d\"}\n", + "}\n", + "fixed_symbols = {\n", + " symbol: value\n", + " for symbol, value in MODEL.parameter_defaults.items()\n", + " if symbol not in free_symbols\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "intensity_func = create_parametrized_function(\n", + " unfolded_expression.xreplace(fixed_symbols),\n", + " parameters=free_symbols,\n", + " backend=\"jax\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "display_memory_usage()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%config InlineBackend.figure_formats = ['png']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "full-width", + "hide-input" + ] + }, + "outputs": [], + "source": [ + "def plot_mass_projection(ax, decay_ids: set[int]):\n", + " decay_products = get_decay_products(decay_ids)\n", + " decay_products_str = \"\".join(p.latex for p in decay_products)\n", + " resonances = get_resonances(decay_ids)\n", + " evenly_spaced_interval = np.linspace(0, 1, len(resonances))\n", + " colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n", + " ax.hist(\n", + " phsp[f\"m_{''.join(map(str, sorted(decay_ids)))}\"].real,\n", + " bins=200,\n", + " alpha=0.5,\n", + " density=True,\n", + " weights=intensity_func(phsp),\n", + " )\n", + " ax.set_xlabel(f\"$m_{{{decay_products_str}}}$ [GeV]\")\n", + " for p, color in zip(resonances, colors):\n", + " ax.axvline(x=p.mass, linestyle=\"dotted\", label=f\"${p.latex}$\", color=color)\n", + " ax.legend()\n", + "\n", + "\n", + "def get_decay_products(decay_ids: set[int]) -> tuple[Particle, Particle]:\n", + " return tuple(REACTION.final_state[i] for i in sorted(decay_ids))\n", + "\n", + "\n", + "def get_resonances(decay_ids: set[int]) -> list[Particle]:\n", + " resonances = {\n", + " t.states[3].particle\n", + " for t in REACTION.transitions\n", + " if t.topology.get_edge_ids_outgoing_from_node(1) == decay_ids\n", + " }\n", + " return sorted(resonances, key=lambda p: (p.name[0], p.mass))\n", + "\n", + "\n", + "fig, axes = plt.subplots(figsize=(9, 12), nrows=3)\n", + "plot_mass_projection(axes[0], decay_ids={0, 1})\n", + "plot_mass_projection(axes[1], decay_ids={1, 2})\n", + "plot_mass_projection(axes[2], decay_ids={2, 0})\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "full-width", + "hide-input" + ] + }, + "outputs": [], + "source": [ + "def indicate_resonances(ax_func, decay_ids) -> None:\n", + " resonances = get_resonances(decay_ids)\n", + " evenly_spaced_interval = np.linspace(0, 1, len(resonances))\n", + " colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n", + " for p, color in zip(resonances, colors):\n", + " ax_func(p.mass**2, linestyle=\"dotted\", label=f\"${p.latex}$\", color=color)\n", + "\n", + "\n", + "x_subsystem = {0, 1}\n", + "y_subsystem = {1, 2}\n", + "x_products = get_decay_products(x_subsystem)\n", + "y_products = get_decay_products(y_subsystem)\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist2d(\n", + " phsp[f\"m_{''.join(map(str, sorted(x_subsystem)))}\"].real ** 2,\n", + " phsp[f\"m_{''.join(map(str, sorted(y_subsystem)))}\"].real ** 2,\n", + " bins=100,\n", + " cmin=1,\n", + " weights=intensity_func(phsp),\n", + ")\n", + "ax.set_xlabel(f\"$m_{{{' '.join(p.latex for p in x_products)}}}$\")\n", + "ax.set_ylabel(f\"$m_{{{' '.join(p.latex for p in y_products)}}}$\")\n", + "indicate_resonances(ax.axvline, x_subsystem)\n", + "indicate_resonances(ax.axhline, y_subsystem)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-input" + ] + }, + "outputs": [], + "source": [ + "display_memory_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Optimize parameters" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "full-width", + "hide-input" + ] + }, + "source": [ + "### Numerical gradient descent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "sig = inspect.signature(intensity_func.function)\n", + "arg_names = tuple(sig.parameters)\n", + "arg_to_par = {\n", + " arg: par\n", + " for arg, par in zip(arg_names, intensity_func.argument_order)\n", + " if par in intensity_func.parameters\n", + "}\n", + "idx_to_par = dict(enumerate(arg_to_par.values()))\n", + "parameter_values = {\n", + " arg: complex(intensity_func.parameters[par]).real\n", + " for arg, par in arg_to_par.items()\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def estimator(args):\n", + " parameters = {idx_to_par[i]: v for i, v in enumerate(args)}\n", + " intensity_func.update_parameters(parameters)\n", + " data_intensities = intensity_func(data)\n", + " phsp_intensities = intensity_func(phsp)\n", + " likelihoods = data_intensities / jnp.mean(phsp_intensities)\n", + " return -jnp.sum(jnp.log(likelihoods))\n", + "\n", + "\n", + "estimator(parameter_values.values())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "arr = jnp.array([1.0])\n", + "msg = f\"\"\"\n", + ":::{{hint}}\n", + "JAX is using this precision: **{arr.dtype}**. For the model, we have:\n", + "- {len(REACTION.get_intermediate_particles())} resonances in {len(REACTION.transition_groups)} subsystems\n", + "- {len(parameter_values)} of {len(MODEL.parameter_defaults)} free parameters\n", + "- {sp.count_ops(unfolded_expression):,d} computational nodes\n", + ":::\n", + "\"\"\"\n", + "msg = dedent(msg).strip()\n", + "display(Markdown(msg))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-output" + ] + }, + "outputs": [], + "source": [ + "import iminuit\n", + "from tqdm.auto import tqdm\n", + "\n", + "PROGRESS_BAR = tqdm()\n", + "\n", + "\n", + "def estimator_with_progress_bar(*args, **kwargs):\n", + " estimator_value = estimator(*args, **kwargs)\n", + " PROGRESS_BAR.update()\n", + " PROGRESS_BAR.set_postfix({\"estimator\": f\"{estimator_value:,10g}\"})\n", + " return estimator_value\n", + "\n", + "\n", + "RNG = np.random.default_rng(seed=0)\n", + "δ = 0.01\n", + "starting_values = tuple(\n", + " p * RNG.uniform(1 - δ, 1 + δ) for p in parameter_values.values()\n", + ")\n", + "optimizer = iminuit.Minuit(\n", + " estimator_with_progress_bar,\n", + " starting_values,\n", + " name=tuple(parameter_values),\n", + ")\n", + "optimizer.errors = tuple(\n", + " 0.1 * abs(x) if abs(x) != 0.0 else 0.1 for x in starting_values\n", + ")\n", + "optimizer.errordef = iminuit.Minuit.LIKELIHOOD\n", + "optimizer.migrad()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### With analytic gradient" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "display_memory_usage()" + ] + } + ], + "metadata": { + "colab": { + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/report/999.ipynb b/docs/report/999.ipynb index 19ccc286..42703f51 100644 --- a/docs/report/999.ipynb +++ b/docs/report/999.ipynb @@ -72,7 +72,7 @@ }, "outputs": [], "source": [ - "%pip install -q \"tensorwaves[jax,pwa]@git+https://github.com/ComPWA/tensorwaves@order-function-args\" ampform~=0.14 qrules~=0.9.8" + "%pip install -q \"tensorwaves[jax,pwa]@git+https://github.com/ComPWA/tensorwaves@order-function-args\" ampform~=0.14 psutil==5.9.6 qrules~=0.9.8" ] }, { @@ -86,20 +86,26 @@ }, "outputs": [], "source": [ + "from __future__ import annotations\n", + "\n", "import inspect\n", "import os\n", + "from textwrap import dedent\n", "\n", "import ampform\n", "import jax\n", + "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", + "import psutil\n", "import qrules\n", + "import sympy as sp\n", "from ampform.dynamics.builder import (\n", " create_non_dynamic_with_ff,\n", " create_relativistic_breit_wigner_with_ff,\n", ")\n", "from ampform.io import aslatex\n", - "from IPython.display import Latex\n", + "from IPython.display import Latex, Markdown\n", "from jax.tree_util import Partial\n", "from matplotlib import cm\n", "from tensorwaves.data import (\n", @@ -109,24 +115,37 @@ " TFUniformRealNumberGenerator,\n", " TFWeightedPhaseSpaceGenerator,\n", ")\n", - "from tensorwaves.function.sympy import create_function, create_parametrized_function\n", + "from tensorwaves.function.sympy import create_parametrized_function\n", + "\n", + "\n", + "def display_memory_usage() -> None:\n", + " process = psutil.Process(os.getpid())\n", + " memory = process.memory_info().rss\n", + " if memory < 1024**2:\n", + " memory_str = f\"{memory / 1024**1:.2f} kB\"\n", + " elif memory < 1024**3:\n", + " memory_str = f\"{memory / 1024**2:.2f} MB\"\n", + " else:\n", + " memory_str = f\"{memory / 1024**3:.2f} GB\"\n", + " msg = dedent(f\"\"\"\n", + " :::{{hint}}\n", + " Memory Usage: **{memory_str}**\n", + " :::\n", + " \"\"\").strip()\n", + " display(Markdown(msg))\n", + "\n", "\n", - "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"" + "jax.config.update(\"jax_enable_x64\", False)\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n", + "display_memory_usage()" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "metadata": {}, "outputs": [], "source": [ - "import jax.numpy as jnp\n", - "\n", - "\n", "def print_jax_precision() -> None:\n", " arr = jnp.array([1.0])\n", " print(arr.dtype)\n", @@ -138,7 +157,6 @@ { "cell_type": "markdown", "metadata": { - "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ @@ -153,12 +171,12 @@ }, "outputs": [], "source": [ - "reaction = qrules.generate_transitions(\n", - " initial_state=(\"J/psi(1S)\", [-1, +1]),\n", - " final_state=[\"gamma\", \"pi0\", \"pi0\"],\n", - " allowed_intermediate_particles=[\"a(0)\", \"f(0)\", \"omega\"],\n", - " allowed_interaction_types=[\"strong\", \"EM\"],\n", + "REACTION = qrules.generate_transitions(\n", + " initial_state=\"B0\",\n", + " final_state=[\"K+\", \"pi-\", \"pi0\"],\n", + " allowed_intermediate_particles=[\"K*(892)\", \"rho\"],\n", " formalism=\"helicity\",\n", + " mass_conservation_factor=0,\n", ")" ] }, @@ -177,7 +195,7 @@ "source": [ "import graphviz\n", "\n", - "dot = qrules.io.asdot(reaction, collapse_graphs=True)\n", + "dot = qrules.io.asdot(REACTION, collapse_graphs=True)\n", "graphviz.Source(dot)" ] }, @@ -191,12 +209,14 @@ }, "outputs": [], "source": [ - "model_builder = ampform.get_builder(reaction)\n", - "model_builder.adapter.permutate_registered_topologies()\n", - "model_builder.set_dynamics(\"J/psi(1S)\", create_non_dynamic_with_ff)\n", - "for name in reaction.get_intermediate_particles().names:\n", - " model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)\n", - "model = model_builder.formulate()" + "INITIAL_STATE, *_ = REACTION.initial_state.values()\n", + "BUILDER = ampform.get_builder(REACTION)\n", + "BUILDER.adapter.permutate_registered_topologies()\n", + "BUILDER.set_dynamics(INITIAL_STATE.name, create_non_dynamic_with_ff)\n", + "for name in REACTION.get_intermediate_particles().names:\n", + " BUILDER.set_dynamics(name, create_relativistic_breit_wigner_with_ff)\n", + " del name\n", + "MODEL = BUILDER.formulate()" ] }, { @@ -212,7 +232,7 @@ }, "outputs": [], "source": [ - "model.intensity" + "MODEL.intensity" ] }, { @@ -224,21 +244,38 @@ }, "tags": [ "hide-input", - "hide-output" + "hide-output", + "full-width" ] }, "outputs": [], "source": [ - "selection = {k: v for i, (k, v) in enumerate(model.amplitudes.items()) if i < 3}\n", + "selection = {k: v for i, (k, v) in enumerate(MODEL.amplitudes.items()) if i < 3}\n", "src = aslatex(selection)\n", "del selection\n", "Latex(src)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-input" + ] + }, + "outputs": [], + "source": [ + "display_memory_usage()" + ] + }, { "cell_type": "markdown", "metadata": { - "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ @@ -255,8 +292,8 @@ "source": [ "rng = TFUniformRealNumberGenerator(seed=0)\n", "phsp_generator = TFPhaseSpaceGenerator(\n", - " initial_state_mass=reaction.initial_state[-1].mass,\n", - " final_state_masses={i: p.mass for i, p in reaction.final_state.items()},\n", + " initial_state_mass=REACTION.initial_state[-1].mass,\n", + " final_state_masses={i: p.mass for i, p in REACTION.final_state.items()},\n", ")\n", "phsp_momenta = phsp_generator.generate(100_000, rng)" ] @@ -268,23 +305,33 @@ "outputs": [], "source": [ "helicity_transformer = SympyDataTransformer.from_sympy(\n", - " model.kinematic_variables, backend=\"jax\"\n", + " MODEL.kinematic_variables, backend=\"jax\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [ - "skip-flake8" - ] - }, + "metadata": {}, + "outputs": [], + "source": [ + "unfolded_expression = MODEL.expression.doit()\n", + "intensity_func = create_parametrized_function(\n", + " unfolded_expression,\n", + " parameters=MODEL.parameter_defaults,\n", + " backend=\"jax\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "unfolded_expression = model.expression.doit()\n", - "substituted_expression = unfolded_expression.xreplace(model.parameter_defaults)\n", - "fixed_intensity_func = create_function(substituted_expression, backend=\"jax\")" + "print_jax_precision()\n", + "jax.config.update(\"jax_enable_x64\", False)\n", + "print_jax_precision()" ] }, { @@ -298,12 +345,12 @@ "outputs": [], "source": [ "weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(\n", - " initial_state_mass=reaction.initial_state[-1].mass,\n", - " final_state_masses={i: p.mass for i, p in reaction.final_state.items()},\n", + " initial_state_mass=REACTION.initial_state[-1].mass,\n", + " final_state_masses={i: p.mass for i, p in REACTION.final_state.items()},\n", ")\n", "data_generator = IntensityDistributionGenerator(\n", " domain_generator=weighted_phsp_generator,\n", - " function=fixed_intensity_func,\n", + " function=intensity_func,\n", " domain_transformer=helicity_transformer,\n", ")\n", "data_momenta = data_generator.generate(10_000, rng)" @@ -312,162 +359,201 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ - "list(helicity_transformer.functions)" + "phsp = helicity_transformer(phsp_momenta)\n", + "data = helicity_transformer(data_momenta)" ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "cell_type": "markdown", + "metadata": { + "tags": [] + }, "source": [ - "phsp = helicity_transformer(phsp_momenta)\n", - "data = helicity_transformer(data_momenta)" + "## Gradient creation with autodiff" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "tags": [] + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] }, "outputs": [], "source": [ - "sorted(substituted_expression.free_symbols, key=str)" + "free_symbols = {\n", + " symbol: value\n", + " for symbol, value in MODEL.parameter_defaults.items()\n", + " if symbol.name[0] in {\"C\", \"d\"}\n", + "}\n", + "fixed_symbols = {\n", + " symbol: value\n", + " for symbol, value in MODEL.parameter_defaults.items()\n", + " if symbol not in free_symbols\n", + "}" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ - "list(model.kinematic_variables)" + "intensity_func = create_parametrized_function(\n", + " unfolded_expression.xreplace(fixed_symbols),\n", + " parameters=free_symbols,\n", + " backend=\"jax\",\n", + ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "jupyter": { - "source_hidden": true - }, "tags": [ - "full-width", "hide-input" ] }, "outputs": [], "source": [ - "resonances = sorted(reaction.get_intermediate_particles(), key=lambda p: p.mass)\n", - "evenly_spaced_interval = np.linspace(0, 1, len(resonances))\n", - "colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n", - "fig, ax = plt.subplots(figsize=(9, 4))\n", - "ax.hist(\n", - " np.real(data[\"m_12\"]),\n", - " bins=200,\n", - " alpha=0.5,\n", - " density=True,\n", - ")\n", - "ax.set_xlabel(\"$m$ [GeV]\")\n", - "for p, color in zip(resonances, colors):\n", - " ax.axvline(x=p.mass, linestyle=\"dotted\", label=p.name, color=color)\n", - "ax.legend()\n", - "plt.show()" + "display_memory_usage()" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": { - "jp-MarkdownHeadingCollapsed": true, "tags": [] }, + "outputs": [], "source": [ - "## Gradient creation with autodiff" + "%config InlineBackend.figure_formats = ['png']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "editable": true, "jupyter": { "source_hidden": true }, + "slideshow": { + "slide_type": "" + }, "tags": [ + "full-width", "hide-input" ] }, "outputs": [], "source": [ - "free_symbols = {\n", - " symbol: value\n", - " for symbol, value in model.parameter_defaults.items()\n", - " if not symbol.name.startswith(\"d\")\n", - "}\n", - "some_coefficient = next(s for s in free_symbols if s.name.startswith(\"C\"))\n", - "free_symbols.pop(some_coefficient)\n", - "fixed_symbols = {\n", - " symbol: value\n", - " for symbol, value in model.parameter_defaults.items()\n", - " if symbol not in free_symbols\n", - "}" + "def plot_mass_projection(ax, decay_ids: set[int]):\n", + " decay_products = get_decay_products(decay_ids)\n", + " decay_products_str = \"\".join(p.latex for p in decay_products)\n", + " resonances = get_resonances(decay_ids)\n", + " evenly_spaced_interval = np.linspace(0, 1, len(resonances))\n", + " colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n", + " ax.hist(\n", + " phsp[f\"m_{''.join(map(str, sorted(decay_ids)))}\"].real,\n", + " bins=200,\n", + " alpha=0.5,\n", + " density=True,\n", + " weights=intensity_func(phsp),\n", + " )\n", + " ax.set_xlabel(f\"$m_{{{decay_products_str}}}$ [GeV]\")\n", + " for p, color in zip(resonances, colors):\n", + " ax.axvline(x=p.mass, linestyle=\"dotted\", label=f\"${p.latex}$\", color=color)\n", + " ax.legend()\n", + "\n", + "\n", + "def get_decay_products(decay_ids: set[int]) -> tuple[Particle, Particle]:\n", + " return tuple(REACTION.final_state[i] for i in sorted(decay_ids))\n", + "\n", + "\n", + "def get_resonances(decay_ids: set[int]) -> list[Particle]:\n", + " resonances = {\n", + " t.states[3].particle\n", + " for t in REACTION.transitions\n", + " if t.topology.get_edge_ids_outgoing_from_node(1) == decay_ids\n", + " }\n", + " return sorted(resonances, key=lambda p: (p.name[0], p.mass))\n", + "\n", + "\n", + "fig, axes = plt.subplots(figsize=(9, 12), nrows=3)\n", + "plot_mass_projection(axes[0], decay_ids={0, 1})\n", + "plot_mass_projection(axes[1], decay_ids={1, 2})\n", + "plot_mass_projection(axes[2], decay_ids={2, 0})\n", + "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "editable": true, "jupyter": { "source_hidden": true }, - "tags": [] - }, - "outputs": [], - "source": [ - "src = aslatex({k: free_symbols[k] for k in sorted(free_symbols, key=str)})\n", - "Latex(src)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { + "slideshow": { + "slide_type": "" + }, "tags": [ - "skip-flake8" + "full-width", + "hide-input" ] }, "outputs": [], "source": [ - "expression = unfolded_expression\n", - "expression = expression.xreplace(fixed_symbols)\n", - "intensity_func = create_parametrized_function(\n", - " expression, parameters=free_symbols, backend=\"jax\"\n", - ")" + "def indicate_resonances(ax_func, decay_ids) -> None:\n", + " resonances = get_resonances(decay_ids)\n", + " evenly_spaced_interval = np.linspace(0, 1, len(resonances))\n", + " colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n", + " for p, color in zip(resonances, colors):\n", + " ax_func(p.mass**2, linestyle=\"dotted\", label=f\"${p.latex}$\", color=color)\n", + "\n", + "\n", + "x_subsystem = {0, 1}\n", + "y_subsystem = {1, 2}\n", + "x_products = get_decay_products(x_subsystem)\n", + "y_products = get_decay_products(y_subsystem)\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.hist2d(\n", + " phsp[f\"m_{''.join(map(str, sorted(x_subsystem)))}\"].real ** 2,\n", + " phsp[f\"m_{''.join(map(str, sorted(y_subsystem)))}\"].real ** 2,\n", + " bins=100,\n", + " cmin=1,\n", + " weights=intensity_func(phsp),\n", + ")\n", + "ax.set_xlabel(f\"$m_{{{' '.join(p.latex for p in x_products)}}}$\")\n", + "ax.set_ylabel(f\"$m_{{{' '.join(p.latex for p in y_products)}}}$\")\n", + "indicate_resonances(ax.axvline, x_subsystem)\n", + "indicate_resonances(ax.axhline, y_subsystem)\n", + "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "jupyter": { - "source_hidden": true + "editable": true, + "slideshow": { + "slide_type": "" }, "tags": [ - "hide-input" + "remove-input" ] }, "outputs": [], "source": [ - "from IPython.display import Markdown\n", - "\n", - "Markdown(f\"Function has **{len(intensity_func.parameters)} free parameters**.\")" + "display_memory_usage()" ] }, { @@ -485,7 +571,16 @@ " for arg, key in zip(arg_names, intensity_func.argument_order)\n", " if key in data\n", "}\n", - "data_columns" + "arg_to_par = {\n", + " arg: par\n", + " for arg, par in zip(arg_names, intensity_func.argument_order)\n", + " if par in intensity_func.parameters\n", + "}\n", + "idx_to_par = dict(enumerate(arg_to_par.values()))\n", + "parameter_values = {\n", + " arg: complex(intensity_func.parameters[par]).real\n", + " for arg, par in arg_to_par.items()\n", + "}" ] }, { @@ -496,28 +591,29 @@ }, "outputs": [], "source": [ - "parameter_values = {\n", - " arg: complex(intensity_func.parameters[key]).real\n", - " for arg, key in zip(arg_names, intensity_func.argument_order)\n", - " if key in intensity_func.parameters\n", - "}\n", - "parameter_values" + "func_with_data_inserted = Partial(intensity_func.function, *data_columns.values())\n", + "gradient_func = jax.jacfwd(\n", + " func_with_data_inserted,\n", + " argnums=range(len(parameter_values)),\n", + ")\n", + "gradient_func" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "tags": [] + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-input" + ] }, "outputs": [], "source": [ - "func_with_data_inserted = Partial(intensity_func.function, *data_columns.values())\n", - "gradient_func = jax.jacfwd(\n", - " func_with_data_inserted,\n", - " argnums=range(len(parameter_values)),\n", - ")\n", - "gradient_func" + "display_memory_usage()" ] }, { @@ -532,6 +628,23 @@ "_ = tuple(v.block_until_ready() for v in gradient_func(*parameter_values.values()))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-input" + ] + }, + "outputs": [], + "source": [ + "display_memory_usage()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -544,6 +657,23 @@ "gradient_values[0].shape" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-input" + ] + }, + "outputs": [], + "source": [ + "display_memory_usage()" + ] + }, { "cell_type": "markdown", "metadata": { @@ -555,7 +685,16 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "full-width", + "hide-input" + ] + }, "source": [ "### Numerical gradient descent" ] @@ -595,7 +734,6 @@ }, "outputs": [], "source": [ - "# @jax.jit # Do not JIT here, otherwise jax.jacfwd crashes!\n", "def estimator(args):\n", " data_intensities = func_with_data_inserted(*args)\n", " phsp_intensities = func_with_phsp_inserted(*args)\n", @@ -609,88 +747,105 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-input" + ] + }, "outputs": [], "source": [ - "print_jax_precision()" + "display_memory_usage()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "tags": [ - "hide-output" - ] + "tags": [] }, "outputs": [], "source": [ - "import iminuit\n", - "from tqdm.auto import tqdm\n", - "\n", - "PROGRESS_BAR = tqdm()\n", - "\n", - "\n", - "def estimator_with_progress_bar(*args, **kwargs):\n", - " estimator_value = estimator(*args, **kwargs)\n", - " PROGRESS_BAR.update()\n", - " PROGRESS_BAR.set_postfix({\"estimator\": estimator_value})\n", - " return estimator_value\n", - "\n", - "\n", - "starting_values = tuple(parameter_values.values())\n", - "optimizer = iminuit.Minuit(\n", - " estimator_with_progress_bar,\n", - " starting_values,\n", - " name=tuple(parameter_values),\n", - ")\n", - "optimizer.errors = tuple(\n", - " 0.1 * abs(x) if abs(x) != 0.0 else 0.1 for x in starting_values\n", - ")\n", - "optimizer.errordef = iminuit.Minuit.LIKELIHOOD\n", - "optimizer.migrad()" + "estimator_gradient = jax.jacfwd(estimator)" ] }, { - "cell_type": "markdown", - "metadata": {}, + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-input" + ] + }, + "outputs": [], "source": [ - "### With analytic gradient" + "display_memory_usage()" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "from jax.config import config\n", - "\n", - "config.update(\"jax_enable_x64\", False)\n", - "print_jax_precision()" + "%%time\n", + "estimator_gradient(tuple(parameter_values.values()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "tags": [] + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-input" + ] }, "outputs": [], "source": [ - "estimator_gradient = jax.jacfwd(estimator)" + "display_memory_usage()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "tags": [] + "editable": true, + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-input" + ] }, "outputs": [], "source": [ - "%%time\n", - "estimator_gradient(tuple(parameter_values.values()))" + "arr = jnp.array([1.0])\n", + "msg = f\"\"\"\n", + ":::{{hint}}\n", + "JAX is using this precision: **{arr.dtype}**. For the model, we have:\n", + "- {len(REACTION.get_intermediate_particles())} resonances in {len(REACTION.transition_groups)} subsystems\n", + "- {len(parameter_values)} of {len(MODEL.parameter_defaults)} free parameters\n", + "- {sp.count_ops(unfolded_expression):,d} computational nodes\n", + ":::\n", + "\"\"\"\n", + "msg = dedent(msg).strip()\n", + "display(Markdown(msg))" ] }, { @@ -703,56 +858,44 @@ }, "outputs": [], "source": [ - "PROGRESS_BAR = tqdm() # reset\n", - "autodiff_optimizer = iminuit.Minuit(\n", + "import iminuit\n", + "from tqdm.auto import tqdm\n", + "\n", + "PROGRESS_BAR = tqdm()\n", + "\n", + "\n", + "def estimator_with_progress_bar(*args, **kwargs):\n", + " estimator_value = estimator(*args, **kwargs)\n", + " PROGRESS_BAR.update()\n", + " PROGRESS_BAR.set_postfix({\"estimator\": f\"{estimator_value:10g}\"})\n", + " return estimator_value\n", + "\n", + "\n", + "RNG = np.random.default_rng(seed=0)\n", + "δ = 0.01\n", + "starting_values = tuple(\n", + " p * RNG.uniform(1 - δ, 1 + δ) for p in parameter_values.values()\n", + ")\n", + "optimizer = iminuit.Minuit(\n", " estimator_with_progress_bar,\n", " starting_values,\n", " grad=estimator_gradient, # analytic!\n", " name=tuple(parameter_values),\n", ")\n", - "autodiff_optimizer.errors = tuple(\n", + "optimizer.errors = tuple(\n", " 0.1 * abs(x) if abs(x) != 0.0 else 0.1 for x in starting_values\n", ")\n", - "autodiff_optimizer.errordef = iminuit.Minuit.LIKELIHOOD\n", - "autodiff_optimizer.migrad()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "tags": [] - }, - "source": [ - "## Conclusion" + "optimizer.errordef = iminuit.Minuit.LIKELIHOOD\n", + "optimizer.migrad()" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - }, - "tags": [ - "hide-input" - ] - }, + "metadata": {}, "outputs": [], "source": [ - "def compute_diff(minuit):\n", - " original_pars = np.array(starting_values)\n", - " optimized_pars = np.array([p.value for p in minuit.params])\n", - " diff = original_pars - optimized_pars\n", - " return np.sqrt(np.sum(np.abs(diff) ** 2)) / len(minuit.params)\n", - "\n", - "\n", - "src = f\"\"\"\n", - "| | numerical | autodiff |\n", - "|--|-----------|----------|\n", - "| time (s) | {optimizer.fmin.time:.1f} | {autodiff_optimizer.fmin.time:.1f} |\n", - "| average parameter offset | {compute_diff(optimizer):.4f} | {compute_diff(autodiff_optimizer):.4f} |\n", - "\"\"\"\n", - "Markdown(src)" + "display_memory_usage()" ] } ], From 0f2a7a8ddd51542c85a84786e383009602ca820c Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Mon, 30 Oct 2023 10:22:50 +0100 Subject: [PATCH 17/17] FIX: remove cspell comment --- .cspell.json | 1 + docs/report/998.ipynb | 2 -- docs/report/999.ipynb | 2 -- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.cspell.json b/.cspell.json index f231d08c..c77c9ca4 100644 --- a/.cspell.json +++ b/.cspell.json @@ -113,6 +113,7 @@ "arange", "arccos", "arctan", + "argnums", "asarray", "asdot", "aslatex", diff --git a/docs/report/998.ipynb b/docs/report/998.ipynb index 57879a0f..bae20599 100644 --- a/docs/report/998.ipynb +++ b/docs/report/998.ipynb @@ -37,8 +37,6 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", - "\n", "::::{margin}\n", ":::{card} Gradient of an amplitude model with autodiff\n", "TR-999\n", diff --git a/docs/report/999.ipynb b/docs/report/999.ipynb index 42703f51..b3751b29 100644 --- a/docs/report/999.ipynb +++ b/docs/report/999.ipynb @@ -37,8 +37,6 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", - "\n", "::::{margin}\n", ":::{card} Gradient of an amplitude model with autodiff\n", "TR-999\n",