From 5c433ec46171da3de2d67f242bb300b96e171c35 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Wed, 16 Aug 2023 14:05:51 -0400 Subject: [PATCH 1/5] fix bug in QSBoozer --- desc/objectives/_qs.py | 1 + setup.cfg | 2 +- tests/test_objective_funs.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/desc/objectives/_qs.py b/desc/objectives/_qs.py index be6c53da92..4731928c08 100644 --- a/desc/objectives/_qs.py +++ b/desc/objectives/_qs.py @@ -143,6 +143,7 @@ def build(self, eq=None, use_jit=True, verbose=1): helicity=self.helicity, NFP=self._transforms["B"].basis.NFP, ) + self._idx = np.where(self._idx)[0] self._constants = { "transforms": self._transforms, "profiles": self._profiles, diff --git a/setup.cfg b/setup.cfg index 99a320c0f1..544fec618c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ markers= filterwarnings= error ignore::pytest.PytestUnraisableExceptionWarning - ignore::RuntimeWarning:desc.compute + ignore::RuntimeWarning # Ignore division by zero warnings. ignore::DeprecationWarning:ml_dtypes.* diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 6c34c1547d..59dc77b061 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -197,6 +197,20 @@ def test(eq): test(Equilibrium(L=2, M=2, N=1, iota=PowerSeriesProfile(0))) test(Equilibrium(L=2, M=2, N=1, current=PowerSeriesProfile(0))) + @pytest.mark.unit + def test_jax_compile_boozer(self): + """Test compilation of Boozer QA metric in ObjectiveFunction.""" + # making sure that compiles without any errors from JAX + # Related to issue #625 + def test(eq): + obj = ObjectiveFunction(QuasisymmetryBoozer(eq=eq)) + obj.build() + obj.compile() + fb = obj.compute_unscaled(obj.x(eq)) + np.testing.assert_allclose(fb, 0, atol=1e-12) + + test(Equilibrium(L=2, M=2, N=1, current=PowerSeriesProfile(0))) + @pytest.mark.unit def test_qh_boozer(self): """Test calculation of Boozer QH metric.""" From 7c42833238eb7c4954ab9821a3b2e67e22229034 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Wed, 16 Aug 2023 15:12:36 -0400 Subject: [PATCH 2/5] move fix to vmec_utils --- desc/objectives/_qs.py | 2 +- desc/vmec_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/desc/objectives/_qs.py b/desc/objectives/_qs.py index 4731928c08..444cfcb3e5 100644 --- a/desc/objectives/_qs.py +++ b/desc/objectives/_qs.py @@ -143,7 +143,7 @@ def build(self, eq=None, use_jit=True, verbose=1): helicity=self.helicity, NFP=self._transforms["B"].basis.NFP, ) - self._idx = np.where(self._idx)[0] + self._constants = { "transforms": self._transforms, "profiles": self._profiles, diff --git a/desc/vmec_utils.py b/desc/vmec_utils.py index a1d081cbf9..ce804c95f8 100644 --- a/desc/vmec_utils.py +++ b/desc/vmec_utils.py @@ -267,7 +267,7 @@ def ptolemy_linear_transform(desc_modes, vmec_modes=None, helicity=None, NFP=Non else: idx_MN = np.nonzero(vmec_modes[:, 1] * N == vmec_modes[:, 2] * M)[0] idx[idx_MN] = False - + idx = np.where(idx)[0] return matrix, vmec_modes, idx return matrix, vmec_modes From c5ece70ef596105dfdf58e9aff8ac7cfc51176d8 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Wed, 16 Aug 2023 16:26:08 -0400 Subject: [PATCH 3/5] fix qh and test_vmec tests --- setup.cfg | 2 +- tests/test_objective_funs.py | 2 +- tests/test_vmec.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 544fec618c..99a320c0f1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ markers= filterwarnings= error ignore::pytest.PytestUnraisableExceptionWarning - ignore::RuntimeWarning + ignore::RuntimeWarning:desc.compute # Ignore division by zero warnings. ignore::DeprecationWarning:ml_dtypes.* diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 59dc77b061..c72cf03e6f 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -245,7 +245,7 @@ def test_qh_boozer(self): idx_B = np.argsort(np.abs(B_mn)) # check that largest amplitudes are the QH modes - np.testing.assert_allclose(B_mn[idx_B[-3:]], np.flip(B_mn[~idx][:3])) + np.testing.assert_allclose(B_mn[idx_B[-3:]], np.flip(np.delete(B_mn, idx)[:3])) # check that these QH modes are not returned by the objective assert [b not in f for b in B_mn[idx_B[-3:]]] # check that the objective returns the lowest amplitudes diff --git a/tests/test_vmec.py b/tests/test_vmec.py index aef921a123..7723e36bda 100644 --- a/tests/test_vmec.py +++ b/tests/test_vmec.py @@ -182,7 +182,7 @@ def test_ptolemy_linear_transform(self): [1, 3, 3], ] ) - np.testing.assert_allclose(modes[~idx], sym_modes) + np.testing.assert_allclose(np.delete(modes, idx), sym_modes) @pytest.mark.unit def test_fourier_to_zernike(self): From f59a1033d440dbda3c9098a0fb916b40bb4d2da9 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Wed, 16 Aug 2023 16:52:39 -0400 Subject: [PATCH 4/5] change where call to nonzero --- desc/vmec_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/vmec_utils.py b/desc/vmec_utils.py index ce804c95f8..5194da3bb3 100644 --- a/desc/vmec_utils.py +++ b/desc/vmec_utils.py @@ -267,7 +267,7 @@ def ptolemy_linear_transform(desc_modes, vmec_modes=None, helicity=None, NFP=Non else: idx_MN = np.nonzero(vmec_modes[:, 1] * N == vmec_modes[:, 2] * M)[0] idx[idx_MN] = False - idx = np.where(idx)[0] + idx = np.nonzero(idx)[0] return matrix, vmec_modes, idx return matrix, vmec_modes From 9fa9d691305f51d12f7ef66f4e8cb8f2fd6f6d54 Mon Sep 17 00:00:00 2001 From: Dario Panici Date: Wed, 16 Aug 2023 20:06:19 -0400 Subject: [PATCH 5/5] fix test --- tests/test_vmec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_vmec.py b/tests/test_vmec.py index 7723e36bda..abca87455c 100644 --- a/tests/test_vmec.py +++ b/tests/test_vmec.py @@ -182,7 +182,7 @@ def test_ptolemy_linear_transform(self): [1, 3, 3], ] ) - np.testing.assert_allclose(np.delete(modes, idx), sym_modes) + np.testing.assert_allclose(np.delete(modes, idx, axis=0), sym_modes) @pytest.mark.unit def test_fourier_to_zernike(self):