Skip to content

Commit

Permalink
Adding logmeanexp and logdiffexp numerical utilities (pytorch#1657)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1657

This commit defines `logdiffexp` and `logmeanexp`, numerical utility functions that are going to be more generally useful for log-space computations.

Reviewed By: Balandat

Differential Revision: D43061278

fbshipit-source-id: a279a2ab7da5e1eb23a20fba0c3498dd52ea37b6
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Feb 7, 2023
1 parent 6475503 commit ffcad4a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
4 changes: 2 additions & 2 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
ndtr as Phi,
phi,
)
from botorch.utils.safe_math import log1mexp
from botorch.utils.safe_math import log1mexp, logmeanexp
from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform
from torch import Tensor

Expand Down Expand Up @@ -587,7 +587,7 @@ def forward(self, X: Tensor) -> Tensor:
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
log_ei = _log_ei_helper(u) + sigma.log()
# this is mathematically - though not numerically - equivalent to log(mean(ei))
return torch.logsumexp(log_ei, dim=-1) - math.log(log_ei.shape[-1])
return logmeanexp(log_ei, dim=-1)


class NoisyExpectedImprovement(ExpectedImprovement):
Expand Down
6 changes: 2 additions & 4 deletions botorch/utils/probability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, Union

import torch
from botorch.utils.safe_math import log1mexp
from botorch.utils.safe_math import logdiffexp
from numpy.polynomial.legendre import leggauss as numpy_leggauss
from torch import BoolTensor, LongTensor, Tensor

Expand Down Expand Up @@ -214,9 +214,7 @@ def log_prob_normal_in(a: Tensor, b: Tensor) -> Tensor:
c = torch.where(rev_cond, -b, a)
b = torch.where(rev_cond, -a, b)
a = c # after we updated b, can assign c to a
log_Phi_b = log_ndtr(b)
# Phi(b) > Phi(a), so 0 > log(Phi(a) / Phi(b)) and we can use log1mexp
return log_Phi_b + log1mexp(log_ndtr(a) - log_Phi_b)
return logdiffexp(log_a=log_ndtr(a), log_b=log_ndtr(b))


def swap_along_dim_(
Expand Down
27 changes: 27 additions & 0 deletions botorch/utils/safe_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,30 @@ def log1mexp(x: Tensor) -> Tensor:
(-x.expm1()).log(),
(-x.exp()).log1p(),
)


def logdiffexp(log_a: Tensor, log_b: Tensor) -> Tensor:
"""Computes log(b - a) accurately given log(a) and log(b).
Assumes, log_b > log_a, i.e. b > a > 0.
Args:
log_a (Tensor): The logarithm of a, assumed to be less than log_b.
log_b (Tensor): The logarithm of b, assumed to be larger than log_a.
Returns:
A Tensor of values corresponding to log(b - a).
"""
return log_b + log1mexp(log_a - log_b)


def logmeanexp(X: Tensor, dim: int = -1) -> Tensor:
"""Computes log(mean(exp(X), dim=dim)).
Args:
X (Tensor): The logarithm of a, assumed to be less than log_b.
dim (int): The dimension over which to compute the mean. Default is -1.
Returns:
A Tensor of values corresponding to log(mean(exp(X), dim=dim)).
"""
return torch.logsumexp(X, dim=dim) - math.log(X.shape[dim])

0 comments on commit ffcad4a

Please sign in to comment.