File tree 4 files changed +10
-10
lines changed
4 files changed +10
-10
lines changed Original file line number Diff line number Diff line change 1
1
import jax
2
2
3
3
from pytensor .link .jax .dispatch import jax_funcify
4
- from pytensor .tensor .signal .conv import Conv1d
4
+ from pytensor .tensor .signal .conv import Convolve1d
5
5
6
6
7
- @jax_funcify .register (Conv1d )
8
- def jax_funcify_Conv1d (op , node , ** kwargs ):
7
+ @jax_funcify .register (Convolve1d )
8
+ def jax_funcify_Convolve1d (op , node , ** kwargs ):
9
9
mode = op .mode
10
10
11
11
def conv1d (data , kernel ):
Original file line number Diff line number Diff line change 2
2
3
3
from pytensor .link .numba .dispatch import numba_funcify
4
4
from pytensor .link .numba .dispatch .basic import numba_njit
5
- from pytensor .tensor .signal .conv import Conv1d
5
+ from pytensor .tensor .signal .conv import Convolve1d
6
6
7
7
8
- @numba_funcify .register (Conv1d )
9
- def numba_funcify_Conv1d (op , node , ** kwargs ):
8
+ @numba_funcify .register (Convolve1d )
9
+ def numba_funcify_Convolve1d (op , node , ** kwargs ):
10
10
mode = op .mode
11
11
12
12
@numba_njit
Original file line number Diff line number Diff line change 15
15
from pytensor .tensor import TensorLike
16
16
17
17
18
- class Conv1d (Op ):
18
+ class Convolve1d (Op ):
19
19
__props__ = ("mode" ,)
20
20
gufunc_signature = "(n),(k)->(o)"
21
21
@@ -129,4 +129,4 @@ def convolve1d(
129
129
)
130
130
mode = "valid"
131
131
132
- return cast (TensorVariable , Blockwise (Conv1d (mode = mode ))(in1 , in2 ))
132
+ return cast (TensorVariable , Blockwise (Convolve1d (mode = mode ))(in1 , in2 ))
Original file line number Diff line number Diff line change 8
8
from pytensor .graph import ancestors , rewrite_graph
9
9
from pytensor .tensor import matrix , vector
10
10
from pytensor .tensor .blockwise import Blockwise
11
- from pytensor .tensor .signal .conv import Conv1d , convolve1d
11
+ from pytensor .tensor .signal .conv import Convolve1d , convolve1d
12
12
from tests import unittest_tools as utt
13
13
14
14
@@ -81,4 +81,4 @@ def test_convolve1d_batch_graph(mode):
81
81
if var .owner is not None and isinstance (var .owner .op , Blockwise )
82
82
]
83
83
# Check any Blockwise are just Conv1d
84
- assert all (isinstance (node .op .core_op , Conv1d ) for node in blockwise_nodes )
84
+ assert all (isinstance (node .op .core_op , Convolve1d ) for node in blockwise_nodes )
You can’t perform that action at this time.
0 commit comments