Skip to content

Commit

Permalink
fix TypeError: SphericalHarmonicsFunction.__call__() got an unexpecte…
Browse files Browse the repository at this point in the history
…d keyword argument 'costheta'

tests/layers/test_basis.py:95: TypeError
  • Loading branch information
janosh committed Nov 16, 2023
1 parent 0c1dde0 commit d9f2665
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/layers/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_gaussian():
rbf_gaussian.reset_parameters()


def test_sphericalbesselfunction():
def test_spherical_bessel_function():
r = torch.linspace(1.0, 5.0, 11)
rbf_sb = SphericalBesselFunction(max_n=3, max_l=3, cutoff=5.0, smooth=False)
rbf = rbf_sb(r)
Expand All @@ -58,7 +58,7 @@ def test_sphericalbesselfunction():
assert [rbf.size(dim=0), rbf.size(dim=1)] == [11, 3]


def test_sphericalbesselfunction_smooth():
def test_spherical_bessel_function_smooth():
r = torch.linspace(1.0, 5.0, 11)
rbf_sb = SphericalBesselFunction(max_n=3, max_l=3, cutoff=5.0, smooth=False)
rbf = rbf_sb(r)
Expand Down Expand Up @@ -92,15 +92,15 @@ def test_spherical_bessel_harmonics_function():
assert tuple(res2.shape) == (10, 3)

shf = SphericalHarmonicsFunction(max_l=3, use_phi=True)
res_shf = shf(costheta=torch.linspace(-1, 1, 10), phi=torch.linspace(0, 2 * np.pi, 10))
res_shf = shf(cos_theta=torch.linspace(-1, 1, 10), phi=torch.linspace(0, 2 * np.pi, 10))

assert res_shf.numpy().shape == (10, 9)
combined = combine_sbf_shf(res, res_shf, max_n=3, max_l=3, use_phi=True)

assert combined.shape == (10, 27)

res_shf2 = SphericalHarmonicsFunction(max_l=3, use_phi=False)(
costheta=torch.linspace(-1, 1, 10), phi=torch.linspace(0, 2 * np.pi, 10)
cos_theta=torch.linspace(-1, 1, 10), phi=torch.linspace(0, 2 * np.pi, 10)
)
combined = combine_sbf_shf(res, res_shf2, max_n=3, max_l=3, use_phi=False)

Expand Down

0 comments on commit d9f2665

Please sign in to comment.