diff --git a/aesara/link/numba/dispatch/random.py b/aesara/link/numba/dispatch/random.py index b12e553d28..687ba0e91e 100644 --- a/aesara/link/numba/dispatch/random.py +++ b/aesara/link/numba/dispatch/random.py @@ -1,4 +1,5 @@ from copy import copy +from math import log from textwrap import dedent, indent from typing import Callable, Optional @@ -8,6 +9,7 @@ from numba import types from numba.extending import overload, overload_method, register_jitable from numba.np.random.distributions import random_beta, random_standard_gamma +from numba.np.random.generator_core import next_double from numba.np.random.generator_methods import check_size, check_types, is_nonelike import aesara.tensor.random.basic as aer @@ -373,3 +375,39 @@ def impl(inst, alphas, size=None): return random_dirichlet(inst.bit_generator, alphas, size) return impl + + +@register_jitable +def random_gumbel(bitgen, loc, scale): + """ + This implementation is adapted from ``numpy/random/src/distributions/distributions.c``. + """ + while True: + u = 1.0 - next_double(bitgen) + if u < 1.0: + return loc - scale * log(-log(u)) + + +@overload_method(types.NumPyRandomGeneratorType, "gumbel") +def NumPyRandomGeneratorType_gumbel(inst, loc=0.0, scale=1.0, size=None): + check_types(loc, [types.Float, types.Integer, int, float], "loc") + check_types(scale, [types.Float, types.Integer, int, float], "scale") + + if isinstance(size, types.Omitted): + size = size.value + + if is_nonelike(size): + + def impl(inst, loc=0.0, scale=1.0, size=None): + return random_gumbel(inst.bit_generator, loc, scale) + + else: + check_size(size) + + def impl(inst, loc=0.0, scale=1.0, size=None): + out = np.empty(size) + for i in np.ndindex(size): + out[i] = random_gumbel(inst.bit_generator, loc, scale) + return out + + return impl diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index 7426bcb70d..30cb0a7092 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -362,7 +362,7 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): "chi2", lambda *args: args, ), - pytest.param( + ( aer.gumbel, [ set_test_value( @@ -377,9 +377,6 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): (2,), "gumbel_r", lambda *args: args, - marks=pytest.mark.skip( - reason="Not yet supported in Numba via `Generator`s" - ), ), ( aer.negative_binomial,