diff --git a/botorch/models/transforms/utils.py b/botorch/models/transforms/utils.py index a956ea32b2..b2dca8a4af 100644 --- a/botorch/models/transforms/utils.py +++ b/botorch/models/transforms/utils.py @@ -60,7 +60,7 @@ def norm_to_lognorm(mu: Tensor, Cov: Tensor) -> tuple[Tensor, Tensor]: diag = torch.diagonal(Cov, dim1=-1, dim2=-2) b = mu + 0.5 * diag mu_ln = torch.exp(b) - Cov_ln = (torch.exp(Cov) - 1) * torch.exp(b.unsqueeze(-1) + b.unsqueeze(-2)) + Cov_ln = torch.special.expm1(Cov) * torch.exp(b.unsqueeze(-1) + b.unsqueeze(-2)) return mu_ln, Cov_ln @@ -88,7 +88,7 @@ def norm_to_lognorm_variance(mu: Tensor, var: Tensor) -> Tensor: The `batch_shape x n` variance vector of the log-Normal distribution. """ b = mu + 0.5 * var - return (torch.exp(var) - 1) * torch.exp(2 * b) + return torch.special.expm1(var) * torch.exp(2 * b) def expand_and_copy_tensor(X: Tensor, batch_shape: torch.Size) -> Tensor: