From b181e89b71b6ba16f2bcd243d091d202a8b7177f Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Fri, 19 Feb 2021 00:26:19 +0100 Subject: [PATCH] docs: add live plot of fit with custom callback --- cspell.json | 4 ++- docs/usage.ipynb | 85 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/cspell.json b/cspell.json index 8904634da..a3611bd75 100644 --- a/cspell.json +++ b/cspell.json @@ -181,6 +181,8 @@ "unflattened", "unnormalized", "vstack", - "xlabel" + "xlabel", + "xlim", + "ylim" ] } \ No newline at end of file diff --git a/docs/usage.ipynb b/docs/usage.ipynb index f0a6af1ad..3278bd42e 100644 --- a/docs/usage.ipynb +++ b/docs/usage.ipynb @@ -242,9 +242,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "jupyter": { - "source_hidden": true - }, "tags": [ "hide-cell" ] @@ -271,7 +268,7 @@ "def indicate_masses():\n", " plt.xlabel(\"$m$ [GeV]\")\n", " for i, p in enumerate(intermediate_states):\n", - " plt.axvline(\n", + " plt.gca().axvline(\n", " x=p.mass, linestyle=\"dotted\", label=p.name, color=colors[i]\n", " )" ] @@ -319,9 +316,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "jupyter": { - "source_hidden": true - }, "tags": [ "hide-cell" ] @@ -365,10 +359,8 @@ "initial_parameters = {\n", " \"C[J/\\\\psi(1S) \\\\to f_{0}(1500)_{0} \\\\gamma_{+1};f_{0}(1500) \\\\to \\\\pi^{0}_{0} \\\\pi^{0}_{0}]\": 1.0\n", " + 0.0j,\n", - " \"m_f(0)(500)\": 0.6,\n", " \"Gamma_f(0)(500)\": 0.3,\n", - " \"Gamma_f(0)(980)\": 0.2,\n", - " \"m_f(0)(1370)\": 1.3,\n", + " \"Gamma_f(0)(980)\": 0.1,\n", " \"m_f(0)(1710)\": 1.75,\n", " \"Gamma_f(0)(1710)\": 0.2,\n", "}\n", @@ -377,15 +369,71 @@ "print(\"Number of free parameters:\", len(initial_parameters))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{tip}\n", + "Insert behavior into the {class}`.Optimizer` by defining a custom {class}`.Callback` class. Here's one that live updates a plot of the latest fit model!\n", + "```" + ] + }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [ + "hide-cell" + ] + }, "outputs": [], "source": [ - "minuit2 = Minuit2(callback=CSVSummary(\"traceback.csv\", step_size=2))\n", - "result = minuit2.optimize(estimator, initial_parameters)\n", - "result" + "import os\n", + "\n", + "from IPython.display import clear_output\n", + "from tensorwaves.optimizer.callbacks import Callback, CallbackList\n", + "\n", + "\n", + "class PyplotCallback(Callback):\n", + " def __init__(self, step_size=10):\n", + " self.__step_size = step_size\n", + " self.__fig, self.__ax = plt.subplots(1, figsize=(8, 5))\n", + " self.__latest_parameters = {}\n", + "\n", + " def on_iteration_end(self, function_call, logs=None):\n", + " self.__latest_parameters = logs[\"parameters\"]\n", + " if function_call % self.__step_size != 0:\n", + " return\n", + " if \"READTHEDOCS\" in os.environ:\n", + " return\n", + " self.update_plot()\n", + " clear_output(wait=True)\n", + " display(plt.gcf())\n", + "\n", + " def on_function_call_end(self):\n", + " self.update_plot()\n", + "\n", + " def update_plot(self):\n", + " bins = 100\n", + " data = data_set[\"m_3+4\"]\n", + " phsp = phsp_set[\"m_3+4\"]\n", + " intensity.update_parameters(self.__latest_parameters)\n", + " intensities = intensity(phsp_set)\n", + " self.__ax.cla()\n", + " self.__ax.hist(data, bins=bins, alpha=0.5, label=\"data\", density=True)\n", + " self.__ax.hist(\n", + " phsp,\n", + " weights=intensities,\n", + " bins=bins,\n", + " histtype=\"step\",\n", + " color=\"red\",\n", + " label=\"fit model\",\n", + " density=True,\n", + " )\n", + " self.__ax.set_xlim((0.25, 2.5))\n", + " self.__ax.set_ylim((0, 1.9))\n", + " indicate_masses()\n", + " plt.gcf().legend()" ] }, { @@ -394,8 +442,13 @@ "metadata": {}, "outputs": [], "source": [ - "intensity.update_parameters(result[\"parameter_values\"])\n", - "compare_model(\"m_3+4\", data_set, phsp_set, intensity)" + "minuit2 = Minuit2(\n", + " callback=CallbackList(\n", + " [CSVSummary(\"traceback.csv\", step_size=2), PyplotCallback()]\n", + " )\n", + ")\n", + "fit_result = minuit2.optimize(estimator, initial_parameters)\n", + "fit_result" ] }, {