Skip to content

Commit

Permalink
vector IrrepsArray
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Mar 15, 2023
1 parent a39b84e commit bde5097
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions mace_jax/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(

def __call__(
self,
vectors: jnp.ndarray, # [n_edges, 3]
vectors: e3nn.IrrepsArray, # [n_edges, 3]
node_specie: jnp.ndarray, # [n_nodes] int between 0 and num_species-1
senders: jnp.ndarray, # [n_edges]
receivers: jnp.ndarray, # [n_edges]
Expand All @@ -121,8 +121,10 @@ def __call__(
) # [n_nodes, feature * irreps]
node_feats = profile("embedding: node_feats", node_feats, node_mask[:, None])

radial_embedding = self.radial_embedding(safe_norm(vectors, axis=-1))
vectors = e3nn.IrrepsArray("1o", vectors)
if not (hasattr(vectors, "irreps") and hasattr(vectors, "array")):
vectors = e3nn.IrrepsArray("1o", vectors)

radial_embedding = self.radial_embedding(safe_norm(vectors.array, axis=-1))

# Interactions
outputs = []
Expand Down

0 comments on commit bde5097

Please sign in to comment.