diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index b480267b..9a4cc8e2 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -74,24 +74,26 @@ def __init__( MLP_irreps: o3.Irreps, gate: Optional[Callable], irrep_out: o3.Irreps = o3.Irreps("0e"), + embedding_dim: int = 1, num_heads: int = 1, ): super().__init__() - self.hidden_irreps = MLP_irreps + self.hidden_irreps = (embedding_dim * MLP_irreps).simplify() self.num_heads = num_heads + self.embedding_dim = embedding_dim + self.num_mlp_irreps = MLP_irreps.count((0, 1)) self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) - self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) - self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out) + self.non_linearity = nn.Activation(irreps_in=MLP_irreps, acts=[gate]) + self.linear_2 = o3.Linear(irreps_in=MLP_irreps, irreps_out=irrep_out) def forward( self, x: torch.Tensor, node_heads_feats: Optional[torch.Tensor] = None ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - x = self.non_linearity(self.linear_1(x)) + x = self.linear_1(x) if node_heads_feats is not None: - return torch.einsum( - "...h, ...h->...", self.linear_2(x), node_heads_feats - ).unsqueeze(-1) - return self.linear_2(x) # [n_nodes, len(heads)] + x = x.view(-1, self.num_mlp_irreps, self.embedding_dim) + x = torch.einsum("b...h, bh-> b...", x, node_heads_feats) + return self.linear_2(self.non_linearity(x)) # [n_nodes, len(heads)] @compile_mode("script") diff --git a/mace/modules/models.py b/mace/modules/models.py index 7f8210e6..73e07004 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -181,7 +181,8 @@ def __init__( hidden_irreps_out, MLP_irreps, gate, - o3.Irreps(f"{self.readout_dim}x0e"), + o3.Irreps(f"0e"), + head_emb_dim, len(heads), ) )