Skip to content

Commit

Permalink
Merge pull request #789 from ACEsuit/develop
Browse files Browse the repository at this point in the history
remove r_max from print ZBL
  • Loading branch information
ilyes319 authored Jan 16, 2025
2 parents 0bcbdb6 + e0f7a5b commit 5a91b92
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 23 deletions.
2 changes: 1 addition & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def run(args: argparse.Namespace) -> None:
distributed_model = DDP(model, device_ids=[local_rank])
else:
distributed_model = None
print("MODEL", model)

tools.train(
model=model,
loss_fn=loss_fn,
Expand Down
72 changes: 50 additions & 22 deletions mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
from e3nn.util.jit import compile_mode

from mace.tools.compile import simplify_if_compile
from mace.tools.scatter import scatter_sum


Expand Down Expand Up @@ -219,7 +218,7 @@ def forward(
return V_ZBL.squeeze(-1)

def __repr__(self):
return f"{self.__class__.__name__}(r_max={self.r_max}, c={self.c})"
return f"{self.__class__.__name__}(c={self.c})"


@compile_mode("script")
Expand Down Expand Up @@ -282,49 +281,78 @@ def __repr__(self):
)


@simplify_if_compile
@compile_mode("script")
class SoftTransform(torch.nn.Module):
"""Soft Transform."""
"""
Tanh-based smooth transformation:
T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))],
which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0.
"""

def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False):
def __init__(self, alpha: float = 4.0, trainable=False):
"""
Args:
p1 (float): Lower "clamp" point.
alpha (float): Steepness; if None, defaults to ~6/(r0-p1).
trainable (bool): Whether to make parameters trainable.
"""
super().__init__()
# Initialize parameters
self.register_buffer(
"alpha", torch.tensor(alpha, dtype=torch.get_default_dtype())
)
if trainable:
self.alpha = torch.nn.Parameter(self.alpha.clone())
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
if trainable:
self.a = torch.nn.Parameter(torch.tensor(a, requires_grad=True))
self.b = torch.nn.Parameter(torch.tensor(b, requires_grad=True))
else:
self.register_buffer("a", torch.tensor(a))
self.register_buffer("b", torch.tensor(b))

def forward(
def compute_r_0(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:
"""
Compute r_0 based on atomic information.
Args:
node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers).
edge_index (torch.Tensor): Edge index indicating connections.
atomic_numbers (torch.Tensor): Atomic numbers.
Returns:
torch.Tensor: r_0 values for each edge.
"""
sender = edge_index[0]
receiver = edge_index[1]
node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze(
-1
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4
r_over_r_0 = x / r_0
y = (
x
+ (1 / 2) * torch.tanh(-r_over_r_0 - self.a * torch.pow(r_over_r_0, self.b))
+ 1 / 2
)
return y
r_0: torch.Tensor = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
return r_0

def forward(
self,
x: torch.Tensor,
node_attrs: torch.Tensor,
edge_index: torch.Tensor,
atomic_numbers: torch.Tensor,
) -> torch.Tensor:

r_0 = self.compute_r_0(node_attrs, edge_index, atomic_numbers)
p_0 = (3 / 4) * r_0
p_1 = (4 / 3) * r_0
m = 0.5 * (p_0 + p_1)
alpha = self.alpha / (p_1 - p_0)
s_x = 0.5 * (1.0 + torch.tanh(alpha * (x - m)))
return p_0 + (x - p_0) * s_x

def __repr__(self):
return f"{self.__class__.__name__}(a={self.a.item()}, b={self.b.item()})"
return f"{self.__class__.__name__}(alpha={self.alpha.item():.4f})"

0 comments on commit 5a91b92

Please sign in to comment.