Skip to content

Commit

Permalink
add option skip_connection_first_layer
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Mar 23, 2023
1 parent 9fb8ed2 commit 605c645
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mace_jax/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
off_diagonal: bool = False,
interaction_irreps: Union[str, e3nn.Irreps] = "o3_restricted", # or o3_full
node_embedding: hk.Module = LinearNodeEmbeddingBlock,
skip_connection_first_layer: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
self.symmetric_tensor_product_basis = symmetric_tensor_product_basis
self.off_diagonal = off_diagonal
self.max_ell = max_ell
self.skip_connection_first_layer = skip_connection_first_layer

# Embeddings
self.node_embedding = node_embedding(
Expand Down Expand Up @@ -154,6 +156,7 @@ def __call__(
readout_mlp_irreps=self.readout_mlp_irreps,
symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
off_diagonal=self.off_diagonal,
skip_connection_first_layer=self.skip_connection_first_layer,
name=f"layer_{i}",
)(
vectors,
Expand Down Expand Up @@ -192,6 +195,7 @@ def __init__(
# ReadoutBlock:
output_irreps: e3nn.Irreps,
readout_mlp_irreps: e3nn.Irreps,
skip_connection_first_layer: bool = False,
) -> None:
super().__init__(name=name)

Expand All @@ -210,6 +214,7 @@ def __init__(
self.readout_mlp_irreps = readout_mlp_irreps
self.symmetric_tensor_product_basis = symmetric_tensor_product_basis
self.off_diagonal = off_diagonal
self.skip_connection_first_layer = skip_connection_first_layer

def __call__(
self,
Expand All @@ -227,7 +232,7 @@ def __call__(
node_feats = profile(f"{self.name}: input", node_feats, node_mask[:, None])

sc = None
if not self.first:
if not self.first or self.skip_connection_first_layer:
sc = e3nn.haiku.Linear(
self.num_features * self.hidden_irreps,
num_indexed_weights=self.num_species,
Expand Down Expand Up @@ -269,7 +274,6 @@ def __call__(
node_feats = profile(
f"{self.name}: skip_tp_first", node_feats, node_mask[:, None]
)
sc = None

node_feats = EquivariantProductBasisBlock(
target_irreps=self.num_features * self.hidden_irreps,
Expand Down

0 comments on commit 605c645

Please sign in to comment.