Skip to content

Commit

Permalink
solve jit backward compatibility
Browse files Browse the repository at this point in the history
ilyes319 committed Jan 13, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent f9de62c commit 04652dc
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions mace/modules/irreps_tools.py
Original file line number Diff line number Diff line change
@@ -86,19 +86,21 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
for mul, d in zip(self.muls, self.dims):
field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr]
ix += mul * d
if hasattr(self, "cueq_config") and self.cueq_config is not None:
if self.cueq_config.layout_str == "mul_ir":
field = field.reshape(batch, mul, d)
else:
field = field.reshape(batch, d, mul)
if hasattr(self, "cueq_config"):
if self.cueq_config is not None:
if self.cueq_config.layout_str == "mul_ir":
field = field.reshape(batch, mul, d)
else:
field = field.reshape(batch, d, mul)
else:
field = field.reshape(batch, mul, d)
out.append(field)

if hasattr(self, "cueq_config") and self.cueq_config is not None:
if self.cueq_config.layout_str == "mul_ir":
return torch.cat(out, dim=-1)
return torch.cat(out, dim=-2)
if hasattr(self, "cueq_config"):
if self.cueq_config is not None:
if self.cueq_config.layout_str == "mul_ir":
return torch.cat(out, dim=-1)
return torch.cat(out, dim=-2)
return torch.cat(out, dim=-1)


0 comments on commit 04652dc

Please sign in to comment.