From b675d80f52ed55d973f0dde894d0a0cbc7d3dc2a Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 16 Jan 2025 18:50:41 +0000 Subject: [PATCH 1/3] remove r_max from print ZBL --- mace/cli/run_train.py | 2 +- mace/modules/radial.py | 71 +++++++++++++++++++++++++++++------------- 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index b2a9c7bc..26d8a4f5 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -646,7 +646,7 @@ def run(args: argparse.Namespace) -> None: distributed_model = DDP(model, device_ids=[local_rank]) else: distributed_model = None - print("MODEL", model) + tools.train( model=model, loss_fn=loss_fn, diff --git a/mace/modules/radial.py b/mace/modules/radial.py index cf626c78..d5f4e131 100644 --- a/mace/modules/radial.py +++ b/mace/modules/radial.py @@ -219,7 +219,7 @@ def forward( return V_ZBL.squeeze(-1) def __repr__(self): - return f"{self.__class__.__name__}(r_max={self.r_max}, c={self.c})" + return f"{self.__class__.__name__}(c={self.c})" @compile_mode("script") @@ -282,13 +282,28 @@ def __repr__(self): ) -@simplify_if_compile @compile_mode("script") class SoftTransform(torch.nn.Module): - """Soft Transform.""" + """ + Tanh-based smooth transformation: + T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))], + which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0. + """ - def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False): + def __init__(self, alpha: float = 4.0, trainable=False): + """ + Args: + p1 (float): Lower "clamp" point. + alpha (float): Steepness; if None, defaults to ~6/(r0-p1). + trainable (bool): Whether to make parameters trainable. + """ super().__init__() + # Initialize parameters + self.register_buffer( + "alpha", torch.tensor(alpha, dtype=torch.get_default_dtype()) + ) + if trainable: + self.alpha = torch.nn.Parameter(self.alpha.clone()) self.register_buffer( "covalent_radii", torch.tensor( @@ -296,20 +311,24 @@ def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False): dtype=torch.get_default_dtype(), ), ) - if trainable: - self.a = torch.nn.Parameter(torch.tensor(a, requires_grad=True)) - self.b = torch.nn.Parameter(torch.tensor(b, requires_grad=True)) - else: - self.register_buffer("a", torch.tensor(a)) - self.register_buffer("b", torch.tensor(b)) - def forward( + def compute_r_0( self, - x: torch.Tensor, node_attrs: torch.Tensor, edge_index: torch.Tensor, atomic_numbers: torch.Tensor, ) -> torch.Tensor: + """ + Compute r_0 based on atomic information. + + Args: + node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers). + edge_index (torch.Tensor): Edge index indicating connections. + atomic_numbers (torch.Tensor): Atomic numbers. + + Returns: + torch.Tensor: r_0 values for each edge. + """ sender = edge_index[0] receiver = edge_index[1] node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( @@ -317,14 +336,24 @@ def forward( ) Z_u = node_atomic_numbers[sender] Z_v = node_atomic_numbers[receiver] - r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4 - r_over_r_0 = x / r_0 - y = ( - x - + (1 / 2) * torch.tanh(-r_over_r_0 - self.a * torch.pow(r_over_r_0, self.b)) - + 1 / 2 - ) - return y + r_0: torch.Tensor = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] + return r_0 + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + + r_0 = self.compute_r_0(node_attrs, edge_index, atomic_numbers) + p_0 = (3 / 4) * r_0 + p_1 = (4 / 3) * r_0 + m = 0.5 * (p_0 + p_1) + alpha = self.alpha / (p_1 - p_0) + s_x = 0.5 * (1.0 + torch.tanh(alpha * (x - m))) + return p_0 + (x - p_0) * s_x def __repr__(self): - return f"{self.__class__.__name__}(a={self.a.item()}, b={self.b.item()})" + return f"{self.__class__.__name__}(alpha={self.alpha.item():.4f})" From c7adcbf034b477c11d7d539a378cd8ca20effc11 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 16 Jan 2025 19:31:31 +0000 Subject: [PATCH 2/3] remove unused import from radial --- mace/modules/radial.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mace/modules/radial.py b/mace/modules/radial.py index d5f4e131..ff69b43e 100644 --- a/mace/modules/radial.py +++ b/mace/modules/radial.py @@ -11,7 +11,6 @@ import torch from e3nn.util.jit import compile_mode -from mace.tools.compile import simplify_if_compile from mace.tools.scatter import scatter_sum From e0f7a5b7a848c7eb5b511e2b751edb7f6d47ed85 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Thu, 16 Jan 2025 19:36:48 +0000 Subject: [PATCH 3/3] fix formatting --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 26d8a4f5..38feaa40 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -646,7 +646,7 @@ def run(args: argparse.Namespace) -> None: distributed_model = DDP(model, device_ids=[local_rank]) else: distributed_model = None - + tools.train( model=model, loss_fn=loss_fn,