-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathmace_radial.py
111 lines (92 loc) · 3.62 KB
/
mace_radial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
###########################################################################################
# Radial basis and cutoff
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################
import numpy as np
import torch
from e3nn.util.jit import compile_mode
@compile_mode("script")
class BesselBasis(torch.nn.Module):
"""
Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020.
Equation (7)
"""
def __init__(self, r_max: float, num_basis=8, trainable=False):
super().__init__()
bessel_weights = (
np.pi
/ r_max
* torch.linspace(
start=1.0,
end=num_basis,
steps=num_basis,
dtype=torch.get_default_dtype(),
)
)
if trainable:
self.bessel_weights = torch.nn.Parameter(bessel_weights)
else:
self.register_buffer("bessel_weights", bessel_weights)
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
self.register_buffer(
"prefactor",
torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()),
)
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
numerator = torch.sin(self.bessel_weights * x) # [..., num_basis]
return self.prefactor * (numerator / x)
def __repr__(self):
return (
f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, "
f"trainable={self.bessel_weights.requires_grad})"
)
@compile_mode("script")
class GaussianBasis(torch.nn.Module):
"""
Gaussian basis functions
"""
def __init__(self, r_max: float, num_basis=128, trainable=False):
super().__init__()
gaussian_weights = torch.linspace(
start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype()
)
if trainable:
self.gaussian_weights = torch.nn.Parameter(
gaussian_weights, requires_grad=True
)
else:
self.register_buffer("gaussian_weights", gaussian_weights)
self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2
def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
x = x - self.gaussian_weights
return torch.exp(self.coeff * torch.pow(x, 2))
@compile_mode("script")
class PolynomialCutoff(torch.nn.Module):
"""
Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020.
Equation (8)
"""
p: torch.Tensor
r_max: torch.Tensor
def __init__(self, r_max: float, p=6):
super().__init__()
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# yapf: disable
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
)
# yapf: enable
# noinspection PyUnresolvedReferences
return envelope * (x < self.r_max)
def __repr__(self):
return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"