From d077801994308b6b428a546f13a796abb7a0fcd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marius=20M=C3=BCller?= <49639740+MarJMue@users.noreply.github.com> Date: Thu, 20 Jun 2024 10:20:55 +0200 Subject: [PATCH 1/6] Re-implement PyTorch 4x4Solver --- src/elli/solver4x4.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/elli/solver4x4.py b/src/elli/solver4x4.py index 863da5bf..ec093798 100644 --- a/src/elli/solver4x4.py +++ b/src/elli/solver4x4.py @@ -4,6 +4,7 @@ import numpy as np import numpy.typing as npt import scipy.constants as sc +import torch from numpy.lib.scimath import sqrt from scipy.linalg import expm as scipy_expm @@ -71,7 +72,30 @@ def calculate_propagation( """ mats = 1j * thickness * np.einsum("nij,n->nij", delta, 2 * sc.pi / lbda) - propagator = np.asarray([scipy_expm(mat) for mat in mats]) + propagator = scipy_expm(mats) + + return propagator + + +class PropagatorExpmTorch(Propagator): + """Propagator class using the matrix exponential provided by PyTorch.""" + + def calculate_propagation( + self, delta: npt.NDArray, thickness: float, lbda: npt.ArrayLike + ) -> npt.NDArray: + """Calculates propagation for a given Delta matrix and layer thickness with the Padé approximation of the matrix exponential. + + Args: + delta (npt.NDArray): Delta Matrix + thickness (float): Thickness of layer (nm) + lbda (npt.ArrayLike): Wavelengths to evaluate (nm) + + Returns: + npt.NDArray: Propagator for the given layer + """ + mats = 1j * thickness * np.einsum("nij,n->nij", delta, 2 * sc.pi / lbda) + + propagator = torch.linalg.matrix_exp(torch.from_numpy(mats)).numpy() return propagator From 7da8d055df98d77da541f1bbe1c3140b3f80abc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marius=20M=C3=BCller?= <49639740+MarJMue@users.noreply.github.com> Date: Thu, 20 Jun 2024 10:25:42 +0200 Subject: [PATCH 2/6] Include PyTorch solver in tests --- tests/benchmark_propagators_TiO2.py | 16 +++++++++++++--- tests/test_TiO2.py | 20 +++++++++++++++++--- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/tests/benchmark_propagators_TiO2.py b/tests/benchmark_propagators_TiO2.py index bcdc6464..8c6b0ffa 100644 --- a/tests/benchmark_propagators_TiO2.py +++ b/tests/benchmark_propagators_TiO2.py @@ -1,10 +1,9 @@ """Testing benchmark for each solver""" -from pytest import fixture -import numpy as np - import elli +import numpy as np from elli.fitting import ParamsHist +from pytest import fixture @fixture @@ -85,6 +84,17 @@ def test_solver4x4_expm(benchmark, structure): ) +def test_solver4x4_expm_pytorch(benchmark, structure): + """Benchmarks expm-torch propagator with solver4x4""" + benchmark.pedantic( + structure.evaluate, + args=(lbda, PHI), + kwargs={"solver": elli.Solver4x4, "propagator": elli.PropagatorExpmTorch()}, + iterations=1, + rounds=10, + ) + + def test_solver4x4_linear(benchmark, structure): """Benchmarks linear propagator with solver4x4""" benchmark.pedantic( diff --git a/tests/test_TiO2.py b/tests/test_TiO2.py index 20afda8c..83630382 100644 --- a/tests/test_TiO2.py +++ b/tests/test_TiO2.py @@ -5,11 +5,10 @@ import os from shutil import copytree, rmtree -import numpy as np -from pytest import fixture - import elli +import numpy as np from elli.fitting import ParamsHist +from pytest import fixture @fixture @@ -132,6 +131,21 @@ def test_solver4x4_expm(self, si_dispersion, meas_data): assert TestTiO2.chisqr(meas_data, sim_data) < 0.0456 + def test_solver4x4_expm_torch(self, si_dispersion, meas_data): + """The solver4x4 with scipy propagator is within chi square accuracy""" + sim_data = ( + elli.Structure(elli.AIR, self.Layer, si_dispersion) + .evaluate( + meas_data.index, + 70, + solver=elli.Solver4x4, + propagator=elli.PropagatorExpmTorch(), + ) + .rho + ) + + assert TestTiO2.chisqr(meas_data, sim_data) < 0.0456 + def test_solver4x4_eig(self, si_dispersion, meas_data): """The solver4x4 with eig propagator is within chi square accuracy""" sim_data = ( From 2be64b9ae5a164d266af42e963a5b76746b87844 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marius=20M=C3=BCller?= <49639740+MarJMue@users.noreply.github.com> Date: Thu, 20 Jun 2024 10:55:49 +0200 Subject: [PATCH 3/6] Update dependencies and make torch optional --- .github/workflows/benchmark.yml | 3 ++- .github/workflows/pytest.yml | 1 + requirements/dev-requirements.txt | 23 +++-------------------- requirements/fitting-requirements.txt | 4 ---- src/elli/solver4x4.py | 6 +++++- 5 files changed, 11 insertions(+), 26 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index fcef18fb..e14248e5 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -19,11 +19,12 @@ jobs: submodules: recursive - uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.12" - name: Install deps run: | curl -LsSf https://astral.sh/uv/install.sh | sh uv pip install --system -r requirements/dev-requirements.txt + uv pip install --system torch --index-url https://download.pytorch.org/whl/cpu - name: Install module run: | uv pip install --system . diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 396e8cb4..16fb3702 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -29,6 +29,7 @@ jobs: run: | curl -LsSf https://astral.sh/uv/install.sh | sh uv pip install --system ".[fitting,dev]" + uv pip install --system torch --index-url https://download.pytorch.org/whl/cpu - name: Test with pytest run: | pytest --nbmake diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 4a0fd62f..ad8e5c6b 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -6,10 +6,6 @@ appdirs==1.4.4 \ # via # -r requirements/fitting-requirements.txt # pint -appnope==0.1.4 \ - --hash=sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee \ - --hash=sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c - # via ipykernel asteval==0.9.33 \ --hash=sha256:94981701f4d252c88aa5e821121b1aabef73a003da138fc6405169c9e675d24d \ --hash=sha256:aae3a0308575a545c8cecc43a6632219e6a90963a56380c74632cf54311e43bf @@ -187,13 +183,6 @@ distlib==0.3.8 \ --hash=sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784 \ --hash=sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64 # via virtualenv -exceptiongroup==1.2.1 \ - --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ - --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 - # via - # -r requirements/fitting-requirements.txt - # ipython - # pytest executing==2.0.1 \ --hash=sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147 \ --hash=sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc @@ -204,9 +193,9 @@ fastjsonschema==2.20.0 \ --hash=sha256:3d48fc5300ee96f5d116f10fe6f28d938e6008f59a6a025c2649475b87f76a23 \ --hash=sha256:5875f0b0fa7a0043a91e93a9b8f793bcbbba9691e7fd83dca95c28ba26d21f0a # via nbformat -filelock==3.15.1 \ - --hash=sha256:58a2549afdf9e02e10720eaa4d4470f56386d7a6f72edd7d0596337af8ed7ad8 \ - --hash=sha256:71b3102950e91dfc1bb4209b64be4dc8854f40e5f534428d8684f953ac847fac +filelock==3.15.3 \ + --hash=sha256:0151273e5b5d6cf753a61ec83b3a9b7d8821c39ae9af9d7ecf2f9e2f17404103 \ + --hash=sha256:e1199bf5194a2277273dacd50269f0d87d0682088a3c561c15674ea9005d8635 # via virtualenv flexcache==0.3 \ --hash=sha256:18743bd5a0621bfe2cf8d519e4c3bfdf57a269c15d1ced3fb4b64e0ff4600656 \ @@ -1253,12 +1242,6 @@ tenacity==8.4.1 \ # via # -r requirements/fitting-requirements.txt # plotly -tomli==2.0.1 \ - --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ - --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f - # via - # coverage - # pytest tornado==6.4.1 \ --hash=sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8 \ --hash=sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f \ diff --git a/requirements/fitting-requirements.txt b/requirements/fitting-requirements.txt index c3ffcff6..7e80f51d 100644 --- a/requirements/fitting-requirements.txt +++ b/requirements/fitting-requirements.txt @@ -76,10 +76,6 @@ dill==0.3.8 \ --hash=sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca \ --hash=sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7 # via lmfit -exceptiongroup==1.2.1 \ - --hash=sha256:5258b9ed329c5bbdd31a309f53cbfb0b155341807f6ff7606a1e801a891b29ad \ - --hash=sha256:a4785e48b045528f5bfe627b6ad554ff32def154f42372786903b7abcfe1aa16 - # via ipython executing==2.0.1 \ --hash=sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147 \ --hash=sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc diff --git a/src/elli/solver4x4.py b/src/elli/solver4x4.py index ec093798..85a271aa 100644 --- a/src/elli/solver4x4.py +++ b/src/elli/solver4x4.py @@ -4,7 +4,11 @@ import numpy as np import numpy.typing as npt import scipy.constants as sc -import torch + +try: + import torch +except ImportError: + ... from numpy.lib.scimath import sqrt from scipy.linalg import expm as scipy_expm From 75f07d7f72fbe6249c8321c1490c347d4198675f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marius=20M=C3=BCller?= <49639740+MarJMue@users.noreply.github.com> Date: Thu, 20 Jun 2024 11:01:32 +0200 Subject: [PATCH 4/6] Update requirements.yml --- .github/workflows/requirements.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/requirements.yml b/.github/workflows/requirements.yml index 20e9fd9f..1f2ec2c0 100644 --- a/.github/workflows/requirements.yml +++ b/.github/workflows/requirements.yml @@ -21,6 +21,7 @@ jobs: run: | curl -LsSf https://astral.sh/uv/install.sh | sh uv pip install --system --no-deps .[fitting,dev] + uv pip install --system torch --index-url https://download.pytorch.org/whl/cpu - name: Install dev requirements run: | uv pip install --system -r requirements/dev-requirements.txt From db9f41806732bcba7d7387a190458cb09f9fd7eb Mon Sep 17 00:00:00 2001 From: MarJMue <49639740+MarJMue@users.noreply.github.com> Date: Mon, 24 Jun 2024 13:26:51 +0200 Subject: [PATCH 5/6] Make solver math provider a parameter --- src/elli/solver4x4.py | 51 ++++++++++++++++++++++++------------------- tests/test_TiO2.py | 6 ++--- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/elli/solver4x4.py b/src/elli/solver4x4.py index 85a271aa..6c3a10cd 100644 --- a/src/elli/solver4x4.py +++ b/src/elli/solver4x4.py @@ -1,5 +1,6 @@ # Encoding: utf-8 from abc import ABC, abstractmethod +from typing import Literal import numpy as np import numpy.typing as npt @@ -8,7 +9,10 @@ try: import torch except ImportError: - ... + TORCH_AVAILABLE = False +else: + TORCH_AVAILABLE = True + from numpy.lib.scimath import sqrt from scipy.linalg import expm as scipy_expm @@ -61,28 +65,29 @@ def calculate_propagation( class PropagatorExpm(Propagator): """Propagator class using the Padé approximation of the matrix exponential.""" - def calculate_propagation( - self, delta: npt.NDArray, thickness: float, lbda: npt.ArrayLike - ) -> npt.NDArray: - """Calculates propagation for a given Delta matrix and layer thickness with the Padé approximation of the matrix exponential. - - Args: - delta (npt.NDArray): Delta Matrix - thickness (float): Thickness of layer (nm) - lbda (npt.ArrayLike): Wavelengths to evaluate (nm) - - Returns: - npt.NDArray: Propagator for the given layer - """ - mats = 1j * thickness * np.einsum("nij,n->nij", delta, 2 * sc.pi / lbda) - - propagator = scipy_expm(mats) - - return propagator - + def __init__(self, backend: Literal["torch", "scipy", "automatic"] = "automatic"): + backends = { + "torch": lambda mats: torch.linalg.matrix_exp( + torch.from_numpy(mats) + ).numpy(), + "scipy": lambda mats: scipy_expm(mats), + } + + if backend == "automatic" and TORCH_AVAILABLE: + backend = "torch" + elif backend == "automatic" and not TORCH_AVAILABLE: + backend = "scipy" + elif backend == "torch" and not TORCH_AVAILABLE: + raise ImportError( + "PyTorch is not installed. If you want to use the PyTorch backend, \ + please follow the install instructions on https://pytorch.org/get-started/locally/" + ) + elif backend not in backends: + raise ValueError( + "Backend should be one of 'torch', 'scipy' or 'automatic'." + ) -class PropagatorExpmTorch(Propagator): - """Propagator class using the matrix exponential provided by PyTorch.""" + self.expm = backends[backend] def calculate_propagation( self, delta: npt.NDArray, thickness: float, lbda: npt.ArrayLike @@ -99,7 +104,7 @@ def calculate_propagation( """ mats = 1j * thickness * np.einsum("nij,n->nij", delta, 2 * sc.pi / lbda) - propagator = torch.linalg.matrix_exp(torch.from_numpy(mats)).numpy() + propagator = self.expm(mats) return propagator diff --git a/tests/test_TiO2.py b/tests/test_TiO2.py index 83630382..3dfa2d1a 100644 --- a/tests/test_TiO2.py +++ b/tests/test_TiO2.py @@ -124,7 +124,7 @@ def test_solver4x4_expm(self, si_dispersion, meas_data): meas_data.index, 70, solver=elli.Solver4x4, - propagator=elli.PropagatorExpm(), + propagator=elli.PropagatorExpm(backend="scipy"), ) .rho ) @@ -132,14 +132,14 @@ def test_solver4x4_expm(self, si_dispersion, meas_data): assert TestTiO2.chisqr(meas_data, sim_data) < 0.0456 def test_solver4x4_expm_torch(self, si_dispersion, meas_data): - """The solver4x4 with scipy propagator is within chi square accuracy""" + """The solver4x4 with pytorch propagator is within chi square accuracy""" sim_data = ( elli.Structure(elli.AIR, self.Layer, si_dispersion) .evaluate( meas_data.index, 70, solver=elli.Solver4x4, - propagator=elli.PropagatorExpmTorch(), + propagator=elli.PropagatorExpm(backend="torch"), ) .rho ) From 3c8be9c8880cd0991aff8185a169331b52529078 Mon Sep 17 00:00:00 2001 From: MarJMue <49639740+MarJMue@users.noreply.github.com> Date: Mon, 24 Jun 2024 19:51:28 +0200 Subject: [PATCH 6/6] Fix benchmarks as well --- tests/benchmark_propagators_TiO2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/benchmark_propagators_TiO2.py b/tests/benchmark_propagators_TiO2.py index 8c6b0ffa..a5865f3e 100644 --- a/tests/benchmark_propagators_TiO2.py +++ b/tests/benchmark_propagators_TiO2.py @@ -78,7 +78,10 @@ def test_solver4x4_expm(benchmark, structure): benchmark.pedantic( structure.evaluate, args=(lbda, PHI), - kwargs={"solver": elli.Solver4x4, "propagator": elli.PropagatorExpm()}, + kwargs={ + "solver": elli.Solver4x4, + "propagator": elli.PropagatorExpm(backend="scipy"), + }, iterations=1, rounds=10, ) @@ -89,7 +92,10 @@ def test_solver4x4_expm_pytorch(benchmark, structure): benchmark.pedantic( structure.evaluate, args=(lbda, PHI), - kwargs={"solver": elli.Solver4x4, "propagator": elli.PropagatorExpmTorch()}, + kwargs={ + "solver": elli.Solver4x4, + "propagator": elli.PropagatorExpm(backend="torch"), + }, iterations=1, rounds=10, )