Skip to content

Commit 2ada4b6

Browse files
committed
Faster implementation of numba convolve1d
1 parent 9530ffc commit 2ada4b6

File tree

2 files changed

+98
-7
lines changed

2 files changed

+98
-7
lines changed
+58-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from numba.np.arraymath import _get_inner_prod
23

34
from pytensor.link.numba.dispatch import numba_funcify
45
from pytensor.link.numba.dispatch.basic import numba_njit
@@ -7,10 +8,63 @@
78

89
@numba_funcify.register(Convolve1d)
910
def numba_funcify_Convolve1d(op, node, **kwargs):
11+
# This specialized version is faster than the overloaded numba np.convolve
1012
mode = op.mode
13+
a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype
14+
out_dtype = node.outputs[0].type.dtype
15+
innerprod = _get_inner_prod(a_dtype, b_dtype)
1116

12-
@numba_njit
13-
def conv1d(data, kernel):
14-
return np.convolve(data, kernel, mode=mode)
17+
if mode == "valid":
1518

16-
return conv1d
19+
def valid_convolve1d(x, y):
20+
nx = len(x)
21+
ny = len(y)
22+
if nx < ny:
23+
x, y = y, x
24+
nx, ny = ny, nx
25+
y_flipped = y[::-1]
26+
27+
length = nx - ny + 1
28+
ret = np.empty(length, out_dtype)
29+
30+
for i in range(length):
31+
ret[i] = innerprod(x[i : i + ny], y_flipped)
32+
33+
return ret
34+
35+
return numba_njit(valid_convolve1d)
36+
37+
elif mode == "full":
38+
39+
def full_convolve1d(x, y):
40+
nx = len(x)
41+
ny = len(y)
42+
if nx < ny:
43+
x, y = y, x
44+
nx, ny = ny, nx
45+
y_flipped = y[::-1]
46+
47+
length = nx + ny - 1
48+
ret = np.empty(length, out_dtype)
49+
idx = 0
50+
51+
for i in range(ny - 1):
52+
k = i + 1
53+
ret[idx] = innerprod(x[:k], y_flipped[-k:])
54+
idx = idx + 1
55+
56+
for i in range(nx - ny + 1):
57+
ret[idx] = innerprod(x[i : i + ny], y_flipped)
58+
idx = idx + 1
59+
60+
for i in range(ny - 1):
61+
k = ny - i - 1
62+
ret[idx] = innerprod(x[-k:], y_flipped[:k])
63+
idx = idx + 1
64+
65+
return ret
66+
67+
return numba_njit(full_convolve1d)
68+
69+
else:
70+
raise ValueError(f"Unsupported mode: {mode}")

tests/link/numba/signal/test_conv.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from functools import partial
2+
13
import numpy as np
24
import pytest
35

4-
from pytensor.tensor import dmatrix
6+
from pytensor import function
7+
from pytensor.tensor import dmatrix, tensor
58
from pytensor.tensor.signal import convolve1d
69
from tests.link.numba.test_basic import compare_numba_and_py
710

@@ -10,13 +13,47 @@
1013

1114

1215
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
13-
def test_convolve1d(mode):
16+
@pytest.mark.parametrize("x_smaller", (False, True))
17+
def test_convolve1d(x_smaller, mode):
1418
x = dmatrix("x")
1519
y = dmatrix("y")
16-
out = convolve1d(x[None], y[:, None], mode=mode)
20+
if x_smaller:
21+
out = convolve1d(x[None], y[:, None], mode=mode)
22+
else:
23+
out = convolve1d(y[:, None], x[None], mode=mode)
1724

1825
rng = np.random.default_rng()
1926
test_x = rng.normal(size=(3, 5))
2027
test_y = rng.normal(size=(7, 11))
2128
# Blockwise dispatch for numba can't be run on object mode
2229
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)
30+
31+
32+
@pytest.mark.parametrize("mode", ("full", "valid"), ids=lambda x: f"mode={x}")
33+
@pytest.mark.parametrize("batch", (False, True), ids=lambda x: f"batch={x}")
34+
def test_convolve1d_benchmark(batch, mode, benchmark):
35+
x = tensor(
36+
shape=(
37+
7,
38+
183,
39+
)
40+
if batch
41+
else (183,)
42+
)
43+
y = tensor(shape=(7, 6) if batch else (6,))
44+
out = convolve1d(x, y, mode=mode)
45+
fn = function([x, y], out, mode="NUMBA", trust_input=True)
46+
47+
rng = np.random.default_rng()
48+
x_test = rng.normal(size=(x.type.shape)).astype(x.type.dtype)
49+
y_test = rng.normal(size=(y.type.shape)).astype(y.type.dtype)
50+
51+
np_convolve1d = np.vectorize(
52+
partial(np.convolve, mode=mode), signature="(x),(y)->(z)"
53+
)
54+
55+
np.testing.assert_allclose(
56+
fn(x_test, y_test),
57+
np_convolve1d(x_test, y_test),
58+
)
59+
benchmark(fn, x_test, y_test)

0 commit comments

Comments
 (0)