Skip to content

Commit

Permalink
Implement work-around for numba issue numba/numba#8215 causing a segf…
Browse files Browse the repository at this point in the history
…ault on M1 when using literal_unroll() with bools.

Closes #1023.
  • Loading branch information
Thomas Wiecki authored and brandonwillard committed Jul 1, 2022
1 parent d09e222 commit ecd6b49
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions aesara/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from textwrap import indent

import numba
import numpy as np

from aesara.link.numba.dispatch import basic as numba_basic
Expand Down Expand Up @@ -198,11 +197,13 @@ def makevector({", ".join(input_names)}):

@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
op_axis = tuple(op.axis.items())
# Make sure op_axis only has ints. This way we can avoid literal_unroll
# which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215
op_axis = tuple((axis, int(value)) for axis, value in op.axis.items())

@numba_basic.numba_njit
def rebroadcast(x):
for axis, value in numba.literal_unroll(op_axis):
for axis, value in op_axis:
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
Expand Down

0 comments on commit ecd6b49

Please sign in to comment.