Skip to content

Commit

Permalink
Merge branch 'main' of github.com:ACEsuit/mace-jax
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Apr 7, 2023
2 parents d988c08 + 605c645 commit 59b30f0
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mace_jax/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
off_diagonal: bool = False,
interaction_irreps: Union[str, e3nn.Irreps] = "o3_restricted", # or o3_full
node_embedding: hk.Module = LinearNodeEmbeddingBlock,
skip_connection_first_layer: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
self.off_diagonal = off_diagonal
self.max_ell = max_ell
self.soft_normalization = soft_normalization
self.skip_connection_first_layer = skip_connection_first_layer

# Embeddings
self.node_embedding = node_embedding(
Expand Down Expand Up @@ -157,6 +159,7 @@ def __call__(
symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
off_diagonal=self.off_diagonal,
soft_normalization=self.soft_normalization,
skip_connection_first_layer=self.skip_connection_first_layer,
name=f"layer_{i}",
)(
vectors,
Expand Down Expand Up @@ -196,6 +199,7 @@ def __init__(
# ReadoutBlock:
output_irreps: e3nn.Irreps,
readout_mlp_irreps: e3nn.Irreps,
skip_connection_first_layer: bool = False,
) -> None:
super().__init__(name=name)

Expand All @@ -215,6 +219,7 @@ def __init__(
self.symmetric_tensor_product_basis = symmetric_tensor_product_basis
self.off_diagonal = off_diagonal
self.soft_normalization = soft_normalization
self.skip_connection_first_layer = skip_connection_first_layer

def __call__(
self,
Expand All @@ -232,7 +237,7 @@ def __call__(
node_feats = profile(f"{self.name}: input", node_feats, node_mask[:, None])

sc = None
if not self.first:
if not self.first or self.skip_connection_first_layer:
sc = e3nn.haiku.Linear(
self.num_features * self.hidden_irreps,
num_indexed_weights=self.num_species,
Expand Down Expand Up @@ -274,7 +279,6 @@ def __call__(
node_feats = profile(
f"{self.name}: skip_tp_first", node_feats, node_mask[:, None]
)
sc = None

node_feats = EquivariantProductBasisBlock(
target_irreps=self.num_features * self.hidden_irreps,
Expand Down

0 comments on commit 59b30f0

Please sign in to comment.