Skip to content

Commit

Permalink
add a non linear embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Nov 12, 2024
1 parent 08096c9 commit 6686788
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
18 changes: 10 additions & 8 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,26 @@ def __init__(
MLP_irreps: o3.Irreps,
gate: Optional[Callable],
irrep_out: o3.Irreps = o3.Irreps("0e"),
embedding_dim: int = 1,
num_heads: int = 1,
):
super().__init__()
self.hidden_irreps = MLP_irreps
self.hidden_irreps = (embedding_dim * MLP_irreps).simplify()
self.num_heads = num_heads
self.embedding_dim = embedding_dim
self.num_mlp_irreps = MLP_irreps.count((0, 1))
self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps)
self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate])
self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out)
self.non_linearity = nn.Activation(irreps_in=MLP_irreps, acts=[gate])
self.linear_2 = o3.Linear(irreps_in=MLP_irreps, irreps_out=irrep_out)

def forward(
self, x: torch.Tensor, node_heads_feats: Optional[torch.Tensor] = None
) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
x = self.non_linearity(self.linear_1(x))
x = self.linear_1(x)
if node_heads_feats is not None:
return torch.einsum(
"...h, ...h->...", self.linear_2(x), node_heads_feats
).unsqueeze(-1)
return self.linear_2(x) # [n_nodes, len(heads)]
x = x.view(-1, self.num_mlp_irreps, self.embedding_dim)
x = torch.einsum("b...h, bh-> b...", x, node_heads_feats)
return self.linear_2(self.non_linearity(x)) # [n_nodes, len(heads)]


@compile_mode("script")
Expand Down
3 changes: 2 additions & 1 deletion mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def __init__(
hidden_irreps_out,
MLP_irreps,
gate,
o3.Irreps(f"{self.readout_dim}x0e"),
o3.Irreps(f"0e"),
head_emb_dim,
len(heads),
)
)
Expand Down

0 comments on commit 6686788

Please sign in to comment.