From a7902a170577a6bf1335b3aa214b7ead59eb097a Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 14 Jun 2022 10:52:15 -0400 Subject: [PATCH] jax implementation of log1mexp op --- aesara/link/jax/dispatch.py | 12 +++++++++++- tests/link/test_jax.py | 10 +++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/aesara/link/jax/dispatch.py b/aesara/link/jax/dispatch.py index af5d849277..26647a14bb 100644 --- a/aesara/link/jax/dispatch.py +++ b/aesara/link/jax/dispatch.py @@ -17,7 +17,7 @@ from aesara.raise_op import CheckAndRaise from aesara.scalar import Softplus from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second -from aesara.scalar.math import Erf, Erfc, Erfinv, Psi +from aesara.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi from aesara.scan.op import Scan from aesara.scan.utils import ScanArgs from aesara.tensor.basic import ( @@ -1119,6 +1119,16 @@ def erfc(x): return erfc +@jax_funcify.register(Log1mexp) +def jax_funcify_Log1mexp(op, node, **kwargs): + def log1mexp(x): + return jnp.where( + x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x)) + ) + + return log1mexp + + # Commented out because jax.scipy does not have erfcx, # but leaving the implementation in here just in case we ever see # a JAX implementation of Erfcx. diff --git a/tests/link/test_jax.py b/tests/link/test_jax.py index b5303978f4..ca9d28ab7b 100644 --- a/tests/link/test_jax.py +++ b/tests/link/test_jax.py @@ -32,7 +32,7 @@ from aesara.tensor.elemwise import Elemwise from aesara.tensor.math import MaxAndArgmax from aesara.tensor.math import all as at_all -from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log +from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log, log1mexp from aesara.tensor.math import max as at_max from aesara.tensor.math import maximum, prod, psi, sigmoid, softplus from aesara.tensor.math import sum as at_sum @@ -1394,3 +1394,11 @@ def test_psi(): out = psi(x) fg = FunctionGraph([x], [out]) compare_jax_and_py(fg, [3.0]) + + +def test_log1mexp(): + x = vector("x") + out = log1mexp(x) + fg = FunctionGraph([x], [out]) + + compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]])