diff --git a/nequip/nn/radial_basis.py b/nequip/nn/radial_basis.py index d0fc3bf2..21d04e06 100644 --- a/nequip/nn/radial_basis.py +++ b/nequip/nn/radial_basis.py @@ -1,5 +1,6 @@ +import math + import torch -import numpy as np from torch import nn @@ -31,12 +32,13 @@ def __init__(self, r_max, num_basis=8, trainable=True): self.r_max = float(r_max) self.prefactor = 2.0 / self.r_max - self.bessel_weights = torch.linspace( - start=1.0, end=num_basis, steps=num_basis - ) * torch.Tensor([np.pi]) - + bessel_weights = ( + torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi + ) if self.trainable: - self.bessel_weights = nn.Parameter(self.bessel_weights) + self.bessel_weights = nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) def forward(self, x): """