From b7a4efd24ef78f4efbf52f48af5b8b1a0db0ac77 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Mon, 6 Dec 2021 13:57:57 +0100 Subject: [PATCH] refactor!: adapt implementation to AmpForm v0.12.x (#345) * build: switch to AmpForm v0.12 * ci: update pip constraints and pre-commit config * feat: compute kinematic helicity angles with different backends * feat: define PositionalArgumentFunction * feat: define create_function * fix: force-push to matching branches * refactor: accept only str as DataSample keys * refactor: extract _printer module from sympy module * refactor: implement SympyDataTransformer * test: benchmark data generation with numpy and jax Co-authored-by: GitHub Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .constraints/py3.6.txt | 4 +- .constraints/py3.7.txt | 7 +- .constraints/py3.8.txt | 7 +- .constraints/py3.9.txt | 7 +- .github/workflows/cd.yml | 4 +- .pre-commit-config.yaml | 8 +- .pylintrc | 1 + benchmarks/ampform.py | 39 ++-- docs/abbreviate_signature.py | 5 +- docs/usage.ipynb | 79 ++++---- docs/usage/analytic-continuation.ipynb | 35 ++-- docs/usage/step2.ipynb | 51 +++-- docs/usage/step3.ipynb | 81 ++++---- pytest.ini | 2 + setup.cfg | 2 +- src/tensorwaves/data/__init__.py | 90 +++++---- src/tensorwaves/data/phasespace.py | 6 +- src/tensorwaves/data/transform.py | 70 +++++-- src/tensorwaves/function/__init__.py | 74 +++++++- src/tensorwaves/function/_backend.py | 5 +- .../function/{sympy.py => sympy/__init__.py} | 178 ++++++++++-------- src/tensorwaves/function/sympy/_printer.py | 82 ++++++++ src/tensorwaves/interface.py | 2 +- tests/data/test_generate.py | 32 ++-- tests/data/test_phasespace.py | 14 +- tests/data/test_transform.py | 26 +++ tests/function/test_ampform.py | 42 ++++- tests/function/test_function.py | 33 +++- tests/function/test_sympy.py | 16 +- 29 files changed, 667 insertions(+), 335 deletions(-) rename src/tensorwaves/function/{sympy.py => sympy/__init__.py} (72%) create mode 100644 src/tensorwaves/function/sympy/_printer.py create mode 100644 tests/data/test_transform.py diff --git a/.constraints/py3.6.txt b/.constraints/py3.6.txt index d87e6da3..e85db13c 100644 --- a/.constraints/py3.6.txt +++ b/.constraints/py3.6.txt @@ -6,7 +6,7 @@ # absl-py==0.15.0 alabaster==0.7.12 -ampform==0.11.5 +ampform==0.12.0 anyio==3.4.0 appdirs==1.4.4 aquirdturtle-collapsible-headings==3.1.0 @@ -175,7 +175,7 @@ python-dateutil==2.8.2 pytz==2021.3 pyyaml==6.0 pyzmq==22.3.0 -qrules==0.9.4 +qrules==0.9.5 qtconsole==5.2.1 qtpy==1.11.3 regex==2021.11.10 diff --git a/.constraints/py3.7.txt b/.constraints/py3.7.txt index 0471a9a1..875a18ae 100644 --- a/.constraints/py3.7.txt +++ b/.constraints/py3.7.txt @@ -6,7 +6,7 @@ # absl-py==1.0.0 alabaster==0.7.12 -ampform==0.11.5 +ampform==0.12.0 anyio==3.4.0 aquirdturtle-collapsible-headings==3.1.0 argcomplete==1.12.3 @@ -18,7 +18,7 @@ babel==2.9.1 backcall==0.2.0 backports.entry-points-selectable==1.1.1 beautifulsoup4==4.10.0 -black==21.11b1 ; python_version >= "3.7.0" +black==21.12b0 ; python_version >= "3.7.0" bleach==4.1.0 cached-property==1.5.2 cachetools==4.2.4 @@ -174,10 +174,9 @@ python-dateutil==2.8.2 pytz==2021.3 pyyaml==6.0 pyzmq==22.3.0 -qrules==0.9.4 +qrules==0.9.5 qtconsole==5.2.1 qtpy==1.11.3 -regex==2021.11.10 requests==2.26.0 requests-oauthlib==1.3.0 restructuredtext-lint==1.3.2 diff --git a/.constraints/py3.8.txt b/.constraints/py3.8.txt index a8f38f47..5b117944 100644 --- a/.constraints/py3.8.txt +++ b/.constraints/py3.8.txt @@ -6,7 +6,7 @@ # absl-py==1.0.0 alabaster==0.7.12 -ampform==0.11.5 +ampform==0.12.0 anyio==3.4.0 aquirdturtle-collapsible-headings==3.1.0 argon2-cffi==21.1.0 @@ -17,7 +17,7 @@ babel==2.9.1 backcall==0.2.0 backports.entry-points-selectable==1.1.1 beautifulsoup4==4.10.0 -black==21.11b1 ; python_version >= "3.7.0" +black==21.12b0 ; python_version >= "3.7.0" bleach==4.1.0 cachetools==4.2.4 certifi==2021.10.8 @@ -173,10 +173,9 @@ python-dateutil==2.8.2 pytz==2021.3 pyyaml==6.0 pyzmq==22.3.0 -qrules==0.9.4 +qrules==0.9.5 qtconsole==5.2.1 qtpy==1.11.3 -regex==2021.11.10 requests==2.26.0 requests-oauthlib==1.3.0 restructuredtext-lint==1.3.2 diff --git a/.constraints/py3.9.txt b/.constraints/py3.9.txt index be4ac057..cd91d0b5 100644 --- a/.constraints/py3.9.txt +++ b/.constraints/py3.9.txt @@ -6,7 +6,7 @@ # absl-py==1.0.0 alabaster==0.7.12 -ampform==0.11.5 +ampform==0.12.0 anyio==3.4.0 aquirdturtle-collapsible-headings==3.1.0 argon2-cffi==21.1.0 @@ -17,7 +17,7 @@ babel==2.9.1 backcall==0.2.0 backports.entry-points-selectable==1.1.1 beautifulsoup4==4.10.0 -black==21.11b1 ; python_version >= "3.7.0" +black==21.12b0 ; python_version >= "3.7.0" bleach==4.1.0 cachetools==4.2.4 certifi==2021.10.8 @@ -172,10 +172,9 @@ python-dateutil==2.8.2 pytz==2021.3 pyyaml==6.0 pyzmq==22.3.0 -qrules==0.9.4 +qrules==0.9.5 qtconsole==5.2.1 qtpy==1.11.3 -regex==2021.11.10 requests==2.26.0 requests-oauthlib==1.3.0 restructuredtext-lint==1.3.2 diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 54338422..afe566ac 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -19,7 +19,7 @@ jobs: git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com" - name: Push to stable branch run: | - git push origin HEAD:refs/heads/stable + git push origin HEAD:refs/heads/stable --force - name: Push to matching minor version branch env: TAG: ${{ github.ref_name }} @@ -27,7 +27,7 @@ jobs: re='^([0-9]+)\.([0-9]+)\.[0-9]+' if [[ $TAG =~ $re ]]; then MINOR_VERSION_BRANCH="${BASH_REMATCH[1]}.${BASH_REMATCH[2]}.x" - git push origin HEAD:refs/heads/$MINOR_VERSION_BRANCH + git push origin HEAD:refs/heads/$MINOR_VERSION_BRANCH --force fi pypi: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe897b38..d653ab1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,7 +44,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/ComPWA/repo-maintenance - rev: 0.0.92 + rev: 0.0.93 hooks: - id: check-dev-files args: @@ -57,7 +57,7 @@ repos: - id: format-setup-cfg - repo: https://github.com/psf/black - rev: 21.11b1 + rev: 21.12b0 hooks: - id: black @@ -120,7 +120,7 @@ repos: metadata.toc-showtags - repo: https://github.com/pre-commit/mirrors-prettier - rev: v2.5.0 + rev: v2.5.1 hooks: - id: prettier @@ -130,7 +130,7 @@ repos: - id: pydocstyle - repo: https://github.com/ComPWA/mirrors-pyright - rev: v1.1.191 + rev: v1.1.192 hooks: - id: pyright diff --git a/.pylintrc b/.pylintrc index e5f5ebe6..fd275f95 100644 --- a/.pylintrc +++ b/.pylintrc @@ -27,6 +27,7 @@ disable= too-few-public-methods, # data containers (attrs) and interface classes unspecified-encoding, # http://pylint.pycqa.org/en/latest/whatsnew/2.10.html unused-import, # https://www.flake8rules.com/rules/F401 + wrong-import-order, # handled by isort [SIMILARITIES] ignore-imports=yes # https://stackoverflow.com/a/30007053 diff --git a/benchmarks/ampform.py b/benchmarks/ampform.py index b20dc02b..872ef23a 100644 --- a/benchmarks/ampform.py +++ b/benchmarks/ampform.py @@ -7,7 +7,7 @@ import tensorwaves as tw from tensorwaves.data.phasespace import TFUniformRealNumberGenerator -from tensorwaves.data.transform import HelicityTransformer +from tensorwaves.data.transform import SympyDataTransformer from tensorwaves.function.sympy import create_parametrized_function from tensorwaves.interface import ( DataSample, @@ -64,9 +64,10 @@ def generate_data( function: ParametrizedFunction, data_sample_size: int, phsp_sample_size: int, + backend: str, transform: bool = False, ) -> Tuple[DataSample, DataSample]: - reaction = model.adapter.reaction_info + reaction = model.reaction_info final_state = reaction.final_state phsp = tw.data.generate_phsp( size=phsp_sample_size, @@ -75,57 +76,57 @@ def generate_data( random_generator=TFUniformRealNumberGenerator(seed=0), ) - helicity_transformer = HelicityTransformer(model.adapter) + expressions = model.kinematic_variables + converter = SympyDataTransformer.from_sympy(expressions, backend) data = tw.data.generate_data( size=data_sample_size, initial_state_mass=reaction.initial_state[-1].mass, final_state_masses={i: p.mass for i, p in final_state.items()}, - data_transformer=helicity_transformer, + data_transformer=converter, intensity=function, random_generator=TFUniformRealNumberGenerator(seed=0), ) if transform: - data = helicity_transformer(data) - phsp = helicity_transformer(phsp) + data = converter(data) + phsp = converter(phsp) return data, phsp def fit( - data_set: DataSample, - phsp_set: DataSample, + data: DataSample, + phsp: DataSample, function: ParametrizedFunction, initial_parameters: Mapping[str, ParameterValue], backend: str, ) -> FitResult: estimator = tw.estimator.UnbinnedNLL( function, - data=data_set, - phsp=phsp_set, + data=data, + phsp=phsp, backend=backend, ) optimizer = tw.optimizer.Minuit2() - return optimizer.optimize(estimator, initial_parameters) class TestJPsiToGammaPiPi: expected_data = { - 0: [ + "p0": [ [1.50757377596, 0.37918944935, 0.73396599969, 1.26106620078], [1.41389525301, -0.07315064441, -0.21998573758, 1.39475985207], [1.52128570461, 0.06569896528, -1.51812710851, 0.0726906006], [1.51480310845, 1.40672331053, 0.49678572189, -0.26260603856], [1.52384281483, 0.79694939592, 1.29832389761, -0.03638188481], ], - 1: [ + "p1": [ [1.42066087326, -0.34871369761, -0.72119471428, -1.1654765212], [0.96610319301, -0.26739932067, -0.15455480956, -0.90539883872], [0.60647770024, 0.11616448713, 0.57584161239, -0.06714695611], [1.01045883083, -0.88651015826, -0.46024226278, 0.0713099651], [1.04324742713, -0.48051670276, -0.91259832182, -0.08009031815], ], - 2: [ + "p2": [ [0.16866535079, -0.03047575173, -0.01277128542, -0.09558967958], [0.71690155399, 0.34054996508, 0.37454054715, -0.48936101336], [0.96913659515, -0.18186345241, 0.94228549612, -0.00554364449], @@ -145,13 +146,15 @@ def model(self) -> "HelicityModel": ) @pytest.mark.benchmark(group="data", min_rounds=1) - @pytest.mark.parametrize("backend", ["jax"]) + @pytest.mark.parametrize("backend", ["jax", "numpy", "tf"]) @pytest.mark.parametrize("size", [10_000]) def test_data(self, backend, benchmark, model, size): n_data = size n_phsp = 10 * n_data function = create_function(model, backend) - data, phsp = benchmark(generate_data, model, function, n_data, n_phsp) + data, phsp = benchmark( + generate_data, model, function, n_data, n_phsp, backend + ) assert len(next(iter(data.values()))) == n_data assert len(next(iter(phsp.values()))) == n_phsp @@ -169,7 +172,9 @@ def test_fit(self, backend, benchmark, model, size): n_data = size n_phsp = 10 * n_data function = create_function(model, backend) - data, phsp = generate_data(model, function, n_data, n_phsp, True) + data, phsp = generate_data( + model, function, n_data, n_phsp, backend, transform=True + ) coefficients = [p for p in function.parameters if p.startswith("C_{")] assert len(coefficients) >= 1 diff --git a/docs/abbreviate_signature.py b/docs/abbreviate_signature.py index 4103a676..36c4f215 100644 --- a/docs/abbreviate_signature.py +++ b/docs/abbreviate_signature.py @@ -21,9 +21,8 @@ def replace_link(text: str) -> str: "a set-like object providing a view on D's items": "typing.ItemsView", "a set-like object providing a view on D's keys": "typing.KeysView", "an object providing a view on D's values": "typing.ValuesView", - "numpy.typing._array_like._SupportsArray": "numpy.typing.ArrayLike", - "numpy.typing._dtype_like._DTypeDict": "numpy.typing.DTypeLike", - "numpy.typing._dtype_like._SupportsDType": "numpy.typing.DTypeLike", + "sp.Expr": "sympy.core.expr.Expr", + "sp.Symbol": "sympy.core.symbol.Symbol", "typing_extensions.Protocol": "typing.Protocol", } for old, new in replacements.items(): diff --git a/docs/usage.ipynb b/docs/usage.ipynb index f7f9d1b6..a716b5fe 100644 --- a/docs/usage.ipynb +++ b/docs/usage.ipynb @@ -92,7 +92,7 @@ "from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff\n", "\n", "from tensorwaves.data import generate_data, generate_phsp\n", - "from tensorwaves.data.transform import HelicityTransformer\n", + "from tensorwaves.data.transform import SympyDataTransformer\n", "from tensorwaves.estimator import UnbinnedNLL\n", "from tensorwaves.function.sympy import create_parametrized_function\n", "from tensorwaves.optimizer import Minuit2\n", @@ -202,20 +202,33 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "intensity = create_parametrized_function(\n", " expression=model.expression.doit(),\n", " parameters=model.parameter_defaults,\n", " backend=\"jax\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "helicity_transformer = SympyDataTransformer.from_sympy(\n", + " model.kinematic_variables, backend=\"numpy\"\n", ")\n", - "helicity_transformer = HelicityTransformer(model.adapter)\n", - "reaction_info = model.adapter.reaction_info\n", - "initial_state_mass = reaction_info.initial_state[-1].mass\n", - "final_state_masses = {i: p.mass for i, p in reaction_info.final_state.items()}\n", - "phsp_sample = generate_phsp(100_000, initial_state_mass, final_state_masses)\n", - "data_sample = generate_data(\n", + "initial_state_mass = reaction.initial_state[-1].mass\n", + "final_state_masses = {i: p.mass for i, p in reaction.final_state.items()}\n", + "phsp_momenta = generate_phsp(100_000, initial_state_mass, final_state_masses)\n", + "data_momenta = generate_data(\n", " 10_000,\n", " initial_state_mass,\n", " final_state_masses,\n", @@ -232,9 +245,8 @@ }, "outputs": [], "source": [ - "phsp_set = helicity_transformer(phsp_sample)\n", - "data_set = helicity_transformer(data_sample)\n", - "data_frame = pd.DataFrame(data_set)" + "phsp = helicity_transformer(phsp_momenta)\n", + "data = helicity_transformer(data_momenta)" ] }, { @@ -254,33 +266,28 @@ "import numpy as np\n", "from matplotlib import cm\n", "\n", - "reaction_info = model.adapter.reaction_info\n", - "intermediate_states = sorted(\n", - " (\n", - " p\n", - " for p in model.particles\n", - " if p not in reaction_info.final_state.values()\n", - " and p not in reaction_info.initial_state.values()\n", - " ),\n", + "resonances = sorted(\n", + " reaction.get_intermediate_particles(),\n", " key=lambda p: p.mass,\n", ")\n", "\n", - "evenly_spaced_interval = np.linspace(0, 1, len(intermediate_states))\n", + "evenly_spaced_interval = np.linspace(0, 1, len(resonances))\n", "colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n", "\n", "\n", "def indicate_masses():\n", " plt.xlabel(\"$m$ [GeV]\")\n", - " for i, p in enumerate(intermediate_states):\n", + " for i, p in enumerate(resonances):\n", " plt.gca().axvline(\n", " x=p.mass, linestyle=\"dotted\", label=p.name, color=colors[i]\n", " )\n", "\n", "\n", "fig, ax = plt.subplots(figsize=(9, 4))\n", - "data_frame[\"m_12\"].hist(bins=100, alpha=0.5, density=True, ax=ax)\n", + "ax.hist(data[\"m_12\"], bins=100, alpha=0.5, density=True)\n", "indicate_masses()\n", - "plt.legend();" + "plt.legend()\n", + "plt.show()" ] }, { @@ -320,24 +327,24 @@ "\n", "def compare_model(\n", " variable_name,\n", - " data_set,\n", - " phsp_set,\n", + " data,\n", + " phsp,\n", " intensity_model,\n", " bins=150,\n", "):\n", - " data = data_set[variable_name]\n", - " phsp = phsp_set[variable_name]\n", - " intensities = intensity_model(phsp_set)\n", + " intensities = intensity_model(phsp)\n", " _, ax = plt.subplots(figsize=(9, 4))\n", + " data_1d = data[variable_name]\n", " ax.hist(\n", - " data,\n", + " data_1d,\n", " bins=bins,\n", " alpha=0.5,\n", " label=\"data\",\n", " density=True,\n", " )\n", + " phsp_1d = phsp[variable_name]\n", " ax.hist(\n", - " phsp,\n", + " phsp_1d,\n", " weights=intensities,\n", " bins=bins,\n", " histtype=\"step\",\n", @@ -352,13 +359,15 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "estimator = UnbinnedNLL(\n", " intensity,\n", - " data_set,\n", - " phsp_set,\n", + " data,\n", + " phsp,\n", " backend=\"jax\",\n", ")\n", "initial_parameters = {\n", @@ -383,7 +392,7 @@ }, "outputs": [], "source": [ - "compare_model(\"m_12\", data_set, phsp_set, intensity)\n", + "compare_model(\"m_12\", data, phsp, intensity)\n", "print(\"Number of free parameters:\", len(initial_parameters))" ] }, @@ -427,7 +436,7 @@ "source": [ "optimized_parameters = fit_result.parameter_values\n", "intensity.update_parameters(optimized_parameters)\n", - "compare_model(\"m_12\", data_set, phsp_set, intensity)" + "compare_model(\"m_12\", data, phsp, intensity)" ] }, { diff --git a/docs/usage/analytic-continuation.ipynb b/docs/usage/analytic-continuation.ipynb index a01fc0cb..ef9e75ca 100644 --- a/docs/usage/analytic-continuation.ipynb +++ b/docs/usage/analytic-continuation.ipynb @@ -70,8 +70,8 @@ "import qrules\n", "from IPython.display import Math\n", "\n", - "from tensorwaves.data import generate_data, generate_phsp\n", - "from tensorwaves.data.transform import HelicityTransformer\n", + "from tensorwaves.data import generate_data\n", + "from tensorwaves.data.transform import SympyDataTransformer\n", "from tensorwaves.function.sympy import create_parametrized_function\n", "\n", "logger = logging.getLogger()\n", @@ -198,12 +198,12 @@ " parameters=model.parameter_defaults,\n", " backend=\"jax\",\n", ")\n", - "helicity_transformer = HelicityTransformer(model.adapter)\n", - "reaction_info = model.adapter.reaction_info\n", - "initial_state_mass = reaction_info.initial_state[-1].mass\n", - "final_state_masses = {i: p.mass for i, p in reaction_info.final_state.items()}\n", - "phsp_sample = generate_phsp(100_000, initial_state_mass, final_state_masses)\n", - "data_sample = generate_data(\n", + "helicity_transformer = SympyDataTransformer.from_sympy(\n", + " model.kinematic_variables, backend=\"numpy\"\n", + ")\n", + "initial_state_mass = reaction.initial_state[-1].mass\n", + "final_state_masses = {i: p.mass for i, p in reaction.final_state.items()}\n", + "data_momenta = generate_data(\n", " 2_000,\n", " initial_state_mass,\n", " final_state_masses,\n", @@ -228,24 +228,18 @@ "import numpy as np\n", "from matplotlib import cm\n", "\n", - "reaction_info = model.adapter.reaction_info\n", - "intermediate_states = sorted(\n", - " (\n", - " p\n", - " for p in model.particles\n", - " if p not in reaction_info.final_state.values()\n", - " and p not in reaction_info.initial_state.values()\n", - " ),\n", + "resonances = sorted(\n", + " reaction.get_intermediate_particles(),\n", " key=lambda p: p.mass,\n", ")\n", "\n", - "evenly_spaced_interval = np.linspace(0, 1, len(intermediate_states))\n", + "evenly_spaced_interval = np.linspace(0, 1, len(resonances))\n", "colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n", "\n", "\n", "def indicate_masses():\n", " plt.xlabel(\"$m_{12}$ [GeV]\")\n", - " for i, p in enumerate(intermediate_states):\n", + " for i, p in enumerate(resonances):\n", " plt.gca().axvline(\n", " x=p.mass, linestyle=\"dotted\", label=p.name, color=colors[i]\n", " )" @@ -257,9 +251,8 @@ "metadata": {}, "outputs": [], "source": [ - "phsp_set = helicity_transformer(phsp_sample)\n", - "data_set = helicity_transformer(data_sample)\n", - "data_frame = pd.DataFrame(data_set)\n", + "data = helicity_transformer(data_momenta)\n", + "data_frame = pd.DataFrame(data)\n", "data_frame[\"m_12\"].hist(bins=50, alpha=0.5, density=True)\n", "indicate_masses()\n", "plt.legend();" diff --git a/docs/usage/step2.ipynb b/docs/usage/step2.ipynb index 05a28d4b..c0f76960 100644 --- a/docs/usage/step2.ipynb +++ b/docs/usage/step2.ipynb @@ -83,7 +83,7 @@ }, "outputs": [], "source": [ - "reaction_info = model.adapter.reaction_info\n", + "reaction_info = model.reaction_info\n", "initial_state = next(iter(reaction_info.initial_state.values()))\n", "print(\"Initial state:\")\n", "print(\" \", initial_state.name)\n", @@ -129,10 +129,9 @@ "from tensorwaves.data import TFUniformRealNumberGenerator, generate_phsp\n", "\n", "rng = TFUniformRealNumberGenerator(seed=0)\n", - "reaction_info = model.adapter.reaction_info\n", "initial_state_mass = reaction_info.initial_state[-1].mass\n", "final_state_masses = {i: p.mass for i, p in reaction_info.final_state.items()}\n", - "phsp_sample = generate_phsp(\n", + "phsp_momenta = generate_phsp(\n", " size=100_000,\n", " initial_state_mass=initial_state_mass,\n", " final_state_masses=final_state_masses,\n", @@ -141,7 +140,7 @@ "pd.DataFrame(\n", " {\n", " (k, label): np.transpose(v)[i]\n", - " for k, v in phsp_sample.items()\n", + " for k, v in phsp_momenta.items()\n", " for i, label in enumerate([\"E\", \"px\", \"py\", \"pz\"])\n", " }\n", ")" @@ -209,7 +208,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "A problem is that {class}`.ParametrizedBackendFunction` takes a {obj}`.DataSample` with kinematic variables for the helicity formalism as input, not a set of four-momenta. We therefore need to construct a {class}`.DataTransformer` to transform these four-momenta to function variables. In this case, we work with the helicity formalism, so we construct a {class}`.HelicityTransformer`:" + "A problem is that {class}`.ParametrizedBackendFunction` takes a {obj}`.DataSample` with kinematic variables for the helicity formalism as input, not a set of four-momenta. We therefore need to construct a {class}`.DataTransformer` to transform these four-momenta to function variables. In this case, we work with the helicity formalism, so we construct a {class}`.SympyDataTransformer`:" ] }, { @@ -218,9 +217,11 @@ "metadata": {}, "outputs": [], "source": [ - "from tensorwaves.data.transform import HelicityTransformer\n", + "from tensorwaves.data.transform import SympyDataTransformer\n", "\n", - "helicity_transformer = HelicityTransformer(model.adapter)" + "helicity_transformer = SympyDataTransformer.from_sympy(\n", + " model.kinematic_variables, backend=\"jax\"\n", + ")" ] }, { @@ -242,7 +243,7 @@ "source": [ "from tensorwaves.data import generate_data\n", "\n", - "data_sample = generate_data(\n", + "data_momenta = generate_data(\n", " size=10_000,\n", " initial_state_mass=initial_state_mass,\n", " final_state_masses=final_state_masses,\n", @@ -253,7 +254,7 @@ "pd.DataFrame(\n", " {\n", " (k, label): np.transpose(v)[i]\n", - " for k, v in data_sample.items()\n", + " for k, v in data_momenta.items()\n", " for i, label in enumerate([\"E\", \"px\", \"py\", \"pz\"])\n", " }\n", ")" @@ -277,7 +278,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We now have a phase space sample and an intensity-based sample. Their data structure isn't the most informative though: it's just a collection of four-momentum tuples. But we can again use the {class}`.HelicityTransformer` to convert these four-momenta to (in the case of the helicity formalism) invariant masses and helicity angles:" + "We now have a phase space sample and an intensity-based sample. Their data structure isn't the most informative though: it's just a collection of four-momentum tuples. But we can again use the {class}`.SympyDataTransformer` to convert these four-momenta to (in the case of the helicity formalism) invariant masses and helicity angles:" ] }, { @@ -286,16 +287,16 @@ "metadata": {}, "outputs": [], "source": [ - "phsp_set = helicity_transformer(phsp_sample)\n", - "data_set = helicity_transformer(data_sample)\n", - "list(data_set)" + "phsp = helicity_transformer(phsp_momenta)\n", + "data = helicity_transformer(data_momenta)\n", + "list(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The {obj}`~ampform.data.DataSet` is just a mapping of kinematic variables names to a sequence of values. The numbers you see here are final state IDs as defined in the {class}`~ampform.helicity.HelicityModel` member of the {class}`~ampform.helicity.HelicityModel`:" + "The {obj}`.DataSample` is a mapping of kinematic variables names to a 1-dimensional array of values. The numbers you see here are final state IDs as defined in the {class}`~ampform.helicity.HelicityModel` member of the {class}`~ampform.helicity.HelicityModel`:" ] }, { @@ -330,7 +331,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {obj}`~ampform.data.DataSet` can easily be converted to a {class}`pandas.DataFrame`:" + "The {obj}`.DataSample` can easily be converted to a {class}`pandas.DataFrame`:" ] }, { @@ -345,8 +346,8 @@ "source": [ "import pandas as pd\n", "\n", - "data_frame = pd.DataFrame(data_set)\n", - "phsp_frame = pd.DataFrame(data_set)\n", + "data_frame = pd.DataFrame(data)\n", + "phsp_frame = pd.DataFrame(data)\n", "data_frame" ] }, @@ -369,14 +370,8 @@ "source": [ "from matplotlib import cm\n", "\n", - "reaction_info = model.adapter.reaction_info\n", "resonances = sorted(\n", - " (\n", - " p\n", - " for p in model.particles\n", - " if p not in reaction_info.final_state.values()\n", - " and p not in reaction_info.initial_state.values()\n", - " ),\n", + " reaction_info.get_intermediate_particles(),\n", " key=lambda p: p.mass,\n", ")\n", "\n", @@ -437,10 +432,10 @@ "source": [ "import pickle\n", "\n", - "with open(\"data_set.pickle\", \"wb\") as stream:\n", - " pickle.dump(data_set, stream)\n", - "with open(\"phsp_set.pickle\", \"wb\") as stream:\n", - " pickle.dump(phsp_set, stream)" + "with open(\"data.pickle\", \"wb\") as stream:\n", + " pickle.dump(data, stream)\n", + "with open(\"phsp.pickle\", \"wb\") as stream:\n", + " pickle.dump(phsp, stream)" ] }, { diff --git a/docs/usage/step3.ipynb b/docs/usage/step3.ipynb index c7ef03f6..9c46b61f 100644 --- a/docs/usage/step3.ipynb +++ b/docs/usage/step3.ipynb @@ -80,10 +80,10 @@ "reaction = qrules.io.load(\"transitions.json\")\n", "with open(\"helicity_model.pickle\", \"rb\") as stream:\n", " model: HelicityModel = pickle.load(stream)\n", - "with open(\"data_set.pickle\", \"rb\") as stream:\n", - " data_set = pickle.load(stream)\n", - "with open(\"phsp_set.pickle\", \"rb\") as stream:\n", - " phsp_set = pickle.load(stream)\n", + "with open(\"data.pickle\", \"rb\") as stream:\n", + " data = pickle.load(stream)\n", + "with open(\"phsp.pickle\", \"rb\") as stream:\n", + " phsp = pickle.load(stream)\n", "\n", "function = create_parametrized_function(\n", " expression=model.expression.doit(),\n", @@ -123,8 +123,8 @@ "\n", "estimator = UnbinnedNLL(\n", " function,\n", - " data=data_set,\n", - " phsp=phsp_set,\n", + " data=data,\n", + " phsp=phsp,\n", " backend=\"jax\",\n", ")" ] @@ -207,14 +207,9 @@ "import numpy as np\n", "from matplotlib import cm\n", "\n", - "reaction_info = model.adapter.reaction_info\n", + "reaction_info = model.reaction_info\n", "resonances = sorted(\n", - " (\n", - " p\n", - " for p in model.particles\n", - " if p not in reaction_info.final_state.values()\n", - " and p not in reaction_info.initial_state.values()\n", - " ),\n", + " reaction_info.get_intermediate_particles(),\n", " key=lambda p: p.mass,\n", ")\n", "\n", @@ -235,25 +230,25 @@ "\n", "def compare_model(\n", " variable_name,\n", - " data_set,\n", - " phsp_set,\n", + " data,\n", + " phsp,\n", " function,\n", " bins=100,\n", "):\n", - " data = data_set[variable_name]\n", - " phsp = phsp_set[variable_name]\n", - " intensities = function(phsp_set)\n", + " intensities = function(phsp)\n", " _, ax = plt.subplots(figsize=(9, 4))\n", + " data_1d = data[variable_name]\n", " ax = plt.gca()\n", " ax.hist(\n", - " data,\n", + " data_1d,\n", " bins=bins,\n", " alpha=0.5,\n", " label=\"data\",\n", " density=True,\n", " )\n", + " phsp_1d = phsp[variable_name]\n", " ax.hist(\n", - " phsp,\n", + " phsp_1d,\n", " weights=intensities,\n", " bins=bins,\n", " histtype=\"step\",\n", @@ -276,7 +271,7 @@ "outputs": [], "source": [ "function.update_parameters(initial_parameters)\n", - "compare_model(\"m_12\", data_set, phsp_set, function)" + "compare_model(\"m_12\", data, phsp, function)" ] }, { @@ -351,14 +346,16 @@ " def update_plot(self, nbins: int):\n", " if self.__fig is None or self.__ax is None:\n", " self.__fig, self.__ax = plt.subplots(1, figsize=(8, 5))\n", - " data = data_set[self.__variable]\n", - " phsp = phsp_set[self.__variable]\n", " function.update_parameters(self.__latest_parameters)\n", - " intensities = function(phsp_set)\n", + " intensities = function(phsp)\n", " self.__ax.cla()\n", - " self.__ax.hist(data, bins=nbins, alpha=0.5, label=\"data\", density=True)\n", + " data_1d = data[self.__variable]\n", + " self.__ax.hist(\n", + " data_1d, bins=nbins, alpha=0.5, label=\"data\", density=True\n", + " )\n", + " phsp_1d = phsp[self.__variable]\n", " self.__ax.hist(\n", - " phsp,\n", + " phsp_1d,\n", " weights=intensities,\n", " bins=nbins,\n", " histtype=\"step\",\n", @@ -512,7 +509,7 @@ "outputs": [], "source": [ "n_real_par = fit_result.count_number_of_parameters(complex_twice=True)\n", - "n_events = len(list(data_set.values())[0])\n", + "n_events = len(list(data.values())[0])\n", "log_likelihood = -fit_result.estimator_value\n", "\n", "bic = n_real_par * np.log(n_events) - 2 * log_likelihood\n", @@ -670,7 +667,7 @@ "outputs": [], "source": [ "function.update_parameters(latest_parameters)\n", - "compare_model(\"m_12\", data_set, phsp_set, function)" + "compare_model(\"m_12\", data, phsp, function)" ] }, { @@ -780,7 +777,7 @@ "source": [ "import numpy as np\n", "\n", - "difference = np.average(from_amplitudes(phsp_set) - from_intensity(phsp_set))\n", + "difference = np.average(from_amplitudes(phsp) - from_intensity(phsp))\n", "assert np.round(difference, decimals=15) == 0.0" ] }, @@ -807,15 +804,15 @@ "fig, ax = plt.subplots(1, figsize=(8, 5))\n", "bins = 150\n", "ax.hist(\n", - " phsp_set[\"m_12\"],\n", - " weights=function(phsp_set),\n", + " phsp[\"m_12\"],\n", + " weights=function(phsp),\n", " bins=bins,\n", " alpha=0.2,\n", " label=\"full intensity\",\n", ")\n", "ax.hist(\n", - " phsp_set[\"m_12\"],\n", - " weights=from_intensity(phsp_set),\n", + " phsp[\"m_12\"],\n", + " weights=from_intensity(phsp),\n", " bins=bins,\n", " histtype=\"step\",\n", " label=R\"$J/\\psi(1S)_{-1} \\to \\gamma_{-1} \\pi^0 \\pi^0$\",\n", @@ -850,7 +847,7 @@ "import logging\n", "\n", "from tensorwaves.data import generate_data\n", - "from tensorwaves.data.transform import HelicityTransformer\n", + "from tensorwaves.data.transform import SympyDataTransformer\n", "\n", "logging.basicConfig()\n", "logging.getLogger().setLevel(logging.ERROR)\n", @@ -865,21 +862,25 @@ "]\n", "initial_state_mass = reaction_info.initial_state[-1].mass\n", "final_state_masses = {i: p.mass for i, p in reaction_info.final_state.items()}\n", - "helicity_transformer = HelicityTransformer(model.adapter)\n", - "sub_intensities = [\n", - " generate_data(\n", + "\n", + "helicity_transformer = SympyDataTransformer.from_sympy(\n", + " model.kinematic_variables, backend=\"numpy\"\n", + ")\n", + "masses = []\n", + "for component in intensity_components:\n", + " sub_events = generate_data(\n", " size=5_000,\n", " initial_state_mass=initial_state_mass,\n", " final_state_masses=final_state_masses,\n", " data_transformer=helicity_transformer,\n", " intensity=component,\n", " )\n", - " for component in intensity_components\n", - "]\n", + " sub_dataset = helicity_transformer(sub_events)\n", + " masses.append(sub_dataset[\"m_12\"])\n", "\n", "fig, ax = plt.subplots(1, figsize=(8, 5))\n", "plt.hist(\n", - " [helicity_transformer(i)[\"m_12\"] for i in sub_intensities],\n", + " masses,\n", " bins=100,\n", " stacked=True,\n", " alpha=0.6,\n", diff --git a/pytest.ini b/pytest.ini index 945260f0..c4959692 100644 --- a/pytest.ini +++ b/pytest.ini @@ -20,7 +20,9 @@ filterwarnings = ignore:Passing a schema to Validator.iter_errors is deprecated.*:DeprecationWarning ignore:invalid value encountered in log.*:RuntimeWarning ignore:invalid value encountered in sqrt:RuntimeWarning + ignore:invalid value encountered in true_divide:RuntimeWarning ignore:numpy.ufunc size changed, may indicate binary incompatibility.*:RuntimeWarning + ignore:unclosed .*:ResourceWarning markers = slow: marks tests as slow (select with '-m slow') norecursedirs = diff --git a/setup.cfg b/setup.cfg index 4d24b0ea..80ee280b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ python_requires = >=3.6, <3.10 setup_requires = setuptools_scm install_requires = - ampform >=0.10, !=0.11.2, <0.12 + ampform ==0.12.* attrs >=20.1.0 # https://www.attrs.org/en/stable/api.html#next-gen iminuit >=2.0 numpy diff --git a/src/tensorwaves/data/__init__.py b/src/tensorwaves/data/__init__.py index 9ebb3b6e..be274c00 100644 --- a/src/tensorwaves/data/__init__.py +++ b/src/tensorwaves/data/__init__.py @@ -1,7 +1,8 @@ +# pylint: disable=too-many-arguments """The `.data` module takes care of data generation.""" import logging -from typing import Mapping, Optional, Tuple +from typing import Any, Mapping, Optional, Tuple import numpy as np from tqdm.auto import tqdm @@ -28,7 +29,7 @@ ] -def generate_data( # pylint: disable=too-many-arguments +def generate_data( # pylint: disable=too-many-arguments too-many-locals size: int, initial_state_mass: float, final_state_masses: Mapping[int, float], @@ -55,9 +56,6 @@ def generate_data( # pylint: disable=too-many-arguments generated from many smaller samples, aka bunches. """ - # pylint: disable=import-outside-toplevel - from ampform.data import EventCollection - if phsp_generator is None: phsp_gen_instance = TFPhaseSpaceGenerator() phsp_gen_instance.setup(initial_state_mass, final_state_masses) @@ -69,9 +67,9 @@ def generate_data( # pylint: disable=too-many-arguments desc="Generating intensity-based sample", disable=logging.getLogger().level > logging.WARNING, ) - momentum_pool = EventCollection({}) + momentum_pool: DataSample = {} current_max = 0.0 - while momentum_pool.n_events < size: + while _get_number_of_events(momentum_pool) < size: bunch, maxvalue = _generate_data_bunch( bunch_size, phsp_gen_instance, @@ -81,23 +79,25 @@ def generate_data( # pylint: disable=too-many-arguments ) if maxvalue > current_max: current_max = 1.05 * maxvalue - if momentum_pool.n_events > 0: + if _get_number_of_events(momentum_pool) > 0: logging.info( "processed bunch maximum of %s is over current" " maximum %s. Restarting generation!", maxvalue, current_max, ) - momentum_pool = EventCollection({}) + momentum_pool = {} progress_bar.update(n=-progress_bar.n) # reset progress bar continue - if np.size(momentum_pool, 0) > 0: # type: ignore[arg-type] - momentum_pool.append(bunch) # type: ignore[arg-type] + if len(momentum_pool): + momentum_pool = _concatenate_events(momentum_pool, bunch) else: - momentum_pool = EventCollection(bunch) # type: ignore[arg-type] - progress_bar.update(n=momentum_pool.n_events - progress_bar.n) + momentum_pool = bunch + progress_bar.update( + n=_get_number_of_events(momentum_pool) - progress_bar.n + ) _finalize_progress_bar(progress_bar) - return momentum_pool.select_events(slice(0, size)) + return {i: values[:size] for i, values in momentum_pool.items()} def _generate_data_bunch( @@ -105,23 +105,20 @@ def _generate_data_bunch( phsp_generator: PhaseSpaceGenerator, random_generator: UniformRealNumberGenerator, intensity: Function, - kinematics: DataTransformer, + adapter: DataTransformer, ) -> Tuple[DataSample, float]: - # pylint: disable=import-outside-toplevel - from ampform.data import EventCollection - - phsp_sample, weights = phsp_generator.generate( + phsp_momenta, weights = phsp_generator.generate( bunch_size, random_generator ) - momentum_pool = EventCollection(phsp_sample) # type: ignore[arg-type] - dataset = kinematics(momentum_pool) + dataset = adapter(phsp_momenta) intensities = intensity(dataset) maxvalue: float = np.max(intensities) uniform_randoms = random_generator(bunch_size, max_value=maxvalue) - hit_and_miss_sample = momentum_pool.select_events( - weights * intensities > uniform_randoms + hit_and_miss_sample = _select_events( + phsp_momenta, + selector=weights * intensities > uniform_randoms, ) return hit_and_miss_sample, maxvalue @@ -148,9 +145,6 @@ def generate_phsp( generated from many smaller samples, aka bunches. """ - # pylint: disable=import-outside-toplevel - from ampform.data import EventCollection - if phsp_generator is None: phsp_generator = TFPhaseSpaceGenerator() phsp_generator.setup(initial_state_mass, final_state_masses) @@ -162,23 +156,43 @@ def generate_phsp( desc="Generating phase space sample", disable=logging.getLogger().level > logging.WARNING, ) - momentum_pool = EventCollection({}) - while momentum_pool.n_events < size: - phsp_sample, weights = phsp_generator.generate( + momentum_pool: DataSample = {} + while _get_number_of_events(momentum_pool) < size: + phsp_momenta, weights = phsp_generator.generate( bunch_size, random_generator ) hit_and_miss_randoms = random_generator(bunch_size) - bunch = EventCollection(phsp_sample).select_events( # type: ignore[arg-type] - weights > hit_and_miss_randoms + bunch = _select_events( + phsp_momenta, selector=weights > hit_and_miss_randoms ) - - if momentum_pool.n_events > 0: - momentum_pool.append(bunch) - else: - momentum_pool = bunch - progress_bar.update(n=bunch.n_events) + momentum_pool = _concatenate_events(momentum_pool, bunch) + progress_bar.update(n=_get_number_of_events(bunch)) _finalize_progress_bar(progress_bar) - return momentum_pool.select_events(slice(0, size)) + return {i: values[:size] for i, values in momentum_pool.items()} + + +def _get_number_of_events(four_momenta: DataSample) -> int: + if len(four_momenta) == 0: + return 0 + return len(next(iter(four_momenta.values()))) + + +def _concatenate_events( + sample1: DataSample, sample2: DataSample +) -> DataSample: + if len(sample1) and len(sample2) and set(sample1) != set(sample2): + raise ValueError( + "Keys of data sets are not matching", set(sample2), set(sample1) + ) + if _get_number_of_events(sample1) == 0: + return sample2 + return { + i: np.vstack((values, sample2[i])) for i, values in sample1.items() + } + + +def _select_events(four_momenta: DataSample, selector: Any) -> DataSample: + return {i: values[selector] for i, values in four_momenta.items()} def _finalize_progress_bar(progress_bar: tqdm) -> None: diff --git a/src/tensorwaves/data/phasespace.py b/src/tensorwaves/data/phasespace.py index 0f0ee309..079e8d34 100644 --- a/src/tensorwaves/data/phasespace.py +++ b/src/tensorwaves/data/phasespace.py @@ -46,11 +46,11 @@ def generate( weights, particles = self.__phsp_gen.generate( n_events=size, seed=rng.generator ) - phsp_sample = { - int(label): momenta.numpy()[:, [3, 0, 1, 2]] + phsp_momenta = { + f"p{label}": momenta.numpy()[:, [3, 0, 1, 2]] for label, momenta in particles.items() } - return phsp_sample, weights.numpy() + return phsp_momenta, weights.numpy() class TFUniformRealNumberGenerator(UniformRealNumberGenerator): diff --git a/src/tensorwaves/data/transform.py b/src/tensorwaves/data/transform.py index ba8e3d89..c88c3f64 100644 --- a/src/tensorwaves/data/transform.py +++ b/src/tensorwaves/data/transform.py @@ -1,29 +1,65 @@ """Implementations of `.DataTransformer`.""" -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set -import numpy as np - -from tensorwaves.interface import DataSample, DataTransformer +from tensorwaves.function import PositionalArgumentFunction +from tensorwaves.function.sympy import _lambdify_normal_or_fast +from tensorwaves.interface import DataSample, DataTransformer, Function if TYPE_CHECKING: - from ampform.kinematics import HelicityAdapter + import sympy as sp -class HelicityTransformer(DataTransformer): - """Transform four-momentum tuples to variables in the helicity formalism. +class SympyDataTransformer(DataTransformer): + """Implementation of a `.DataTransformer`.""" - Implementation of a `.DataTransformer` based on the - `~ampform.kinematics.HelicityAdapter`. - """ + def __init__(self, functions: Mapping[str, Function]) -> None: + if any(map(lambda f: not isinstance(f, Function), functions.values())): + raise TypeError( + "Not all values in the mapping are an instance of" + f" {Function.__name__}" + ) + self.__functions = dict(functions) - def __init__(self, helicity_adapter: "HelicityAdapter") -> None: - self.__helicity_adapter = helicity_adapter + @property + def functions(self) -> Dict[str, Function]: + """Read-only access to the internal mapping of functions.""" + return dict(self.__functions) def __call__(self, dataset: DataSample) -> DataSample: - # pylint: disable=import-outside-toplevel - from ampform.kinematics import EventCollection + """Transform one `.DataSample` into another `.DataSample`.""" + return { + key: function(dataset) + for key, function in self.__functions.items() + } - events = EventCollection({int(k): v for k, v in dataset.items()}) - dataset = self.__helicity_adapter.transform(events) - return {key: np.array(values) for key, values in dataset.items()} + @classmethod + def from_sympy( + cls, + expressions: Dict["sp.Symbol", "sp.Expr"], + backend: str, + *, + use_cse: bool = True, + max_complexity: Optional[int] = None, + ) -> "SympyDataTransformer": + expanded_expressions: Dict[str, "sp.Expr"] = { + k.name: expr.doit() for k, expr in expressions.items() + } + free_symbols: Set["sp.Symbol"] = set() + for expr in expanded_expressions.values(): + free_symbols |= expr.free_symbols + ordered_symbols = tuple(sorted(free_symbols, key=lambda s: s.name)) + argument_order = tuple(map(str, ordered_symbols)) + functions = {} + for variable_name, expr in expanded_expressions.items(): + function = _lambdify_normal_or_fast( + expr, + ordered_symbols, + backend, + use_cse=use_cse, + max_complexity=max_complexity, + ) + functions[variable_name] = PositionalArgumentFunction( + function, argument_order + ) + return cls(functions) diff --git a/src/tensorwaves/function/__init__.py b/src/tensorwaves/function/__init__.py index e3b01276..22d54b88 100644 --- a/src/tensorwaves/function/__init__.py +++ b/src/tensorwaves/function/__init__.py @@ -1,16 +1,88 @@ """Express mathematical expressions in terms of computational functions.""" -from typing import Callable, Dict, Mapping, Sequence +import inspect +from typing import Callable, Dict, Iterable, Mapping, Sequence, Tuple +import attr import numpy as np from tensorwaves.interface import ( DataSample, + Function, ParameterValue, ParametrizedFunction, ) +def _all_str( + _: "PositionalArgumentFunction", __: attr.Attribute, value: Iterable[str] +) -> None: + if not all(map(lambda s: isinstance(s, str), value)): + raise TypeError(f"Not all arguments are of type {str.__name__}") + + +def _all_unique( + _: "PositionalArgumentFunction", __: attr.Attribute, value: Iterable[str] +) -> None: + argument_names = list(value) + if len(set(argument_names)) != len(argument_names): + duplicate_arguments = [] + for arg_name in argument_names: + n_occurrences = argument_names.count(arg_name) + if n_occurrences > 1: + duplicate_arguments.append(arg_name) + raise ValueError( + f"There are duplicate argument names: {duplicate_arguments}" + ) + + +def _validate_arguments( + instance: "PositionalArgumentFunction", _: attr.Attribute, value: Callable +) -> None: + if not callable(value): + raise TypeError("Function is not callable") + n_args = len(instance.argument_order) + signature = inspect.signature(value) + if len(signature.parameters) != n_args: + if len(signature.parameters) == 1: + parameter = next(iter(signature.parameters.values())) + if parameter.kind == parameter.VAR_POSITIONAL: + return + raise ValueError( + f"Lambdified function expects {len(signature.parameters)}" + f" arguments, but {n_args} sorted arguments were provided." + ) + + +def _to_tuple(argument_order: Iterable[str]) -> Tuple[str, ...]: + return tuple(argument_order) + + +@attr.s(frozen=True) +class PositionalArgumentFunction(Function): + """Wrapper around a function with positional arguments. + + This class provides a :meth:`__call__` that can take a `.DataSample` for a + function with `positional arguments + `_. Its + :attr:`argument_order` redirect the keys in the `.DataSample` to the + argument positions in its underlying :attr:`function`. + """ + + function: Callable[..., np.ndarray] = attr.ib( + validator=_validate_arguments + ) + """A function with positional arguments only.""" + argument_order: Tuple[str, ...] = attr.ib( + converter=_to_tuple, validator=[_all_str, _all_unique] + ) + """Ordered labels for each positional argument.""" + + def __call__(self, dataset: DataSample) -> np.ndarray: + args = [dataset[var_name] for var_name in self.argument_order] + return self.function(*args) + + class ParametrizedBackendFunction(ParametrizedFunction): """Implements `.ParametrizedFunction` for a specific computational back-end.""" diff --git a/src/tensorwaves/function/_backend.py b/src/tensorwaves/function/_backend.py index 13ed5e8b..d487f43a 100644 --- a/src/tensorwaves/function/_backend.py +++ b/src/tensorwaves/function/_backend.py @@ -47,10 +47,13 @@ def get_backend_modules( return np, np.__dict__ # returning only np.__dict__ does not work well with conditionals if backend in {"tensorflow", "tf"}: - # pylint: disable=import-error + # pylint: disable=import-error, no-name-in-module # pyright: reportMissingImports=false import tensorflow as tf import tensorflow.experimental.numpy as tnp + from tensorflow.python.ops.numpy_ops import np_config + + np_config.enable_numpy_behavior() return tnp.__dict__, tf diff --git a/src/tensorwaves/function/sympy.py b/src/tensorwaves/function/sympy/__init__.py similarity index 72% rename from src/tensorwaves/function/sympy.py rename to src/tensorwaves/function/sympy/__init__.py index eef84f48..4c19f290 100644 --- a/src/tensorwaves/function/sympy.py +++ b/src/tensorwaves/function/sympy/__init__.py @@ -1,9 +1,9 @@ -# pylint: disable=abstract-method invalid-name protected-access +# pylint: disable=import-outside-toplevel """Lambdify `sympy` expression trees to a `.Function`.""" import logging -import re from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -15,52 +15,91 @@ Union, ) -import sympy as sp -from sympy.printing.numpy import NumPyPrinter -from sympy.printing.printer import Printer from tqdm.auto import tqdm +from tensorwaves.function import ( + ParametrizedBackendFunction, + PositionalArgumentFunction, +) from tensorwaves.function._backend import get_backend_modules, jit_compile from tensorwaves.interface import ParameterValue -from . import ParametrizedBackendFunction +if TYPE_CHECKING: + import sympy as sp + from sympy.printing.printer import Printer -def create_parametrized_function( - expression: sp.Expr, - parameters: Mapping[sp.Symbol, ParameterValue], +def create_function( + expression: "sp.Expr", backend: str, + max_complexity: Optional[int] = None, use_cse: bool = True, +) -> PositionalArgumentFunction: + sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name) + lambdified_function = _lambdify_normal_or_fast( + expression=expression, + symbols=sorted_symbols, + backend=backend, + max_complexity=max_complexity, + use_cse=use_cse, + ) + return PositionalArgumentFunction( + function=lambdified_function, + argument_order=tuple(map(str, sorted_symbols)), + ) + + +def create_parametrized_function( + expression: "sp.Expr", + parameters: Mapping["sp.Symbol", ParameterValue], + backend: str, max_complexity: Optional[int] = None, + use_cse: bool = True, ) -> ParametrizedBackendFunction: sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name) - if max_complexity is None: - lambdified_function = lambdify( - expression=expression, - symbols=sorted_symbols, - backend=backend, - use_cse=use_cse, - ) - else: - lambdified_function = fast_lambdify( - expression=expression, - symbols=sorted_symbols, - backend=backend, - use_cse=use_cse, - max_complexity=max_complexity, - ) + lambdified_function = _lambdify_normal_or_fast( + expression=expression, + symbols=sorted_symbols, + backend=backend, + max_complexity=max_complexity, + use_cse=use_cse, + ) return ParametrizedBackendFunction( function=lambdified_function, - argument_order=list(map(str, sorted_symbols)), + argument_order=tuple(map(str, sorted_symbols)), parameters={ symbol.name: value for symbol, value in parameters.items() }, ) +def _lambdify_normal_or_fast( + expression: "sp.Expr", + symbols: Sequence["sp.Symbol"], + backend: str, + max_complexity: Optional[int], + use_cse: bool, +) -> Callable: + """Switch between `.lambdify` and `.fast_lambdify`.""" + if max_complexity is None: + return lambdify( + expression=expression, + symbols=symbols, + backend=backend, + use_cse=use_cse, + ) + return fast_lambdify( + expression=expression, + symbols=symbols, + backend=backend, + max_complexity=max_complexity, + use_cse=use_cse, + ) + + def lambdify( - expression: sp.Expr, - symbols: Sequence[sp.Symbol], + expression: "sp.Expr", + symbols: Sequence["sp.Symbol"], backend: str, use_cse: bool = True, ) -> Callable: @@ -77,17 +116,17 @@ def lambdify( function. use_cse: Lambdify with common sub-expressions (see :code:`cse` argument in :func:`~sympy.utilities.lambdify.lambdify`). - kwargs: Any additional key-word arguments passed to - :func:`sympy.utilities.lambdify.lambdify`. """ # pylint: disable=import-outside-toplevel, too-many-return-statements def jax_lambdify() -> Callable: + from ._printer import JaxPrinter + return jit_compile(backend="jax")( _sympy_lambdify( expression, symbols, modules=modules, - printer=_JaxPrinter(), + printer=JaxPrinter(), use_cse=use_cse, ) ) @@ -95,19 +134,25 @@ def jax_lambdify() -> Callable: def numba_lambdify() -> Callable: return jit_compile(backend="numba")( _sympy_lambdify( - expression, symbols, modules="numpy", use_cse=use_cse + expression, + symbols, + use_cse=use_cse, + modules="numpy", ) ) def tensorflow_lambdify() -> Callable: # pylint: disable=import-error - import tensorflow.experimental.numpy as tnp # pyright: reportMissingImports=false + # pyright: reportMissingImports=false + import tensorflow.experimental.numpy as tnp + + from ._printer import TensorflowPrinter return _sympy_lambdify( expression, symbols, modules=tnp, - printer=_TensorflowPrinter(), + printer=TensorflowPrinter(), use_cse=use_cse, ) @@ -131,17 +176,22 @@ def tensorflow_lambdify() -> Callable: return tensorflow_lambdify() return _sympy_lambdify( - expression, symbols, modules=modules, use_cse=use_cse + expression, + symbols, + modules=modules, + use_cse=use_cse, ) def _sympy_lambdify( - expression: sp.Expr, - symbols: Sequence[sp.Symbol], + expression: "sp.Expr", + symbols: Sequence["sp.Symbol"], modules: Union[str, tuple, dict], use_cse: bool, - printer: Optional[Printer] = None, + printer: Optional["Printer"] = None, ) -> Callable: + import sympy as sp + dummy_replacements = { symbol: sp.Symbol(f"z{i}", **symbol.assumptions0) for i, symbol in enumerate(symbols) @@ -151,15 +201,15 @@ def _sympy_lambdify( return sp.lambdify( dummy_symbols, expression, + cse=use_cse, modules=modules, printer=printer, - cse=use_cse, ) -def fast_lambdify( - expression: sp.Expr, - symbols: Sequence[sp.Symbol], +def fast_lambdify( # pylint: disable=too-many-locals + expression: "sp.Expr", + symbols: Sequence["sp.Symbol"], backend: str, *, min_complexity: int = 0, @@ -205,10 +255,10 @@ def recombined_function(*args: Any) -> Any: def split_expression( - expression: sp.Expr, + expression: "sp.Expr", max_complexity: int, min_complexity: int = 1, -) -> Tuple[sp.Expr, Dict[sp.Symbol, sp.Expr]]: +) -> Tuple["sp.Expr", Dict["sp.Symbol", "sp.Expr"]]: """Split an expression into a 'top expression' and several sub-expressions. Replace nodes in the expression tree of a `sympy.Expr @@ -218,6 +268,8 @@ def split_expression( .. seealso:: :doc:`/usage/faster-lambdify` """ + import sympy as sp + i = 0 symbol_mapping: Dict[sp.Symbol, sp.Expr] = {} n_operations = sp.count_ops(expression) @@ -256,43 +308,3 @@ def recursive_split(sub_expression: sp.Expr) -> sp.Expr: def _use_progress_bar() -> bool: return logging.getLogger().level <= logging.WARNING - - -def _replace_module( - mapping: Dict[str, str], old: str, new: str -) -> Dict[str, str]: - return { - k: re.sub(fr"^{old}\.(.*)$", fr"{new}.\1", v) - for k, v in mapping.items() - } - - -class _CustomNumPyPrinter(NumPyPrinter): - def __init__(self) -> None: - # https://github.com/sympy/sympy/blob/f291f2d/sympy/utilities/lambdify.py#L821-L823 - super().__init__( - settings={ - "fully_qualified_modules": False, - "inline": True, - "allow_unknown_functions": True, - } - ) - self._kc = _replace_module(NumPyPrinter._kc, "numpy", self._module) - self._kf = _replace_module(NumPyPrinter._kf, "numpy", self._module) - self.printmethod = "_numpycode" # force using _numpycode methods - - -class _JaxPrinter(_CustomNumPyPrinter): - module_imports = {"jax": {"numpy as jnp"}} - _module = "jnp" - - -class _TensorflowPrinter(_CustomNumPyPrinter): - module_imports = {"tensorflow.experimental": {"numpy as tnp"}} - _module = "tnp" - - def __init__(self) -> None: - # https://github.com/sympy/sympy/blob/f1384c2/sympy/printing/printer.py#L21-L72 - super().__init__() - self.known_functions["ComplexSqrt"] = "sqrt" - self.printmethod = "_tensorflow_code" diff --git a/src/tensorwaves/function/sympy/_printer.py b/src/tensorwaves/function/sympy/_printer.py new file mode 100644 index 00000000..b750c5f9 --- /dev/null +++ b/src/tensorwaves/function/sympy/_printer.py @@ -0,0 +1,82 @@ +# pylint: disable=abstract-method protected-access +import re +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Type, TypeVar + +from sympy.printing.numpy import NumPyPrinter # noqa: E402 + +if TYPE_CHECKING: + import sympy as sp + + +def _replace_module( + mapping: Dict[str, str], old: str, new: str +) -> Dict[str, str]: + return { + k: re.sub(fr"^{old}\.(.*)$", fr"{new}.\1", v) + for k, v in mapping.items() + } + + +class CustomNumPyPrinter(NumPyPrinter): + def __init__(self) -> None: + # https://github.com/sympy/sympy/blob/f291f2d/sympy/utilities/lambdify.py#L821-L823 + super().__init__( + settings={ + "fully_qualified_modules": False, + "inline": True, + "allow_unknown_functions": True, + } + ) + self._kc = _replace_module(NumPyPrinter._kc, "numpy", self._module) + self._kf = _replace_module(NumPyPrinter._kf, "numpy", self._module) + self.printmethod = "_numpycode" # force using _numpycode methods + + +class JaxPrinter(CustomNumPyPrinter): + module_imports = {"jax": {"numpy as jnp"}} + _module = "jnp" + + +_T = TypeVar("_T") + + +def _forward_to_numpy_printer( + class_names: Iterable[str], +) -> Callable[[Type[_T]], Type[_T]]: + """Decorator for a `~sympy.printing.printer.Printer` class. + + Args: + class_names: The names of classes that should be printed with their + :code:`_numpycode()` method. + """ + + def decorator(decorated_class: Type[_T]) -> Type[_T]: + def _get_numpy_code(self: _T, expr: "sp.Expr", *args: Any) -> str: + return expr._numpycode(self, *args) + + for class_name in class_names: + method_name = f"_print_{class_name}" + setattr(decorated_class, method_name, _get_numpy_code) + return decorated_class + + return decorator + + +@_forward_to_numpy_printer( + [ + "ArrayAxisSum", + "ArrayMultiplication", + "BoostZ", + "RotationY", + "RotationZ", + ] +) +class TensorflowPrinter(CustomNumPyPrinter): + module_imports = {"tensorflow.experimental": {"numpy as tnp"}} + _module = "tnp" + + def __init__(self) -> None: + # https://github.com/sympy/sympy/blob/f1384c2/sympy/printing/printer.py#L21-L72 + super().__init__() + self.known_functions["ComplexSqrt"] = "sqrt" + self.printmethod = "_tensorflow_code" diff --git a/src/tensorwaves/interface.py b/src/tensorwaves/interface.py index b071fd21..54c50f5a 100644 --- a/src/tensorwaves/interface.py +++ b/src/tensorwaves/interface.py @@ -42,7 +42,7 @@ def __call__(self, data: InputType) -> OutputType: ... -DataSample = Mapping[Union[int, str], np.ndarray] +DataSample = Dict[str, np.ndarray] """Mapping of variable names to a sequence of data points, used by `Function`.""" ParameterValue = Union[complex, float] """Allowed types for parameter values.""" diff --git a/tests/data/test_generate.py b/tests/data/test_generate.py index 90a1082d..0f6f4c3e 100644 --- a/tests/data/test_generate.py +++ b/tests/data/test_generate.py @@ -42,7 +42,7 @@ def test_generate_data(): intensity=FlatDistribution(), random_generator=TFUniformRealNumberGenerator(seed=0), ) - assert set(phsp) == set(final_state_masses) + assert set(phsp) == {f"p{i}" for i in final_state_masses} assert set(phsp) == set(data) for i in phsp: assert pytest.approx(phsp[i]) == data[i] @@ -55,17 +55,17 @@ def test_generate_data(): "J/psi(1S)", ("pi0", "pi0", "pi0"), { - 0: [ + "p0": [ [0.841233472, 0.799667989, 0.159823862, 0.156340839], [0.640234742, -0.364360112, -0.371962329, 0.347228344], [0.631540320, 0.403805561, 0.417294074, -0.208401449], ], - 1: [ + "p1": [ [1.09765205, -0.05378975, -0.53523771, -0.94723204], [1.426564296, 1.168326711, -0.060296302, -0.805136016], [1.243480165, 0.014812643, 0.081738919, 1.233338364], ], - 2: [ + "p2": [ [1.158014477, -0.745878234, 0.375413844, 0.790891204], [1.030100961, -0.803966599, 0.432258632, 0.457907671], [1.22187951, -0.41861820, -0.49903210, -1.02493691], @@ -76,22 +76,22 @@ def test_generate_data(): "J/psi(1S)", ("pi0", "pi0", "pi0", "gamma"), { - 0: [ + "p0": [ [0.520913076, 0.037458949, 0.339629143, -0.369297399], [1.180624927, -0.569078090, 0.687702756, -0.760836072], [0.606831154, 0.543652274, 0.220242315, -0.077206475], ], - 1: [ + "p1": [ [0.353305116, 0.130561009, 0.299006221, -0.012444727], [0.194507152, 0.123009165, 0.057692537, 0.033979586], [0.331482507, 0.224048290, -0.156048645, 0.130817046], ], - 2: [ + "p2": [ [1.276779728, 0.236609937, -0.366594420, 1.192296945], [1.339317905, 0.571746863, -0.586304492, 1.051145223], [0.820720580, 0.402982692, -0.697161285, 0.083274400], ], - 3: [ + "p3": [ [0.945902080, -0.40462990, -0.27204094, -0.81055482], [0.38245001, -0.12567794, -0.15909080, -0.32428874], [1.337865758, -1.170683257, 0.632967615, -0.136884971], @@ -102,27 +102,27 @@ def test_generate_data(): "J/psi(1S)", ("pi0", "pi0", "pi0", "pi0", "gamma"), { - 0: [ + "p0": [ [1.000150296, 0.715439409, -0.284844373, -0.623772405], [0.353592342, 0.134562969, 0.189723778, 0.229578969], [0.734241552, 0.655088513, -0.205095150, -0.222905673], ], - 1: [ + "p1": [ [0.537685901, -0.062423993, 0.008278542, -0.516645045], [0.440319420, -0.075102421, -0.215361523, 0.351626927], [0.621720722, -0.569846157, -0.063070826, 0.199036046], ], - 2: [ + "p2": [ [0.588463958, -0.190428491, -0.002167052, 0.540188288], [0.77747437, -0.11485659, -0.55477746, -0.51505105], [0.543908922, -0.120958419, 0.236101553, -0.455239823], ], - 3: [ + "p3": [ [0.513251926, -0.286712460, -0.089479316, 0.393698133], [0.593575359, 0.536198573, -0.215753382, -0.007385008], [0.564116725, -0.442948181, -0.261969339, 0.187557768], ], - 4: [ + "p4": [ [0.457347916, -0.175874464, 0.368212199, 0.206531028], [0.931938511, -0.480802535, 0.796168585, -0.058769834], [0.632912076, 0.478664245, 0.294033763, 0.291551681], @@ -139,7 +139,7 @@ def test_generate_phsp( ): sample_size = 3 rng = TFUniformRealNumberGenerator(seed=0) - phsp_sample = generate_phsp( + phsp_momenta = generate_phsp( sample_size, initial_state_mass=pdg[initial_state].mass, final_state_masses={ @@ -147,11 +147,11 @@ def test_generate_phsp( }, random_generator=rng, ) - assert set(phsp_sample) == set(expected_sample) + assert set(phsp_momenta) == set(expected_sample) n_events = len(next(iter(expected_sample.values()))) for i in expected_sample: # pylint: disable=consider-using-dict-items expected_momenta = expected_sample[i] - momenta = phsp_sample[i] + momenta = phsp_momenta[i] assert len(expected_momenta) == n_events assert len(momenta) == n_events assert pytest.approx(momenta, abs=1e-6) == expected_sample[i] diff --git a/tests/data/test_phasespace.py b/tests/data/test_phasespace.py index 7d24529b..612055b2 100644 --- a/tests/data/test_phasespace.py +++ b/tests/data/test_phasespace.py @@ -27,30 +27,30 @@ def test_generate_deterministic(pdg: "ParticleCollection"): i: pdg[name].mass for i, name in enumerate(final_state_names) }, ) - phsp_sample, weights = phsp_generator.generate(sample_size, rng) + phsp_momenta, weights = phsp_generator.generate(sample_size, rng) print("Expected values, get by running pytest with the -s flag") pprint( { i: np.round(four_momenta, decimals=10).tolist() - for i, four_momenta in phsp_sample.items() + for i, four_momenta in phsp_momenta.items() } ) expected_sample = { - 0: [ + "p0": [ [0.7059154068, 0.3572095625, 0.251997269, 0.2441281612], [0.6996310679, -0.3562654953, -0.1367339084, 0.3102348449], [0.7592776659, 0.0551489184, 0.3313621005, -0.4648049287], [0.7820530714, 0.4694971942, 0.2238765653, -0.3056827887], [0.6628957748, 0.1287045232, 0.1927256954, 0.3716262275], ], - 1: [ + "p1": [ [1.2268366211, 0.0530779071, 0.2808911915, -0.0938614524], [1.2983113985, 0.0580707314, -0.345843232, -0.3847489307], [1.3730435556, -0.264045346, -0.3231669721, 0.5445096619], [1.2694745247, -0.0510249037, -0.3895930085, 0.2063451448], [1.3387073694, -0.167841506, -0.5904119798, 0.0279167867], ], - 2: [ + "p2": [ [1.1641479721, -0.4102874697, -0.5328884605, -0.1502667089], [1.0989575336, 0.2981947639, 0.4825771404, 0.0745140857], [0.9645787786, 0.2088964277, -0.0081951284, -0.0797047333], @@ -59,10 +59,10 @@ def test_generate_deterministic(pdg: "ParticleCollection"): ], } n_events = len(next(iter(expected_sample.values()))) - assert set(phsp_sample) == set(expected_sample) + assert set(phsp_momenta) == set(expected_sample) for i in expected_sample: # pylint: disable=consider-using-dict-items expected_momenta = expected_sample[i] - momenta = phsp_sample[i] + momenta = phsp_momenta[i] assert len(expected_momenta) == n_events assert len(momenta) == n_events assert pytest.approx(momenta) == expected_sample[i] diff --git a/tests/data/test_transform.py b/tests/data/test_transform.py new file mode 100644 index 00000000..ae775e90 --- /dev/null +++ b/tests/data/test_transform.py @@ -0,0 +1,26 @@ +# pylint: disable=no-self-use, invalid-name +import numpy as np +import pytest +import sympy as sp +from numpy import sqrt + +from tensorwaves.data.transform import SympyDataTransformer + + +class TestSympyDataTransformer: + @pytest.mark.parametrize("backend", ["jax", "numba", "numpy", "tf"]) + def test_polar_to_cartesian_coordinates(self, backend): + r, phi, x, y = sp.symbols("r phi x y") + expressions = { + x: r * sp.cos(phi), + y: r * sp.sin(phi), + } + converter = SympyDataTransformer.from_sympy(expressions, backend) + assert set(converter.functions) == {"x", "y"} + input_data = { + "r": np.ones(4), + "phi": np.array([0, np.pi / 4, np.pi / 2, np.pi]), + } + output = converter(input_data) # type: ignore[arg-type] + assert pytest.approx(output["x"]) == [1, sqrt(2) / 2, 0, -1] + assert pytest.approx(output["y"]) == [0, sqrt(2) / 2, 1, 0] diff --git a/tests/function/test_ampform.py b/tests/function/test_ampform.py index a0ac5f7c..a32c7a4d 100644 --- a/tests/function/test_ampform.py +++ b/tests/function/test_ampform.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from tensorwaves.data.transform import SympyDataTransformer from tensorwaves.function._backend import find_function from tensorwaves.function.sympy import create_parametrized_function @@ -9,7 +10,7 @@ @pytest.mark.parametrize("backend", ["jax", "math", "numba", "numpy", "tf"]) def test_complex_sqrt(backend: str): import sympy as sp - from ampform.dynamics.math import ComplexSqrt + from ampform.sympy.math import ComplexSqrt from numpy.lib.scimath import sqrt as complex_sqrt x = sp.Symbol("x") @@ -28,3 +29,42 @@ def test_complex_sqrt(backend: str): data = {"x": values} output_array = function(data) # type: ignore[arg-type] np.testing.assert_almost_equal(output_array, complex_sqrt(data["x"])) + + +@pytest.mark.parametrize("backend", ["jax", "numpy", "tf"]) +def test_four_momenta_to_helicity_angles(backend): + import ampform + import qrules + + reaction = qrules.generate_transitions( + initial_state=("J/psi(1S)", [+1]), + final_state=[("gamma", [+1]), "pi0", "pi0"], + allowed_intermediate_particles=["f(0)(500)"], + allowed_interaction_types=["EM", "strong"], + ) + + builder = ampform.get_builder(reaction) + model = builder.formulate() + + expressions = model.kinematic_variables + converter = SympyDataTransformer.from_sympy(expressions, backend) + assert set(converter.functions) == { + "m_0", + "m_012", + "m_1", + "m_12", + "m_2", + "phi_1+2", + "phi_1,1+2", + "theta_1+2", + "theta_1,1+2", + } + + zeros = np.zeros(shape=(1, 4)) + data_momenta = {"p0": zeros, "p1": zeros, "p2": zeros} + data = converter(data_momenta) + for var_name in converter.functions: + if var_name in {"phi_1,1+2", "theta_1+2", "theta_1,1+2"}: + assert np.isnan(data[var_name]) + else: + assert data[var_name] == 0 diff --git a/tests/function/test_function.py b/tests/function/test_function.py index 096e046a..a9d940c3 100644 --- a/tests/function/test_function.py +++ b/tests/function/test_function.py @@ -3,7 +3,10 @@ import pytest import sympy as sp -from tensorwaves.function import ParametrizedBackendFunction +from tensorwaves.function import ( + ParametrizedBackendFunction, + PositionalArgumentFunction, +) from tensorwaves.function.sympy import create_parametrized_function from tensorwaves.interface import DataSample @@ -50,3 +53,31 @@ def test_call( np.testing.assert_array_almost_equal( results, expected_results, decimal=4 ) + + +class TestPositionalArgumentFunction: + def test_call(self): + function = PositionalArgumentFunction( + function=lambda a, b, x, y: a * x ** 2 + b * y ** 2, + argument_order=("a", "b", "x", "y"), + ) + data: DataSample = { + "a": np.array([1, 0, +1, 1]), + "b": np.array([1, 0, -1, 1]), + "x": np.array([1, 1, +4, 2]), + "y": np.array([1, 1, -4, 3]), + } + output = function(data) + assert pytest.approx(output) == [2, 0, 0, 4 + 9] + + def test_variadic_args(self): + function = PositionalArgumentFunction( + function=lambda *args: args[0] + args[1], + argument_order=("a", "b"), + ) + data: DataSample = { + "a": np.array([1, 2, 3]), + "b": np.array([1, 2, 3]), + } + output = function(data) + assert pytest.approx(output) == [2, 4, 6] diff --git a/tests/function/test_sympy.py b/tests/function/test_sympy.py index 8de5d086..fc38cc29 100644 --- a/tests/function/test_sympy.py +++ b/tests/function/test_sympy.py @@ -6,13 +6,27 @@ import pytest import sympy as sp -from tensorwaves.function.sympy import fast_lambdify, split_expression +from tensorwaves.function.sympy import ( + create_function, + fast_lambdify, + split_expression, +) def create_expression(a, x, y, z) -> sp.Expr: return a * (x ** z + 2 * y) +@pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"]) +def test_create_function(backend: str): + symbols: Tuple[sp.Symbol, ...] = sp.symbols("a x y z") + a, x, y, z = symbols + expression = create_expression(a, x, y, z) + function = create_function(expression, backend) + assert callable(function.function) + assert function.argument_order == ("a", "x", "y", "z") + + @pytest.mark.parametrize("backend", ["jax", "math", "numpy", "tf"]) @pytest.mark.parametrize("max_complexity", [0, 1, 2, 3, 4, 5]) @pytest.mark.parametrize("use_cse", [False, True])