From 8314f866b58f9098ebfd10530662ff2acd216fa5 Mon Sep 17 00:00:00 2001 From: epens94 Date: Thu, 24 Oct 2024 16:22:13 +0200 Subject: [PATCH] bernstein class and config add --- .../radial_basis/bernstein.yaml | 4 + src/schnetpack/nn/radial.py | 118 +++++++++++++++++- 2 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 src/schnetpack/configs/model/representation/radial_basis/bernstein.yaml diff --git a/src/schnetpack/configs/model/representation/radial_basis/bernstein.yaml b/src/schnetpack/configs/model/representation/radial_basis/bernstein.yaml new file mode 100644 index 000000000..fac40c51a --- /dev/null +++ b/src/schnetpack/configs/model/representation/radial_basis/bernstein.yaml @@ -0,0 +1,4 @@ +_target_: schnetpack.nn.radial.BernsteinRBF +n_rbf: 32 +cutoff: ${globals.cutoff} +init_alpha: 0.95 diff --git a/src/schnetpack/nn/radial.py b/src/schnetpack/nn/radial.py index 8c63c2a07..41354560b 100644 --- a/src/schnetpack/nn/radial.py +++ b/src/schnetpack/nn/radial.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -__all__ = ["gaussian_rbf", "GaussianRBF", "GaussianRBFCentered", "BesselRBF"] +__all__ = ["gaussian_rbf", "GaussianRBF", "GaussianRBFCentered", "BesselRBF","BernsteinRBF","PhysNetBasisRBF"] from torch import nn as nn @@ -108,3 +108,119 @@ def forward(self, inputs): norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs) y = sinax / norm[..., None] return y + + +class BernsteinRBF(torch.nn.Module): + + + r"""Bernstein radial basis functions. + + According to + B_{v,n}(x) = \binom{n}{v} x^v (1 - x)^{n - v} + with + B as the Bernstein polynomial of degree v + binom{k}{n} as the binomial coefficient n! / (k! * (n - k)!) + they become in logaritmic form log(n!) - log(k!) - log((n - k)!) + n as index running from 0 to degree k + + The logarithmic form of the k-th Bernstein polynominal of degree n is + + log(B_{k}_{n}) = logBinomCoeff + k * log(x) - (n-k) * log(1-x) + k_term is here k*log(x) + n_k_term is here (n-k)*log(1-x) + x is here the radial basis expansion : exp[-alpha*d] + + logBinomCoeff is a scalar + k_term is a vector + n_k_term is also a vector + + log to avoid numerical overflow errors, and ensure stability + """ + + def __init__( + self, n_rbf: int, cutoff:float, init_alpha:float = 0.95): + """ + Args: + n_rbf: total number of Bernstein functions, :math:`N_g`. + cutoff: center of last Bernstein function, :math:`\mu_{N_g}` + """ + + super(BernsteinRBF, self).__init__() + self.n_rbf = n_rbf + + # log binomal coefficient vector + b = self.calculate_log_binomial_coefficients(n_rbf) + n_idx = torch.arange(0, n_rbf) + n_k_idx = n_rbf - 1 - n_idx + + # register buffers and parameters + self.register_buffer("cutoff",torch.tensor(cutoff)) + self.register_buffer("b", b) + self.register_buffer("n", n_idx) + self.register_buffer("n_k", n_k_idx) + self.register_buffer("init_alpha",torch.tensor(init_alpha)) + + # log of factorial (n! or k! or n-k!) + def log_factorial(self,n): + # log of factorial degree n + return torch.sum(torch.log(torch.arange(1, n + 1))) + + # calculate log binominal coefficient + def log_binomial_coefficient(self,n, k): + # n_factorial - k_factorial - n_k_factorial + return self.log_factorial(n) - (self.log_factorial(k) + self.log_factorial(n - k)) + + # vector of log binominal coefficients + def calculate_log_binomial_coefficients(self,n_rbf): + # store the log binomial coefficients + # Loop through each value from 0 to n_rbf-1 + log_binomial_coeffs = [ + self.log_binomial_coefficient(n_rbf - 1, x) for x in range(n_rbf) + ] + return torch.tensor(log_binomial_coeffs) + + def forward(self, inputs): + exp_x = -self.init_alpha * inputs[...,None] + x = torch.exp(exp_x) + k_term = self.n * torch.where(self.n != 0, torch.log(x), torch.zeros_like(x)) + n_k_term = self.n_k * torch.where(self.n_k != 0, torch.log(1 - x), torch.zeros_like(x)) + y = torch.exp(self.b + k_term + n_k_term) + return y + + +class PhysNetBasisRBF(torch.nn.Module): + + """ + Expand distances in the basis used in PhysNet (see https://arxiv.org/abs/1902.08408) + + width (beta_k) = (2K^⁻1 * (1 - exp(-cutoff)))^-2) + center (mu_k) = equally spaced between exp(-cutoff) and 1 + + """ + + def __init__(self, n_rbf: int, cutoff:float, trainable:bool): + + """ + Args: + n_rbf: total number of basis functions. + cutoff: cutoff basis functions + """ + + super(PhysNetBasisRBF, self).__init__() + self.n_rbf = n_rbf + + # compute offset and width of Gaussian functions + widths = ((2 / self.n_rbf) * (1 - torch.exp(torch.Tensor([-cutoff])))) ** (-2) + r_0 = torch.exp(torch.Tensor([-cutoff])).item() + centers = torch.linspace(r_0,1,self.n_rbf) + + if trainable: + self.widths = torch.nn.Parameter(widths) + self.centers = torch.nn.Parameter(centers) + else: + self.register_buffer("widths", widths) + self.register_buffer("centers", centers) + + + def forward(self, inputs: torch.Tensor): + return torch.exp(-abs(self.widths) * (torch.exp(-inputs[...,None]) - self.centers) ** 2) \ No newline at end of file