Skip to content

Commit

Permalink
Add a Numba implementation for Generator.gumbel
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 9, 2023
1 parent 8da6847 commit 4b26667
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
38 changes: 38 additions & 0 deletions aesara/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import copy
from math import log
from textwrap import dedent, indent
from typing import Callable, Optional

Expand All @@ -8,6 +9,7 @@
from numba import types
from numba.extending import overload, overload_method, register_jitable
from numba.np.random.distributions import random_beta, random_standard_gamma
from numba.np.random.generator_core import next_double
from numba.np.random.generator_methods import check_size, check_types, is_nonelike

import aesara.tensor.random.basic as aer
Expand Down Expand Up @@ -373,3 +375,39 @@ def impl(inst, alphas, size=None):
return random_dirichlet(inst.bit_generator, alphas, size)

return impl


@register_jitable
def random_gumbel(bitgen, loc, scale):
"""
This implementation is adapted from ``numpy/random/src/distributions/distributions.c``.
"""
while True:
u = 1.0 - next_double(bitgen)
if u < 1.0:
return loc - scale * log(-log(u))


@overload_method(types.NumPyRandomGeneratorType, "gumbel")
def NumPyRandomGeneratorType_gumbel(inst, loc=0.0, scale=1.0, size=None):
check_types(loc, [types.Float, types.Integer, int, float], "loc")
check_types(scale, [types.Float, types.Integer, int, float], "scale")

if isinstance(size, types.Omitted):
size = size.value

if is_nonelike(size):

def impl(inst, loc=0.0, scale=1.0, size=None):
return random_gumbel(inst.bit_generator, loc, scale)

else:
check_size(size)

def impl(inst, loc=0.0, scale=1.0, size=None):
out = np.empty(size)
for i in np.ndindex(size):
out[i] = random_gumbel(inst.bit_generator, loc, scale)
return out

return impl
5 changes: 1 addition & 4 deletions tests/link/numba/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
"chi2",
lambda *args: args,
),
pytest.param(
(
aer.gumbel,
[
set_test_value(
Expand All @@ -377,9 +377,6 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
(2,),
"gumbel_r",
lambda *args: args,
marks=pytest.mark.skip(
reason="Not yet supported in Numba via `Generator`s"
),
),
(
aer.negative_binomial,
Expand Down

0 comments on commit 4b26667

Please sign in to comment.