From 6de1cb9b4f2ce27326a42e877a52657b03ba240e Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 4 May 2021 10:32:24 -0600 Subject: [PATCH] Fix bessel GPU (#27) --- nequip/nn/radial_basis.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/nequip/nn/radial_basis.py b/nequip/nn/radial_basis.py index d0fc3bf2..21d04e06 100644 --- a/nequip/nn/radial_basis.py +++ b/nequip/nn/radial_basis.py @@ -1,5 +1,6 @@ +import math + import torch -import numpy as np from torch import nn @@ -31,12 +32,13 @@ def __init__(self, r_max, num_basis=8, trainable=True): self.r_max = float(r_max) self.prefactor = 2.0 / self.r_max - self.bessel_weights = torch.linspace( - start=1.0, end=num_basis, steps=num_basis - ) * torch.Tensor([np.pi]) - + bessel_weights = ( + torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi + ) if self.trainable: - self.bessel_weights = nn.Parameter(self.bessel_weights) + self.bessel_weights = nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) def forward(self, x): """