From ecd6b49ca988c6f6280a4372461cd11f4ad77c03 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 1 Jul 2022 17:51:35 +0200 Subject: [PATCH] Implement work-around for numba issue https://github.com/numba/numba/issues/8215 causing a segfault on M1 when using literal_unroll() with bools. Closes #1023. --- aesara/link/numba/dispatch/tensor_basic.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/aesara/link/numba/dispatch/tensor_basic.py b/aesara/link/numba/dispatch/tensor_basic.py index 34c14e716a..3f1662e919 100644 --- a/aesara/link/numba/dispatch/tensor_basic.py +++ b/aesara/link/numba/dispatch/tensor_basic.py @@ -1,6 +1,5 @@ from textwrap import indent -import numba import numpy as np from aesara.link.numba.dispatch import basic as numba_basic @@ -198,11 +197,13 @@ def makevector({", ".join(input_names)}): @numba_funcify.register(Rebroadcast) def numba_funcify_Rebroadcast(op, **kwargs): - op_axis = tuple(op.axis.items()) + # Make sure op_axis only has ints. This way we can avoid literal_unroll + # which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215 + op_axis = tuple((axis, int(value)) for axis, value in op.axis.items()) @numba_basic.numba_njit def rebroadcast(x): - for axis, value in numba.literal_unroll(op_axis): + for axis, value in op_axis: if value and x.shape[axis] != 1: raise ValueError( ("Dimension in Rebroadcast's input was supposed to be 1")