Skip to content

Commit

Permalink
Add a Numba implementation for Generator.dirichlet
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 9, 2023
1 parent 2d84709 commit 8da6847
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
79 changes: 78 additions & 1 deletion aesara/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
from numba import types
from numba.extending import overload
from numba.extending import overload, overload_method, register_jitable
from numba.np.random.distributions import random_beta, random_standard_gamma
from numba.np.random.generator_methods import check_size, check_types, is_nonelike

import aesara.tensor.random.basic as aer
from aesara.graph.basic import Apply
Expand Down Expand Up @@ -296,3 +298,78 @@ def dirichlet_rv(rng, size, dtype, alphas):
return (rng, rng.dirichlet(alphas, size))

return dirichlet_rv


@register_jitable
def random_dirichlet(bitgen, alpha, size):
"""
This implementation is straight from ``numpy/random/_generator.pyx``.
"""

k = len(alpha)
alpha_arr = np.asarray(alpha, dtype=np.float64)

if np.any(np.less_equal(alpha_arr, 0)):
raise ValueError("alpha <= 0")

shape = size + (k,)

diric = np.zeros(shape, np.float64)

i = 0
totsize = diric.size

if (k > 0) and (alpha_arr.max() < 0.1):
alpha_csum_arr = np.empty_like(alpha_arr)
csum = 0.0
for j in range(k - 1, -1, -1):
csum += alpha_arr[j]
alpha_csum_arr[j] = csum

while i < totsize:
acc = 1.0
for j in range(k - 1):
v = random_beta(bitgen, alpha_arr[j], alpha_csum_arr[j + 1])
diric[i + j] = acc * v
acc *= 1.0 - v
diric[i + k - 1] = acc
i = i + k

else:
while i < totsize:
acc = 0.0
for j in range(k):
diric[i + j] = random_standard_gamma(bitgen, alpha_arr[j])
acc = acc + diric[i + j]
invacc = 1.0 / acc
for j in range(k):
diric[i + j] = diric[i + j] * invacc
i = i + k

return diric


@overload_method(types.NumPyRandomGeneratorType, "dirichlet")
def NumPyRandomGeneratorType_dirichlet(inst, alphas, size=None):
check_types(alphas, [types.Array, types.List], "alphas")

if isinstance(size, types.Omitted):
size = size.value

if is_nonelike(size):

def impl(inst, alphas, size=None):
return random_dirichlet(inst.bit_generator, alphas, ())

elif isinstance(size, (int, types.Integer)):

def impl(inst, alphas, size=None):
return random_dirichlet(inst.bit_generator, alphas, (size,))

else:
check_size(size)

def impl(inst, alphas, size=None):
return random_dirichlet(inst.bit_generator, alphas, size)

return impl
1 change: 0 additions & 1 deletion tests/link/numba/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ def test_CategoricalRV(dist_args, size, cm):
)


@pytest.mark.skip(reason="Not yet supported in Numba via `Generator`s")
@pytest.mark.parametrize(
"a, size, cm",
[
Expand Down

0 comments on commit 8da6847

Please sign in to comment.