Skip to content

Commit

Permalink
Fix bessel GPU (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp authored May 4, 2021
1 parent 349b269 commit 6de1cb9
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions nequip/nn/radial_basis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math

import torch
import numpy as np

from torch import nn

Expand Down Expand Up @@ -31,12 +32,13 @@ def __init__(self, r_max, num_basis=8, trainable=True):
self.r_max = float(r_max)
self.prefactor = 2.0 / self.r_max

self.bessel_weights = torch.linspace(
start=1.0, end=num_basis, steps=num_basis
) * torch.Tensor([np.pi])

bessel_weights = (
torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi
)
if self.trainable:
self.bessel_weights = nn.Parameter(self.bessel_weights)
self.bessel_weights = nn.Parameter(bessel_weights)
else:
self.register_buffer("bessel_weights", bessel_weights)

def forward(self, x):
"""
Expand Down

0 comments on commit 6de1cb9

Please sign in to comment.