Skip to content

Commit 9530ffc

Browse files
committed
Rename core Conv1d to Convolve1d
1 parent 35c6999 commit 9530ffc

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

pytensor/link/jax/dispatch/signal/conv.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import jax
22

33
from pytensor.link.jax.dispatch import jax_funcify
4-
from pytensor.tensor.signal.conv import Conv1d
4+
from pytensor.tensor.signal.conv import Convolve1d
55

66

7-
@jax_funcify.register(Conv1d)
8-
def jax_funcify_Conv1d(op, node, **kwargs):
7+
@jax_funcify.register(Convolve1d)
8+
def jax_funcify_Convolve1d(op, node, **kwargs):
99
mode = op.mode
1010

1111
def conv1d(data, kernel):

pytensor/link/numba/dispatch/signal/conv.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from pytensor.link.numba.dispatch import numba_funcify
44
from pytensor.link.numba.dispatch.basic import numba_njit
5-
from pytensor.tensor.signal.conv import Conv1d
5+
from pytensor.tensor.signal.conv import Convolve1d
66

77

8-
@numba_funcify.register(Conv1d)
9-
def numba_funcify_Conv1d(op, node, **kwargs):
8+
@numba_funcify.register(Convolve1d)
9+
def numba_funcify_Convolve1d(op, node, **kwargs):
1010
mode = op.mode
1111

1212
@numba_njit

pytensor/tensor/signal/conv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pytensor.tensor import TensorLike
1616

1717

18-
class Conv1d(Op):
18+
class Convolve1d(Op):
1919
__props__ = ("mode",)
2020
gufunc_signature = "(n),(k)->(o)"
2121

@@ -129,4 +129,4 @@ def convolve1d(
129129
)
130130
mode = "valid"
131131

132-
return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2))
132+
return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2))

tests/tensor/signal/test_conv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytensor.graph import ancestors, rewrite_graph
99
from pytensor.tensor import matrix, vector
1010
from pytensor.tensor.blockwise import Blockwise
11-
from pytensor.tensor.signal.conv import Conv1d, convolve1d
11+
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
1212
from tests import unittest_tools as utt
1313

1414

@@ -81,4 +81,4 @@ def test_convolve1d_batch_graph(mode):
8181
if var.owner is not None and isinstance(var.owner.op, Blockwise)
8282
]
8383
# Check any Blockwise are just Conv1d
84-
assert all(isinstance(node.op.core_op, Conv1d) for node in blockwise_nodes)
84+
assert all(isinstance(node.op.core_op, Convolve1d) for node in blockwise_nodes)

0 commit comments

Comments
 (0)