Skip to content

Commit

Permalink
jax implementation of log1mexp op
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle Caron authored and Ricardo Vieira committed Jun 15, 2022
1 parent 2ccd9cc commit a7902a1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
12 changes: 11 additions & 1 deletion aesara/link/jax/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from aesara.raise_op import CheckAndRaise
from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scalar.math import Erf, Erfc, Erfinv, Psi
from aesara.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
from aesara.scan.op import Scan
from aesara.scan.utils import ScanArgs
from aesara.tensor.basic import (
Expand Down Expand Up @@ -1119,6 +1119,16 @@ def erfc(x):
return erfc


@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
return jnp.where(
x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x))
)

return log1mexp


# Commented out because jax.scipy does not have erfcx,
# but leaving the implementation in here just in case we ever see
# a JAX implementation of Erfcx.
Expand Down
10 changes: 9 additions & 1 deletion tests/link/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as at_all
from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log
from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log, log1mexp
from aesara.tensor.math import max as at_max
from aesara.tensor.math import maximum, prod, psi, sigmoid, softplus
from aesara.tensor.math import sum as at_sum
Expand Down Expand Up @@ -1394,3 +1394,11 @@ def test_psi():
out = psi(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [3.0])


def test_log1mexp():
x = vector("x")
out = log1mexp(x)
fg = FunctionGraph([x], [out])

compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]])

0 comments on commit a7902a1

Please sign in to comment.