Skip to content

Commit

Permalink
zero emb for zero edges
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Mar 9, 2023
1 parent d78dc6c commit 6cd52fd
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion mace_jax/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def func(lengths):
)
factor = jnp.mean(func(samples) ** 2).item() ** -0.5

embedding = factor * func(edge_lengths) # [n_edges, num_basis]
embedding = factor * jnp.where(
edge_lengths == 0.0, 0.0, func(edge_lengths)
) # [n_edges, num_basis]
return e3nn.IrrepsArray(f"{embedding.shape[-1]}x0e", embedding)


Expand Down
4 changes: 2 additions & 2 deletions mace_jax/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def __call__(
self.radial_embedding(lengths),
e3nn.spherical_harmonics(
self.sh_irreps,
vectors / lengths[..., None],
normalize=False,
vectors,
normalize=True,
normalization="component",
),
]
Expand Down
2 changes: 1 addition & 1 deletion mace_jax/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def unflatten_dict(xs, sep=None):
def safe_norm(x: jnp.ndarray, axis: int = None, keepdims=False) -> jnp.ndarray:
"""nan-safe norm."""
x2 = jnp.sum(x**2, axis=axis, keepdims=keepdims)
return jnp.where(x2 == 0, 1, x2) ** 0.5
return jnp.where(x2 == 0.0, 0.0, jnp.where(x2 == 0, 1.0, x2) ** 0.5)


def compute_mean_std_atomic_inter_energy(
Expand Down

0 comments on commit 6cd52fd

Please sign in to comment.