diff --git a/aesara/link/numba/dispatch/random.py b/aesara/link/numba/dispatch/random.py index 8ac77ad52f..b12e553d28 100644 --- a/aesara/link/numba/dispatch/random.py +++ b/aesara/link/numba/dispatch/random.py @@ -6,7 +6,9 @@ import numba.np.unsafe.ndarray as numba_ndarray import numpy as np from numba import types -from numba.extending import overload +from numba.extending import overload, overload_method, register_jitable +from numba.np.random.distributions import random_beta, random_standard_gamma +from numba.np.random.generator_methods import check_size, check_types, is_nonelike import aesara.tensor.random.basic as aer from aesara.graph.basic import Apply @@ -296,3 +298,78 @@ def dirichlet_rv(rng, size, dtype, alphas): return (rng, rng.dirichlet(alphas, size)) return dirichlet_rv + + +@register_jitable +def random_dirichlet(bitgen, alpha, size): + """ + This implementation is straight from ``numpy/random/_generator.pyx``. + """ + + k = len(alpha) + alpha_arr = np.asarray(alpha, dtype=np.float64) + + if np.any(np.less_equal(alpha_arr, 0)): + raise ValueError("alpha <= 0") + + shape = size + (k,) + + diric = np.zeros(shape, np.float64) + + i = 0 + totsize = diric.size + + if (k > 0) and (alpha_arr.max() < 0.1): + alpha_csum_arr = np.empty_like(alpha_arr) + csum = 0.0 + for j in range(k - 1, -1, -1): + csum += alpha_arr[j] + alpha_csum_arr[j] = csum + + while i < totsize: + acc = 1.0 + for j in range(k - 1): + v = random_beta(bitgen, alpha_arr[j], alpha_csum_arr[j + 1]) + diric[i + j] = acc * v + acc *= 1.0 - v + diric[i + k - 1] = acc + i = i + k + + else: + while i < totsize: + acc = 0.0 + for j in range(k): + diric[i + j] = random_standard_gamma(bitgen, alpha_arr[j]) + acc = acc + diric[i + j] + invacc = 1.0 / acc + for j in range(k): + diric[i + j] = diric[i + j] * invacc + i = i + k + + return diric + + +@overload_method(types.NumPyRandomGeneratorType, "dirichlet") +def NumPyRandomGeneratorType_dirichlet(inst, alphas, size=None): + check_types(alphas, [types.Array, types.List], "alphas") + + if isinstance(size, types.Omitted): + size = size.value + + if is_nonelike(size): + + def impl(inst, alphas, size=None): + return random_dirichlet(inst.bit_generator, alphas, ()) + + elif isinstance(size, (int, types.Integer)): + + def impl(inst, alphas, size=None): + return random_dirichlet(inst.bit_generator, alphas, (size,)) + + else: + check_size(size) + + def impl(inst, alphas, size=None): + return random_dirichlet(inst.bit_generator, alphas, size) + + return impl diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index 0b920b4861..7426bcb70d 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -520,7 +520,6 @@ def test_CategoricalRV(dist_args, size, cm): ) -@pytest.mark.skip(reason="Not yet supported in Numba via `Generator`s") @pytest.mark.parametrize( "a, size, cm", [