diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index 2e79c0ab..b5e79df6 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -86,19 +86,21 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: for mul, d in zip(self.muls, self.dims): field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] ix += mul * d - if hasattr(self, "cueq_config") and self.cueq_config is not None: - if self.cueq_config.layout_str == "mul_ir": - field = field.reshape(batch, mul, d) - else: - field = field.reshape(batch, d, mul) + if hasattr(self, "cueq_config"): + if self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + field = field.reshape(batch, mul, d) + else: + field = field.reshape(batch, d, mul) else: field = field.reshape(batch, mul, d) out.append(field) - if hasattr(self, "cueq_config") and self.cueq_config is not None: - if self.cueq_config.layout_str == "mul_ir": - return torch.cat(out, dim=-1) - return torch.cat(out, dim=-2) + if hasattr(self, "cueq_config"): + if self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + return torch.cat(out, dim=-1) + return torch.cat(out, dim=-2) return torch.cat(out, dim=-1)