From 402a38fe5ee91d2d29f9589af867c4133f7d7a4e Mon Sep 17 00:00:00 2001 From: Smit-create Date: Tue, 14 Mar 2023 12:26:41 +0530 Subject: [PATCH] Use disable numba JIT --- tests/link/numba/test_basic.py | 78 ++++------------------------------ 1 file changed, 9 insertions(+), 69 deletions(-) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 15799a3134..c20fd3d99f 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -1,7 +1,5 @@ import contextlib -import inspect from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union -from unittest import mock import numba import numpy as np @@ -108,73 +106,15 @@ def compare_shape_dtype(x, y): def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): """Evaluate the Numba implementation in pure Python for coverage purposes.""" - def py_tuple_setitem(t, i, v): - ll = list(t) - ll[i] = v - return tuple(ll) - - def py_to_scalar(x): - if isinstance(x, np.ndarray): - return x.item() - else: - return x - - def njit_noop(*args, **kwargs): - if len(args) == 1 and callable(args[0]): - return args[0] - else: - return lambda x: x - - def vectorize_noop(*args, **kwargs): - def wrap(fn): - # `numba.vectorize` allows an `out` positional argument. We need - # to account for that - sig = inspect.signature(fn) - nparams = len(sig.parameters) - - def inner_vec(*args): - if len(args) > nparams: - # An `out` argument has been specified for an in-place - # operation - out = args[-1] - out[...] = np.vectorize(fn)(*args[:nparams]) - return out - else: - return np.vectorize(fn)(*args) - - return inner_vec - - if len(args) == 1 and callable(args[0]): - return wrap(args[0], **kwargs) - else: - return wrap - - mocks = [ - mock.patch("numba.njit", njit_noop), - mock.patch("numba.vectorize", vectorize_noop), - mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem), - mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop), - mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop), - mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x), - mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar), - mock.patch( - "aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype", - lambda dtype: dtype, - ), - mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)), - ] - - with contextlib.ExitStack() as stack: - for ctx in mocks: - stack.enter_context(ctx) - - aesara_numba_fn = function( - fn_inputs, - fn_outputs, - mode=mode, - accept_inplace=True, - ) - _ = aesara_numba_fn(*inputs) + numba.config.DISABLE_JIT = True + aesara_numba_fn = function( + fn_inputs, + fn_outputs, + mode=mode, + accept_inplace=True, + ) + _ = aesara_numba_fn(*inputs) + numba.config.DISABLE_JIT = False def compare_numba_and_py(