diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index fcef18fb..e14248e5 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -19,11 +19,12 @@ jobs: submodules: recursive - uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.12" - name: Install deps run: | curl -LsSf https://astral.sh/uv/install.sh | sh uv pip install --system -r requirements/dev-requirements.txt + uv pip install --system torch --index-url https://download.pytorch.org/whl/cpu - name: Install module run: | uv pip install --system . diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 396e8cb4..16fb3702 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -29,6 +29,7 @@ jobs: run: | curl -LsSf https://astral.sh/uv/install.sh | sh uv pip install --system ".[fitting,dev]" + uv pip install --system torch --index-url https://download.pytorch.org/whl/cpu - name: Test with pytest run: | pytest --nbmake diff --git a/.github/workflows/requirements.yml b/.github/workflows/requirements.yml index 20e9fd9f..1f2ec2c0 100644 --- a/.github/workflows/requirements.yml +++ b/.github/workflows/requirements.yml @@ -21,6 +21,7 @@ jobs: run: | curl -LsSf https://astral.sh/uv/install.sh | sh uv pip install --system --no-deps .[fitting,dev] + uv pip install --system torch --index-url https://download.pytorch.org/whl/cpu - name: Install dev requirements run: | uv pip install --system -r requirements/dev-requirements.txt diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 4a0fd62f..ad8e5c6b 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -6,10 +6,6 @@ appdirs==1.4.4 \ # via # -r requirements/fitting-requirements.txt # pint -appnope==0.1.4 \ - --hash=sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee \ - --hash=sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c - # via ipykernel asteval==0.9.33 \ --hash=sha256:94981701f4d252c88aa5e821121b1aabef73a003da138fc6405169c9e675d24d \ --hash=sha256:aae3a0308575a545c8cecc43a6632219e6a90963a56380c74632cf54311e43bf @@ -187,13 +183,6 @@ distlib==0.3.8 \ --hash=sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784 \ --hash=sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64 # via virtualenv -exceptiongroup==1.2.1 \ - --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ - --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 - # via - # -r requirements/fitting-requirements.txt - # ipython - # pytest executing==2.0.1 \ --hash=sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147 \ --hash=sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc @@ -204,9 +193,9 @@ fastjsonschema==2.20.0 \ --hash=sha256:3d48fc5300ee96f5d116f10fe6f28d938e6008f59a6a025c2649475b87f76a23 \ --hash=sha256:5875f0b0fa7a0043a91e93a9b8f793bcbbba9691e7fd83dca95c28ba26d21f0a # via nbformat -filelock==3.15.1 \ - --hash=sha256:58a2549afdf9e02e10720eaa4d4470f56386d7a6f72edd7d0596337af8ed7ad8 \ - --hash=sha256:71b3102950e91dfc1bb4209b64be4dc8854f40e5f534428d8684f953ac847fac +filelock==3.15.3 \ + --hash=sha256:0151273e5b5d6cf753a61ec83b3a9b7d8821c39ae9af9d7ecf2f9e2f17404103 \ + --hash=sha256:e1199bf5194a2277273dacd50269f0d87d0682088a3c561c15674ea9005d8635 # via virtualenv flexcache==0.3 \ --hash=sha256:18743bd5a0621bfe2cf8d519e4c3bfdf57a269c15d1ced3fb4b64e0ff4600656 \ @@ -1253,12 +1242,6 @@ tenacity==8.4.1 \ # via # -r requirements/fitting-requirements.txt # plotly -tomli==2.0.1 \ - --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ - --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f - # via - # coverage - # pytest tornado==6.4.1 \ --hash=sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8 \ --hash=sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f \ diff --git a/requirements/fitting-requirements.txt b/requirements/fitting-requirements.txt index c3ffcff6..7e80f51d 100644 --- a/requirements/fitting-requirements.txt +++ b/requirements/fitting-requirements.txt @@ -76,10 +76,6 @@ dill==0.3.8 \ --hash=sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca \ --hash=sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7 # via lmfit -exceptiongroup==1.2.1 \ - --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ - --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 - # via ipython executing==2.0.1 \ --hash=sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147 \ --hash=sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc diff --git a/src/elli/solver4x4.py b/src/elli/solver4x4.py index 863da5bf..6c3a10cd 100644 --- a/src/elli/solver4x4.py +++ b/src/elli/solver4x4.py @@ -1,9 +1,18 @@ # Encoding: utf-8 from abc import ABC, abstractmethod +from typing import Literal import numpy as np import numpy.typing as npt import scipy.constants as sc + +try: + import torch +except ImportError: + TORCH_AVAILABLE = False +else: + TORCH_AVAILABLE = True + from numpy.lib.scimath import sqrt from scipy.linalg import expm as scipy_expm @@ -56,6 +65,30 @@ def calculate_propagation( class PropagatorExpm(Propagator): """Propagator class using the Padé approximation of the matrix exponential.""" + def __init__(self, backend: Literal["torch", "scipy", "automatic"] = "automatic"): + backends = { + "torch": lambda mats: torch.linalg.matrix_exp( + torch.from_numpy(mats) + ).numpy(), + "scipy": lambda mats: scipy_expm(mats), + } + + if backend == "automatic" and TORCH_AVAILABLE: + backend = "torch" + elif backend == "automatic" and not TORCH_AVAILABLE: + backend = "scipy" + elif backend == "torch" and not TORCH_AVAILABLE: + raise ImportError( + "PyTorch is not installed. If you want to use the PyTorch backend, \ + please follow the install instructions on https://pytorch.org/get-started/locally/" + ) + elif backend not in backends: + raise ValueError( + "Backend should be one of 'torch', 'scipy' or 'automatic'." + ) + + self.expm = backends[backend] + def calculate_propagation( self, delta: npt.NDArray, thickness: float, lbda: npt.ArrayLike ) -> npt.NDArray: @@ -71,7 +104,7 @@ def calculate_propagation( """ mats = 1j * thickness * np.einsum("nij,n->nij", delta, 2 * sc.pi / lbda) - propagator = np.asarray([scipy_expm(mat) for mat in mats]) + propagator = self.expm(mats) return propagator diff --git a/tests/benchmark_propagators_TiO2.py b/tests/benchmark_propagators_TiO2.py index bcdc6464..a5865f3e 100644 --- a/tests/benchmark_propagators_TiO2.py +++ b/tests/benchmark_propagators_TiO2.py @@ -1,10 +1,9 @@ """Testing benchmark for each solver""" -from pytest import fixture -import numpy as np - import elli +import numpy as np from elli.fitting import ParamsHist +from pytest import fixture @fixture @@ -79,7 +78,24 @@ def test_solver4x4_expm(benchmark, structure): benchmark.pedantic( structure.evaluate, args=(lbda, PHI), - kwargs={"solver": elli.Solver4x4, "propagator": elli.PropagatorExpm()}, + kwargs={ + "solver": elli.Solver4x4, + "propagator": elli.PropagatorExpm(backend="scipy"), + }, + iterations=1, + rounds=10, + ) + + +def test_solver4x4_expm_pytorch(benchmark, structure): + """Benchmarks expm-torch propagator with solver4x4""" + benchmark.pedantic( + structure.evaluate, + args=(lbda, PHI), + kwargs={ + "solver": elli.Solver4x4, + "propagator": elli.PropagatorExpm(backend="torch"), + }, iterations=1, rounds=10, ) diff --git a/tests/test_TiO2.py b/tests/test_TiO2.py index 20afda8c..3dfa2d1a 100644 --- a/tests/test_TiO2.py +++ b/tests/test_TiO2.py @@ -5,11 +5,10 @@ import os from shutil import copytree, rmtree -import numpy as np -from pytest import fixture - import elli +import numpy as np from elli.fitting import ParamsHist +from pytest import fixture @fixture @@ -125,7 +124,22 @@ def test_solver4x4_expm(self, si_dispersion, meas_data): meas_data.index, 70, solver=elli.Solver4x4, - propagator=elli.PropagatorExpm(), + propagator=elli.PropagatorExpm(backend="scipy"), + ) + .rho + ) + + assert TestTiO2.chisqr(meas_data, sim_data) < 0.0456 + + def test_solver4x4_expm_torch(self, si_dispersion, meas_data): + """The solver4x4 with pytorch propagator is within chi square accuracy""" + sim_data = ( + elli.Structure(elli.AIR, self.Layer, si_dispersion) + .evaluate( + meas_data.index, + 70, + solver=elli.Solver4x4, + propagator=elli.PropagatorExpm(backend="torch"), ) .rho )