Skip to content

Commit 4378d48

Browse files
committed
Rewrite sliced full convolutions as valid
These show up in the gradient of Convolve1D
1 parent 2ada4b6 commit 4378d48

File tree

4 files changed

+110
-3
lines changed

4 files changed

+110
-3
lines changed

pytensor/tensor/rewriting/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytensor.tensor.rewriting.blas_c
44
import pytensor.tensor.rewriting.blas_scipy
55
import pytensor.tensor.rewriting.blockwise
6+
import pytensor.tensor.rewriting.conv
67
import pytensor.tensor.rewriting.einsum
78
import pytensor.tensor.rewriting.elemwise
89
import pytensor.tensor.rewriting.extra_ops

pytensor/tensor/rewriting/conv.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from pytensor.graph.basic import Constant
2+
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
3+
from pytensor.tensor.blockwise import Blockwise
4+
from pytensor.tensor.rewriting.basic import register_specialize, register_stabilize
5+
from pytensor.tensor.signal import convolve1d
6+
from pytensor.tensor.signal.conv import Convolve1d
7+
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
8+
9+
10+
@register_stabilize
11+
@register_specialize
12+
@node_rewriter([Subtensor])
13+
def local_sliced_full_conv_to_valid_conv(fgraph, node):
14+
"""Rewrite sliced full conv that are equivalent to valid.
15+
16+
The gradient of a valid Conv1d always implements the worst case scenario - full convolution -
17+
because it would need to know which input is larger to do something smarter.
18+
If we find out (through rewrites or static shape) we provide the direct implementation
19+
which can be orders of magnitude faster.
20+
21+
# if x.shape[-1] > y.shape[-1]
22+
# z = convolve1d(x, y, mode="full")
23+
# z[..., y.shape[-1] - 1: z.shape[-1] - y.shape[-1] - 1] -> convolve1d(x, y, mode="valid")
24+
"""
25+
conv, *other_idx_vars = node.inputs
26+
27+
if not (
28+
conv.owner is not None
29+
and isinstance(conv.owner.op, Blockwise)
30+
and isinstance(conv.owner.op.core_op, Convolve1d)
31+
and conv.owner.op.core_op.mode == "full"
32+
):
33+
return None
34+
35+
# Check we have an (a:b) constant slice at the last axis of the input
36+
idx_list = node.op.idx_list
37+
if not (len(idx_list) == conv.type.ndim and isinstance(idx_list[-1], slice)):
38+
return None
39+
40+
last_slice = idx_list[-1]
41+
if not (
42+
last_slice.start is not None
43+
and last_slice.stop is not None
44+
and last_slice.step is None
45+
):
46+
return None
47+
48+
*other_idx_vars, start, stop = other_idx_vars
49+
if not (isinstance(start, Constant) and isinstance(stop, Constant)):
50+
return None
51+
52+
x, y = conv.owner.inputs
53+
len_x = x.type.shape[-1]
54+
len_y = y.type.shape[-1]
55+
if len_x is None or len_y is None:
56+
return None
57+
58+
start, stop = start.data, stop.data
59+
if len_x < len_y:
60+
# Convolution is symmetric with input order
61+
x, y = y, x
62+
len_x, len_y = len_y, len_x
63+
64+
if (
65+
start == len_y - 1
66+
# equivalent to stop = conv.shape[-1] - len_y - 1
67+
and stop == start + (len_x - len_y) + 1
68+
):
69+
new_conv = convolve1d(x, y, mode="valid")
70+
copy_stack_trace(conv, new_conv)
71+
72+
if other_idx_vars:
73+
# If there were more than just empty slices besides the last one
74+
new_indices = indices_from_subtensor(idx_list[:-1], other_idx_vars)
75+
new_conv = new_conv[new_indices]
76+
copy_stack_trace(node.out, new_conv)
77+
78+
return [new_conv]

pytensor/tensor/signal/conv.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,14 @@ def L_op(self, inputs, outputs, output_grads):
7575
n = in1.shape[0]
7676
k = in2.shape[0]
7777
kmn = maximum(0, k - n)
78-
nkm = maximum(0, n - k)
78+
nmk = maximum(0, n - k)
7979
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
8080
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
81+
# There is a rewrite that optimizes this case when n, k are static
8182
in1_bar = full_conv(grad, in2[::-1])
8283
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn]
8384
in2_bar = full_conv(grad, in1[::-1])
84-
in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm]
85+
in2_bar = in2_bar[nmk : in2_bar.shape[0] - nmk]
8586

8687
return [in1_bar, in2_bar]
8788

tests/tensor/signal/test_conv.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from scipy.signal import convolve as scipy_convolve
66

77
from pytensor import config, function, grad
8-
from pytensor.graph import ancestors, rewrite_graph
8+
from pytensor.graph.basic import ancestors, io_toposort
9+
from pytensor.graph.rewriting import rewrite_graph
910
from pytensor.tensor import matrix, vector
1011
from pytensor.tensor.blockwise import Blockwise
1112
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
@@ -82,3 +83,29 @@ def test_convolve1d_batch_graph(mode):
8283
]
8384
# Check any Blockwise are just Conv1d
8485
assert all(isinstance(node.op.core_op, Convolve1d) for node in blockwise_nodes)
86+
87+
88+
@pytest.mark.parametrize("static_shape", [False, True])
89+
def test_convolve1d_valid_grad_rewrite(static_shape):
90+
"""Test that we don't do a useless full convolve1d when taking the gradient of a valid convolve wrt to the smallest input.
91+
92+
This can only be achieved when the two inputs have static shapes, so we know which one is larger
93+
"""
94+
larger = vector("larger", shape=(128 if static_shape else None,))
95+
smaller = vector("smaller", shape=(64 if static_shape else None,))
96+
out = convolve1d(larger, smaller, mode="valid")
97+
grad_out = rewrite_graph(
98+
grad(out.sum(), wrt=smaller),
99+
include=(
100+
"ShapeOpt",
101+
"canonicalize",
102+
"stabilize",
103+
"local_useless_unbatched_blockwise",
104+
),
105+
)
106+
[conv_op] = [
107+
node.op
108+
for node in io_toposort([larger, smaller], [grad_out])
109+
if isinstance(node.op, Convolve1d)
110+
]
111+
assert conv_op.mode == ("valid" if static_shape else "full")

0 commit comments

Comments
 (0)