Skip to content

Commit

Permalink
add mask to profiling
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Mar 2, 2023
1 parent f18eb46 commit d78dc6c
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions mace_jax/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from profile_nn_jax import profile
except ImportError:

def profile(_, x):
def profile(_, x, __=None):
return x


Expand Down Expand Up @@ -106,17 +106,21 @@ def __call__(
node_specie: jnp.ndarray, # [n_nodes] int between 0 and num_species-1
senders: jnp.ndarray, # [n_edges]
receivers: jnp.ndarray, # [n_edges]
node_mask: Optional[jnp.ndarray] = None, # [n_nodes] only used for profiling
) -> e3nn.IrrepsArray:
assert vectors.ndim == 2 and vectors.shape[1] == 3
assert node_specie.ndim == 1
assert senders.ndim == 1 and receivers.ndim == 1
assert vectors.shape[0] == senders.shape[0] == receivers.shape[0]

if node_mask is None:
node_mask = jnp.ones(node_specie.shape[0], dtype=jnp.bool_)

# Embeddings
node_feats = self.node_embedding(node_specie).astype(
vectors.dtype
) # [n_nodes, feature * irreps]
node_feats = profile("embedding: node_feats", node_feats)
node_feats = profile("embedding: node_feats", node_feats, node_mask[:, None])

lengths = safe_norm(vectors, axis=-1)

Expand Down Expand Up @@ -162,13 +166,7 @@ def __call__(
symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
off_diagonal=self.off_diagonal,
name=f"layer_{i}",
)(
node_feats,
node_specie,
edge_attrs,
senders,
receivers,
)
)(node_feats, node_specie, edge_attrs, senders, receivers, node_mask)
outputs += [node_outputs] # list of [n_nodes, output_irreps]

return e3nn.stack(outputs, axis=1) # [n_nodes, num_interactions, output_irreps]
Expand Down Expand Up @@ -221,8 +219,12 @@ def __call__(
edge_attrs: e3nn.IrrepsArray, # [n_edges, irreps]
senders: jnp.ndarray, # [n_edges]
receivers: jnp.ndarray, # [n_edges]
node_mask: Optional[jnp.ndarray] = None, # [n_nodes] only used for profiling
):
node_feats = profile(f"{self.name}: node_feats", node_feats)
if node_mask is None:
node_mask = jnp.ones(node_specie.shape[0], dtype=jnp.bool_)

node_feats = profile(f"{self.name}: input", node_feats, node_mask[:, None])

sc = None
if not self.first:
Expand All @@ -233,7 +235,7 @@ def __call__(
)(
node_specie, node_feats
) # [n_nodes, feature * hidden_irreps]
sc = profile(f"{self.name}: self-connexion", sc)
sc = profile(f"{self.name}: self-connexion", sc, node_mask[:, None])

node_feats = InteractionBlock(
target_irreps=self.num_features * self.interaction_irreps,
Expand All @@ -251,7 +253,9 @@ def __call__(
else:
node_feats /= jnp.sqrt(self.avg_num_neighbors)

node_feats = profile(f"{self.name}: node_feats after interaction", node_feats)
node_feats = profile(
f"{self.name}: interaction", node_feats, node_mask[:, None]
)

if self.first:
# Selector TensorProduct
Expand All @@ -261,7 +265,7 @@ def __call__(
name="skip_tp_first",
)(node_specie, node_feats)
node_feats = profile(
f"{self.name}: node_feats after skip_tp_first", node_feats
f"{self.name}: skip_tp_first", node_feats, node_mask[:, None]
)
sc = None

Expand All @@ -273,7 +277,9 @@ def __call__(
off_diagonal=self.off_diagonal,
)(node_feats=node_feats, node_specie=node_specie)

node_feats = profile(f"{self.name}: node_feats after tensor power", node_feats)
node_feats = profile(
f"{self.name}: tensor power", node_feats, node_mask[:, None]
)

if sc is not None:
node_feats = node_feats + sc # [n_nodes, feature * hidden_irreps]
Expand All @@ -291,5 +297,5 @@ def __call__(
node_feats
) # [n_nodes, output_irreps]

node_outputs = profile(f"{self.name}: node_outputs", node_outputs)
node_outputs = profile(f"{self.name}: output", node_outputs, node_mask[:, None])
return node_outputs, node_feats

0 comments on commit d78dc6c

Please sign in to comment.