Skip to content

Commit

Permalink
refactor!: adapt implementation to AmpForm v0.12.x (#345)
Browse files Browse the repository at this point in the history
* build: switch to AmpForm v0.12
* ci: update pip constraints and pre-commit config
* feat: compute kinematic helicity angles with different backends
* feat: define PositionalArgumentFunction
* feat: define create_function
* fix: force-push to matching branches
* refactor: accept only str as DataSample keys
* refactor: extract _printer module from sympy module
* refactor: implement SympyDataTransformer
* test: benchmark data generation with numpy and jax

Co-authored-by: GitHub <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 6, 2021
1 parent acf5770 commit b7a4efd
Show file tree
Hide file tree
Showing 29 changed files with 667 additions and 335 deletions.
4 changes: 2 additions & 2 deletions .constraints/py3.6.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
absl-py==0.15.0
alabaster==0.7.12
ampform==0.11.5
ampform==0.12.0
anyio==3.4.0
appdirs==1.4.4
aquirdturtle-collapsible-headings==3.1.0
Expand Down Expand Up @@ -175,7 +175,7 @@ python-dateutil==2.8.2
pytz==2021.3
pyyaml==6.0
pyzmq==22.3.0
qrules==0.9.4
qrules==0.9.5
qtconsole==5.2.1
qtpy==1.11.3
regex==2021.11.10
Expand Down
7 changes: 3 additions & 4 deletions .constraints/py3.7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
absl-py==1.0.0
alabaster==0.7.12
ampform==0.11.5
ampform==0.12.0
anyio==3.4.0
aquirdturtle-collapsible-headings==3.1.0
argcomplete==1.12.3
Expand All @@ -18,7 +18,7 @@ babel==2.9.1
backcall==0.2.0
backports.entry-points-selectable==1.1.1
beautifulsoup4==4.10.0
black==21.11b1 ; python_version >= "3.7.0"
black==21.12b0 ; python_version >= "3.7.0"
bleach==4.1.0
cached-property==1.5.2
cachetools==4.2.4
Expand Down Expand Up @@ -174,10 +174,9 @@ python-dateutil==2.8.2
pytz==2021.3
pyyaml==6.0
pyzmq==22.3.0
qrules==0.9.4
qrules==0.9.5
qtconsole==5.2.1
qtpy==1.11.3
regex==2021.11.10
requests==2.26.0
requests-oauthlib==1.3.0
restructuredtext-lint==1.3.2
Expand Down
7 changes: 3 additions & 4 deletions .constraints/py3.8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
absl-py==1.0.0
alabaster==0.7.12
ampform==0.11.5
ampform==0.12.0
anyio==3.4.0
aquirdturtle-collapsible-headings==3.1.0
argon2-cffi==21.1.0
Expand All @@ -17,7 +17,7 @@ babel==2.9.1
backcall==0.2.0
backports.entry-points-selectable==1.1.1
beautifulsoup4==4.10.0
black==21.11b1 ; python_version >= "3.7.0"
black==21.12b0 ; python_version >= "3.7.0"
bleach==4.1.0
cachetools==4.2.4
certifi==2021.10.8
Expand Down Expand Up @@ -173,10 +173,9 @@ python-dateutil==2.8.2
pytz==2021.3
pyyaml==6.0
pyzmq==22.3.0
qrules==0.9.4
qrules==0.9.5
qtconsole==5.2.1
qtpy==1.11.3
regex==2021.11.10
requests==2.26.0
requests-oauthlib==1.3.0
restructuredtext-lint==1.3.2
Expand Down
7 changes: 3 additions & 4 deletions .constraints/py3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
absl-py==1.0.0
alabaster==0.7.12
ampform==0.11.5
ampform==0.12.0
anyio==3.4.0
aquirdturtle-collapsible-headings==3.1.0
argon2-cffi==21.1.0
Expand All @@ -17,7 +17,7 @@ babel==2.9.1
backcall==0.2.0
backports.entry-points-selectable==1.1.1
beautifulsoup4==4.10.0
black==21.11b1 ; python_version >= "3.7.0"
black==21.12b0 ; python_version >= "3.7.0"
bleach==4.1.0
cachetools==4.2.4
certifi==2021.10.8
Expand Down Expand Up @@ -172,10 +172,9 @@ python-dateutil==2.8.2
pytz==2021.3
pyyaml==6.0
pyzmq==22.3.0
qrules==0.9.4
qrules==0.9.5
qtconsole==5.2.1
qtpy==1.11.3
regex==2021.11.10
requests==2.26.0
requests-oauthlib==1.3.0
restructuredtext-lint==1.3.2
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ jobs:
git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
- name: Push to stable branch
run: |
git push origin HEAD:refs/heads/stable
git push origin HEAD:refs/heads/stable --force
- name: Push to matching minor version branch
env:
TAG: ${{ github.ref_name }}
run: |
re='^([0-9]+)\.([0-9]+)\.[0-9]+'
if [[ $TAG =~ $re ]]; then
MINOR_VERSION_BRANCH="${BASH_REMATCH[1]}.${BASH_REMATCH[2]}.x"
git push origin HEAD:refs/heads/$MINOR_VERSION_BRANCH
git push origin HEAD:refs/heads/$MINOR_VERSION_BRANCH --force
fi
pypi:
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/ComPWA/repo-maintenance
rev: 0.0.92
rev: 0.0.93
hooks:
- id: check-dev-files
args:
Expand All @@ -57,7 +57,7 @@ repos:
- id: format-setup-cfg

- repo: https://github.com/psf/black
rev: 21.11b1
rev: 21.12b0
hooks:
- id: black

Expand Down Expand Up @@ -120,7 +120,7 @@ repos:
metadata.toc-showtags
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.5.0
rev: v2.5.1
hooks:
- id: prettier

Expand All @@ -130,7 +130,7 @@ repos:
- id: pydocstyle

- repo: https://github.com/ComPWA/mirrors-pyright
rev: v1.1.191
rev: v1.1.192
hooks:
- id: pyright

Expand Down
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ disable=
too-few-public-methods, # data containers (attrs) and interface classes
unspecified-encoding, # http://pylint.pycqa.org/en/latest/whatsnew/2.10.html
unused-import, # https://www.flake8rules.com/rules/F401
wrong-import-order, # handled by isort

[SIMILARITIES]
ignore-imports=yes # https://stackoverflow.com/a/30007053
Expand Down
39 changes: 22 additions & 17 deletions benchmarks/ampform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import tensorwaves as tw
from tensorwaves.data.phasespace import TFUniformRealNumberGenerator
from tensorwaves.data.transform import HelicityTransformer
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function.sympy import create_parametrized_function
from tensorwaves.interface import (
DataSample,
Expand Down Expand Up @@ -64,9 +64,10 @@ def generate_data(
function: ParametrizedFunction,
data_sample_size: int,
phsp_sample_size: int,
backend: str,
transform: bool = False,
) -> Tuple[DataSample, DataSample]:
reaction = model.adapter.reaction_info
reaction = model.reaction_info
final_state = reaction.final_state
phsp = tw.data.generate_phsp(
size=phsp_sample_size,
Expand All @@ -75,57 +76,57 @@ def generate_data(
random_generator=TFUniformRealNumberGenerator(seed=0),
)

helicity_transformer = HelicityTransformer(model.adapter)
expressions = model.kinematic_variables
converter = SympyDataTransformer.from_sympy(expressions, backend)
data = tw.data.generate_data(
size=data_sample_size,
initial_state_mass=reaction.initial_state[-1].mass,
final_state_masses={i: p.mass for i, p in final_state.items()},
data_transformer=helicity_transformer,
data_transformer=converter,
intensity=function,
random_generator=TFUniformRealNumberGenerator(seed=0),
)

if transform:
data = helicity_transformer(data)
phsp = helicity_transformer(phsp)
data = converter(data)
phsp = converter(phsp)
return data, phsp


def fit(
data_set: DataSample,
phsp_set: DataSample,
data: DataSample,
phsp: DataSample,
function: ParametrizedFunction,
initial_parameters: Mapping[str, ParameterValue],
backend: str,
) -> FitResult:
estimator = tw.estimator.UnbinnedNLL(
function,
data=data_set,
phsp=phsp_set,
data=data,
phsp=phsp,
backend=backend,
)
optimizer = tw.optimizer.Minuit2()

return optimizer.optimize(estimator, initial_parameters)


class TestJPsiToGammaPiPi:
expected_data = {
0: [
"p0": [
[1.50757377596, 0.37918944935, 0.73396599969, 1.26106620078],
[1.41389525301, -0.07315064441, -0.21998573758, 1.39475985207],
[1.52128570461, 0.06569896528, -1.51812710851, 0.0726906006],
[1.51480310845, 1.40672331053, 0.49678572189, -0.26260603856],
[1.52384281483, 0.79694939592, 1.29832389761, -0.03638188481],
],
1: [
"p1": [
[1.42066087326, -0.34871369761, -0.72119471428, -1.1654765212],
[0.96610319301, -0.26739932067, -0.15455480956, -0.90539883872],
[0.60647770024, 0.11616448713, 0.57584161239, -0.06714695611],
[1.01045883083, -0.88651015826, -0.46024226278, 0.0713099651],
[1.04324742713, -0.48051670276, -0.91259832182, -0.08009031815],
],
2: [
"p2": [
[0.16866535079, -0.03047575173, -0.01277128542, -0.09558967958],
[0.71690155399, 0.34054996508, 0.37454054715, -0.48936101336],
[0.96913659515, -0.18186345241, 0.94228549612, -0.00554364449],
Expand All @@ -145,13 +146,15 @@ def model(self) -> "HelicityModel":
)

@pytest.mark.benchmark(group="data", min_rounds=1)
@pytest.mark.parametrize("backend", ["jax"])
@pytest.mark.parametrize("backend", ["jax", "numpy", "tf"])
@pytest.mark.parametrize("size", [10_000])
def test_data(self, backend, benchmark, model, size):
n_data = size
n_phsp = 10 * n_data
function = create_function(model, backend)
data, phsp = benchmark(generate_data, model, function, n_data, n_phsp)
data, phsp = benchmark(
generate_data, model, function, n_data, n_phsp, backend
)
assert len(next(iter(data.values()))) == n_data
assert len(next(iter(phsp.values()))) == n_phsp

Expand All @@ -169,7 +172,9 @@ def test_fit(self, backend, benchmark, model, size):
n_data = size
n_phsp = 10 * n_data
function = create_function(model, backend)
data, phsp = generate_data(model, function, n_data, n_phsp, True)
data, phsp = generate_data(
model, function, n_data, n_phsp, backend, transform=True
)

coefficients = [p for p in function.parameters if p.startswith("C_{")]
assert len(coefficients) >= 1
Expand Down
5 changes: 2 additions & 3 deletions docs/abbreviate_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ def replace_link(text: str) -> str:
"a set-like object providing a view on D's items": "typing.ItemsView",
"a set-like object providing a view on D's keys": "typing.KeysView",
"an object providing a view on D's values": "typing.ValuesView",
"numpy.typing._array_like._SupportsArray": "numpy.typing.ArrayLike",
"numpy.typing._dtype_like._DTypeDict": "numpy.typing.DTypeLike",
"numpy.typing._dtype_like._SupportsDType": "numpy.typing.DTypeLike",
"sp.Expr": "sympy.core.expr.Expr",
"sp.Symbol": "sympy.core.symbol.Symbol",
"typing_extensions.Protocol": "typing.Protocol",
}
for old, new in replacements.items():
Expand Down
Loading

0 comments on commit b7a4efd

Please sign in to comment.