Skip to content

Commit

Permalink
fixing device, argparser and normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
CheukHinHoJerry committed Oct 30, 2024
1 parent 42b9728 commit ac44b5e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
6 changes: 3 additions & 3 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,12 +694,12 @@ def _setup(self) -> None:
self.reshape = reshape_irreps(self.irreps_out)

elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]:
self.linear = []
self.linear = torch.nn.ModuleList([])
# Selector TensorProduct
self.skip_tp = o3.FullyConnectedTensorProduct(
self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
)
self.reshape = []
self.reshape = torch.nn.ModuleList([])
for _ in range(self.correlation):
self.linear.append(o3.Linear(
irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
Expand Down Expand Up @@ -744,7 +744,7 @@ def forward(
message = torch.cat((message, _message), dim = -1)
print("shape of message: ", message.shape)
return (
message,
message / self.avg_num_neighbors,
sc
)

Expand Down
5 changes: 3 additions & 2 deletions mace/modules/symmetric_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
+[f"bk{out_channel_idx[nu-1]}"]
+[f"->bk{out_channel_idx[:nu]}"]
),
*[c_tensor for _ in range(nu)])
*[c_tensor for _ in range(nu)]) / torch.jit._builtins.math.factorial(nu)

else:
if self.tensor_format == "non_symmetric_tucker":
Expand All @@ -334,7 +334,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
+[f"bk{out_channel_idx[nu-1]}"]
+[f"->bk{out_channel_idx[:nu]}"]
),
*[c_tensor for _ in range(nu)])
*[c_tensor for _ in range(nu)]) / torch.jit._builtins.math.factorial(nu)


outs[nu] = torch.einsum(
Expand All @@ -357,6 +357,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
# reshape kLM
outs[nu] = outs[nu].reshape(outs[nu].shape[0], -1)

# / factorial(nu) because of extra work done for convenience
return torch.cat([outs[nu] for nu in range(self.correlation, 0, -1)], dim = 1)

elif "cp" in self.tensor_format:
Expand Down
1 change: 1 addition & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
choices=["symmetric_cp",
"non_symmetric_cp",
"symmetric_tucker",
"non_symmetric_tucker",
]
)

Expand Down

0 comments on commit ac44b5e

Please sign in to comment.