diff --git a/mace/modules/radial.py b/mace/modules/radial.py index ea48ce30..cae2aa71 100644 --- a/mace/modules/radial.py +++ b/mace/modules/radial.py @@ -130,7 +130,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.calculate_envelope(x, self.r_max, self.p.to(torch.int)) @staticmethod - def calculate_envelope(x: torch.Tensor, r_max: float, p: int) -> torch.Tensor: + def calculate_envelope( + x: torch.Tensor, r_max: torch.Tensor, p: int + ) -> torch.Tensor: r_over_r_max = x / r_max envelope = ( 1.0