Skip to content

Commit

Permalink
bernstein class and config add
Browse files Browse the repository at this point in the history
  • Loading branch information
epens94 committed Oct 24, 2024
1 parent b70361f commit 8314f86
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: schnetpack.nn.radial.BernsteinRBF
n_rbf: 32
cutoff: ${globals.cutoff}
init_alpha: 0.95
118 changes: 117 additions & 1 deletion src/schnetpack/nn/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn

__all__ = ["gaussian_rbf", "GaussianRBF", "GaussianRBFCentered", "BesselRBF"]
__all__ = ["gaussian_rbf", "GaussianRBF", "GaussianRBFCentered", "BesselRBF","BernsteinRBF","PhysNetBasisRBF"]

from torch import nn as nn

Expand Down Expand Up @@ -108,3 +108,119 @@ def forward(self, inputs):
norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs)
y = sinax / norm[..., None]
return y


class BernsteinRBF(torch.nn.Module):


r"""Bernstein radial basis functions.
According to
B_{v,n}(x) = \binom{n}{v} x^v (1 - x)^{n - v}
with
B as the Bernstein polynomial of degree v
binom{k}{n} as the binomial coefficient n! / (k! * (n - k)!)
they become in logaritmic form log(n!) - log(k!) - log((n - k)!)
n as index running from 0 to degree k
The logarithmic form of the k-th Bernstein polynominal of degree n is
log(B_{k}_{n}) = logBinomCoeff + k * log(x) - (n-k) * log(1-x)
k_term is here k*log(x)
n_k_term is here (n-k)*log(1-x)
x is here the radial basis expansion : exp[-alpha*d]
logBinomCoeff is a scalar
k_term is a vector
n_k_term is also a vector
log to avoid numerical overflow errors, and ensure stability
"""

def __init__(
self, n_rbf: int, cutoff:float, init_alpha:float = 0.95):
"""
Args:
n_rbf: total number of Bernstein functions, :math:`N_g`.
cutoff: center of last Bernstein function, :math:`\mu_{N_g}`
"""

super(BernsteinRBF, self).__init__()
self.n_rbf = n_rbf

# log binomal coefficient vector
b = self.calculate_log_binomial_coefficients(n_rbf)
n_idx = torch.arange(0, n_rbf)
n_k_idx = n_rbf - 1 - n_idx

# register buffers and parameters
self.register_buffer("cutoff",torch.tensor(cutoff))
self.register_buffer("b", b)
self.register_buffer("n", n_idx)
self.register_buffer("n_k", n_k_idx)
self.register_buffer("init_alpha",torch.tensor(init_alpha))

# log of factorial (n! or k! or n-k!)
def log_factorial(self,n):
# log of factorial degree n
return torch.sum(torch.log(torch.arange(1, n + 1)))

# calculate log binominal coefficient
def log_binomial_coefficient(self,n, k):
# n_factorial - k_factorial - n_k_factorial
return self.log_factorial(n) - (self.log_factorial(k) + self.log_factorial(n - k))

# vector of log binominal coefficients
def calculate_log_binomial_coefficients(self,n_rbf):
# store the log binomial coefficients
# Loop through each value from 0 to n_rbf-1
log_binomial_coeffs = [
self.log_binomial_coefficient(n_rbf - 1, x) for x in range(n_rbf)
]
return torch.tensor(log_binomial_coeffs)

def forward(self, inputs):
exp_x = -self.init_alpha * inputs[...,None]
x = torch.exp(exp_x)
k_term = self.n * torch.where(self.n != 0, torch.log(x), torch.zeros_like(x))
n_k_term = self.n_k * torch.where(self.n_k != 0, torch.log(1 - x), torch.zeros_like(x))
y = torch.exp(self.b + k_term + n_k_term)
return y


class PhysNetBasisRBF(torch.nn.Module):

"""
Expand distances in the basis used in PhysNet (see https://arxiv.org/abs/1902.08408)
width (beta_k) = (2K^⁻1 * (1 - exp(-cutoff)))^-2)
center (mu_k) = equally spaced between exp(-cutoff) and 1
"""

def __init__(self, n_rbf: int, cutoff:float, trainable:bool):

"""
Args:
n_rbf: total number of basis functions.
cutoff: cutoff basis functions
"""

super(PhysNetBasisRBF, self).__init__()
self.n_rbf = n_rbf

# compute offset and width of Gaussian functions
widths = ((2 / self.n_rbf) * (1 - torch.exp(torch.Tensor([-cutoff])))) ** (-2)
r_0 = torch.exp(torch.Tensor([-cutoff])).item()
centers = torch.linspace(r_0,1,self.n_rbf)

if trainable:
self.widths = torch.nn.Parameter(widths)
self.centers = torch.nn.Parameter(centers)
else:
self.register_buffer("widths", widths)
self.register_buffer("centers", centers)


def forward(self, inputs: torch.Tensor):
return torch.exp(-abs(self.widths) * (torch.exp(-inputs[...,None]) - self.centers) ** 2)

0 comments on commit 8314f86

Please sign in to comment.