Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to disable numba JIT #1470

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 9 additions & 69 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import contextlib
import inspect
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
from unittest import mock

import numba
import numpy as np
Expand Down Expand Up @@ -108,73 +106,15 @@ def compare_shape_dtype(x, y):
def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
"""Evaluate the Numba implementation in pure Python for coverage purposes."""

def py_tuple_setitem(t, i, v):
ll = list(t)
ll[i] = v
return tuple(ll)

def py_to_scalar(x):
if isinstance(x, np.ndarray):
return x.item()
else:
return x

def njit_noop(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
else:
return lambda x: x

def vectorize_noop(*args, **kwargs):
def wrap(fn):
# `numba.vectorize` allows an `out` positional argument. We need
# to account for that
sig = inspect.signature(fn)
nparams = len(sig.parameters)

def inner_vec(*args):
if len(args) > nparams:
# An `out` argument has been specified for an in-place
# operation
out = args[-1]
out[...] = np.vectorize(fn)(*args[:nparams])
return out
else:
return np.vectorize(fn)(*args)

return inner_vec

if len(args) == 1 and callable(args[0]):
return wrap(args[0], **kwargs)
else:
return wrap

mocks = [
mock.patch("numba.njit", njit_noop),
mock.patch("numba.vectorize", vectorize_noop),
mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem),
mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop),
mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop),
mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x),
mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch(
"aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
),
mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)),
]

with contextlib.ExitStack() as stack:
for ctx in mocks:
stack.enter_context(ctx)

aesara_numba_fn = function(
fn_inputs,
fn_outputs,
mode=mode,
accept_inplace=True,
)
_ = aesara_numba_fn(*inputs)
numba.config.DISABLE_JIT = True
aesara_numba_fn = function(
fn_inputs,
fn_outputs,
mode=mode,
accept_inplace=True,
)
_ = aesara_numba_fn(*inputs)
numba.config.DISABLE_JIT = False


def compare_numba_and_py(
Expand Down