From d9f266523983fdea0baeacc16dece83f447e5a14 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 16 Nov 2023 11:36:48 -0800 Subject: [PATCH] fix TypeError: SphericalHarmonicsFunction.__call__() got an unexpected keyword argument 'costheta' tests/layers/test_basis.py:95: TypeError --- tests/layers/test_basis.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/layers/test_basis.py b/tests/layers/test_basis.py index b64d1bab..93ce241b 100644 --- a/tests/layers/test_basis.py +++ b/tests/layers/test_basis.py @@ -47,7 +47,7 @@ def test_gaussian(): rbf_gaussian.reset_parameters() -def test_sphericalbesselfunction(): +def test_spherical_bessel_function(): r = torch.linspace(1.0, 5.0, 11) rbf_sb = SphericalBesselFunction(max_n=3, max_l=3, cutoff=5.0, smooth=False) rbf = rbf_sb(r) @@ -58,7 +58,7 @@ def test_sphericalbesselfunction(): assert [rbf.size(dim=0), rbf.size(dim=1)] == [11, 3] -def test_sphericalbesselfunction_smooth(): +def test_spherical_bessel_function_smooth(): r = torch.linspace(1.0, 5.0, 11) rbf_sb = SphericalBesselFunction(max_n=3, max_l=3, cutoff=5.0, smooth=False) rbf = rbf_sb(r) @@ -92,7 +92,7 @@ def test_spherical_bessel_harmonics_function(): assert tuple(res2.shape) == (10, 3) shf = SphericalHarmonicsFunction(max_l=3, use_phi=True) - res_shf = shf(costheta=torch.linspace(-1, 1, 10), phi=torch.linspace(0, 2 * np.pi, 10)) + res_shf = shf(cos_theta=torch.linspace(-1, 1, 10), phi=torch.linspace(0, 2 * np.pi, 10)) assert res_shf.numpy().shape == (10, 9) combined = combine_sbf_shf(res, res_shf, max_n=3, max_l=3, use_phi=True) @@ -100,7 +100,7 @@ def test_spherical_bessel_harmonics_function(): assert combined.shape == (10, 27) res_shf2 = SphericalHarmonicsFunction(max_l=3, use_phi=False)( - costheta=torch.linspace(-1, 1, 10), phi=torch.linspace(0, 2 * np.pi, 10) + cos_theta=torch.linspace(-1, 1, 10), phi=torch.linspace(0, 2 * np.pi, 10) ) combined = combine_sbf_shf(res, res_shf2, max_n=3, max_l=3, use_phi=False)