Skip to content

Commit 43e5155

Browse files
Add scan op (#19681)
* Add `scan` * Fix lint * Increase test coverage * Increase test coverage * Replace `TypeError` with `ValueError` for invalid `unroll`
1 parent 10c27c0 commit 43e5155

File tree

8 files changed

+465
-7
lines changed

8 files changed

+465
-7
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from keras.src.ops.core import custom_gradient
1616
from keras.src.ops.core import fori_loop
1717
from keras.src.ops.core import is_tensor
18+
from keras.src.ops.core import scan
1819
from keras.src.ops.core import scatter
1920
from keras.src.ops.core import scatter_update
2021
from keras.src.ops.core import shape

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from keras.src.ops.core import custom_gradient
1616
from keras.src.ops.core import fori_loop
1717
from keras.src.ops.core import is_tensor
18+
from keras.src.ops.core import scan
1819
from keras.src.ops.core import scatter
1920
from keras.src.ops.core import scatter_update
2021
from keras.src.ops.core import shape

keras/src/backend/jax/core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,18 @@ def vectorized_map(function, elements):
253253
return jax.vmap(function)(elements)
254254

255255

256+
def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
257+
if not isinstance(unroll, bool):
258+
if not isinstance(unroll, int) or unroll < 1:
259+
raise ValueError(
260+
"`unroll` must be an positive integer or boolean. "
261+
f"Received: unroll={unroll}"
262+
)
263+
return jax.lax.scan(
264+
f, init=init, xs=xs, length=length, reverse=reverse, unroll=unroll
265+
)
266+
267+
256268
def scatter(indices, values, shape):
257269
zeros = jnp.zeros(shape, values.dtype)
258270
key = tuple(jnp.moveaxis(indices, -1, 0))

keras/src/backend/numpy/core.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,55 @@ def convert_numpy_to_keras_tensor(x):
140140
return output_spec
141141

142142

143+
def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
144+
# Ref: jax.lax.scan
145+
if not callable(f):
146+
raise TypeError(f"`f` should be a callable. Received: f={f}")
147+
if not isinstance(unroll, bool):
148+
if not isinstance(unroll, int) or unroll < 1:
149+
raise ValueError(
150+
"`unroll` must be an positive integer or boolean. "
151+
f"Received: unroll={unroll}"
152+
)
153+
if xs is None and length is None:
154+
raise ValueError("Got no `xs` to scan over and `length` not provided.")
155+
156+
input_is_sequence = tree.is_nested(xs)
157+
output_is_sequence = tree.is_nested(init)
158+
159+
def pack_input(x):
160+
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]
161+
162+
def pack_output(x):
163+
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]
164+
165+
if xs is None:
166+
xs_flat = []
167+
n = int(length)
168+
else:
169+
xs_flat = tree.flatten(xs)
170+
xs_flat = [convert_to_tensor(elem) for elem in xs_flat]
171+
n = int(length) if length is not None else shape(xs_flat[0])[0]
172+
173+
init_flat = tree.flatten(init)
174+
init_flat = [convert_to_tensor(init) for init in init_flat]
175+
init = pack_output(init_flat)
176+
dummy_y = [np.zeros_like(init) for init in init_flat]
177+
178+
carry = init
179+
ys = []
180+
maybe_reversed = reversed if reverse else lambda x: x
181+
for i in maybe_reversed(range(n)):
182+
xs_slice = [x[i] for x in xs_flat]
183+
packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None
184+
carry, y = f(carry, packed_xs)
185+
ys.append(y if y is not None else dummy_y)
186+
stacked_y = tree.map_structure(
187+
lambda *ys: np.stack(ys), *maybe_reversed(ys)
188+
)
189+
return carry, stacked_y
190+
191+
143192
def scatter(indices, values, shape):
144193
indices = convert_to_tensor(indices)
145194
values = convert_to_tensor(values)

keras/src/backend/tensorflow/core.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,127 @@ def vectorized_map(function, elements):
210210
return tf.vectorized_map(function, elements)
211211

212212

213+
def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
214+
# We have reimplemented `scan` to match the behavior of `jax.lax.scan`
215+
# Ref: tf.scan, jax.lax.scan
216+
if not callable(f):
217+
raise TypeError(f"`f` should be a callable. Received: f={f}")
218+
if not isinstance(unroll, bool):
219+
if not isinstance(unroll, int) or unroll < 1:
220+
raise ValueError(
221+
"`unroll` must be an positive integer or boolean. "
222+
f"Received: unroll={unroll}"
223+
)
224+
if xs is None and length is None:
225+
raise ValueError("Got no `xs` to scan over and `length` not provided.")
226+
227+
input_is_sequence = tree.is_nested(xs)
228+
output_is_sequence = tree.is_nested(init)
229+
230+
def pack_input(x):
231+
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]
232+
233+
def pack_output(x):
234+
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]
235+
236+
if xs is None:
237+
xs_flat = []
238+
n = int(length)
239+
else:
240+
# xs_flat = flatten_input(xs)
241+
xs_flat = tree.flatten(xs)
242+
xs_flat = [tf.convert_to_tensor(elem) for elem in xs_flat]
243+
n = int(length) if length is not None else tf.shape(xs_flat[0])[0]
244+
245+
# TensorArrays are always flat
246+
xs_array = [
247+
tf.TensorArray(
248+
dtype=x.dtype,
249+
size=n,
250+
dynamic_size=False,
251+
element_shape=x.shape[1:],
252+
infer_shape=True,
253+
)
254+
for x in xs_flat
255+
]
256+
xs_array = [x_a.unstack(x) for x_a, x in zip(xs_array, xs_flat)]
257+
258+
init_flat = tree.flatten(init)
259+
carry_flat = [tf.convert_to_tensor(init) for init in init_flat]
260+
261+
# Store the intermediate values
262+
# Note: there is a constraint that the output of `f` must have the same
263+
# shape and dtype as carry (`init`).
264+
ys_array = [
265+
tf.TensorArray(
266+
dtype=carry.dtype,
267+
size=n,
268+
dynamic_size=False,
269+
element_shape=carry.shape,
270+
infer_shape=True,
271+
)
272+
for carry in carry_flat
273+
]
274+
carry_array = [
275+
tf.TensorArray(
276+
dtype=carry.dtype,
277+
size=1,
278+
dynamic_size=False,
279+
clear_after_read=False,
280+
element_shape=carry.shape,
281+
infer_shape=True,
282+
)
283+
for carry in carry_flat
284+
]
285+
carry_array = [
286+
carry.write(0, c) for (carry, c) in zip(carry_array, carry_flat)
287+
]
288+
289+
def loop_body(i, carry_array, ys_array):
290+
packed_xs = (
291+
pack_input([xs.read(i) for xs in xs_array])
292+
if len(xs_array) > 0
293+
else None
294+
)
295+
packed_carry = pack_output([carry.read(0) for carry in carry_array])
296+
297+
carry, ys = f(packed_carry, packed_xs)
298+
299+
if ys is not None:
300+
flat_ys = tree.flatten(ys)
301+
ys_array = [ys.write(i, v) for (ys, v) in zip(ys_array, flat_ys)]
302+
if carry is not None:
303+
flat_carry = tree.flatten(carry)
304+
carry_array = [
305+
carry.write(0, v) for (carry, v) in zip(carry_array, flat_carry)
306+
]
307+
next_i = i + 1 if not reverse else i - 1
308+
return (next_i, carry_array, ys_array)
309+
310+
if isinstance(unroll, bool):
311+
unroll = max(n, 1) if unroll else 1
312+
313+
_, carry_array, ys_array = tf.while_loop(
314+
lambda i, _1, _2: i >= 0 if reverse else i < n,
315+
loop_body,
316+
(n - 1 if reverse else 0, carry_array, ys_array),
317+
parallel_iterations=unroll,
318+
)
319+
320+
ys_flat = [ys.stack() for ys in ys_array]
321+
carry_flat = [carry.read(0) for carry in carry_array]
322+
if xs is not None:
323+
n_static = xs_flat[0].get_shape().with_rank_at_least(1)[0]
324+
if not isinstance(n_static, int):
325+
for x in xs_flat[1:]:
326+
n_static.assert_is_compatible_with(
327+
x.get_shape().with_rank_at_least(1)[0]
328+
)
329+
for r in ys_flat:
330+
r.set_shape(tf.TensorShape(n_static).concatenate(r.get_shape()[1:]))
331+
return pack_output(carry_flat), pack_output(ys_flat)
332+
333+
213334
def scatter(indices, values, shape):
214335
return tf.scatter_nd(indices, values, shape)
215336

keras/src/backend/torch/core.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,55 @@ def vectorized_map(function, elements):
340340
return torch.vmap(function)(elements)
341341

342342

343+
def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
344+
# Ref: jax.lax.scan
345+
if not callable(f):
346+
raise TypeError(f"`f` should be a callable. Received: f={f}")
347+
if not isinstance(unroll, bool):
348+
if not isinstance(unroll, int) or unroll < 1:
349+
raise ValueError(
350+
"`unroll` must be an positive integer or boolean. "
351+
f"Received: unroll={unroll}"
352+
)
353+
if xs is None and length is None:
354+
raise ValueError("Got no `xs` to scan over and `length` not provided.")
355+
356+
input_is_sequence = tree.is_nested(xs)
357+
output_is_sequence = tree.is_nested(init)
358+
359+
def pack_input(x):
360+
return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]
361+
362+
def pack_output(x):
363+
return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]
364+
365+
if xs is None:
366+
xs_flat = []
367+
n = int(length)
368+
else:
369+
xs_flat = tree.flatten(xs)
370+
xs_flat = [convert_to_tensor(elem) for elem in xs_flat]
371+
n = int(length) if length is not None else shape(xs_flat[0])[0]
372+
373+
init_flat = tree.flatten(init)
374+
init_flat = [convert_to_tensor(init) for init in init_flat]
375+
init = pack_output(init_flat)
376+
dummy_y = [torch.zeros_like(init) for init in init_flat]
377+
378+
carry = init
379+
ys = []
380+
maybe_reversed = reversed if reverse else lambda x: x
381+
for i in maybe_reversed(range(n)):
382+
xs_slice = [x[i] for x in xs_flat]
383+
packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None
384+
carry, y = f(carry, packed_xs)
385+
ys.append(y if y is not None else dummy_y)
386+
stacked_y = tree.map_structure(
387+
lambda *ys: torch.stack(ys), *maybe_reversed(ys)
388+
)
389+
return carry, stacked_y
390+
391+
343392
def scatter(indices, values, shape):
344393
indices = convert_to_tensor(indices)
345394
values = convert_to_tensor(values)

keras/src/ops/core.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""
2+
scan
23
scatter
34
scatter_update
45
slice
@@ -25,6 +26,113 @@
2526
from keras.src.utils import traceback_utils
2627

2728

29+
class Scan(Operation):
30+
def __init__(self, reverse=False, unroll=1):
31+
super().__init__()
32+
self.reverse = reverse
33+
self.unroll = unroll
34+
35+
def call(self, f, init, xs, length):
36+
return backend.core.scan(
37+
f, init, xs, length, reverse=self.reverse, unroll=self.unroll
38+
)
39+
40+
def compute_output_spec(self, f, init, xs, length):
41+
if xs is None:
42+
n = int(length)
43+
x = None
44+
else:
45+
n = (
46+
int(length)
47+
if length is not None
48+
else tree.flatten(xs)[0].shape[0]
49+
)
50+
x = xs[0]
51+
52+
carry_spec, y_spec = backend.compute_output_spec(f, init, x)
53+
y_spec.shape = (n,) + y_spec.shape
54+
return carry_spec, y_spec
55+
56+
57+
@keras_export("keras.ops.scan")
58+
def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
59+
"""Scan a function over leading array axes while carrying along state.
60+
61+
When the type of `xs` is an array type or `None`, and the type of `ys` is an
62+
array type, the semantics of `scan()` are given roughly by this Python
63+
implementation:
64+
65+
```python
66+
def scan(f, init, xs, length=None):
67+
if xs is None:
68+
xs = [None] * length
69+
carry = init
70+
ys = []
71+
for x in xs:
72+
carry, y = f(carry, x)
73+
ys.append(y)
74+
return carry, np.stack(ys)
75+
```
76+
77+
The loop-carried value `carry` (`init`) must hold a fixed shape and dtype
78+
across all iterations.
79+
80+
In TensorFlow, `y` must match `carry` in shape and dtype. This is not
81+
required in other backends.
82+
83+
Args:
84+
f: Callable defines the logic for each loop iteration. This accepts two
85+
arguments where the first is a value of the loop carry and the
86+
second is a slice of `xs` along its leading axis.
87+
This callable returns a pair where the first represents a new value
88+
for the loop carry and the second represents a slice of the output.
89+
init: The initial loop carry value. This can be a scalar, tensor, or any
90+
nested structure. It must match the structure of the first element
91+
returned by `f`.
92+
xs: Optional value to scan along its leading axis. This can be a tensor
93+
or any nested structure. If `xs` is not provided, you must specify
94+
`length` to define the number of loop iterations.
95+
Defaults to `None`.
96+
length: Optional integer specifying the number of loop iterations.
97+
If `length` is not provided, it defaults to the sizes of leading
98+
axis of the arrays in `xs`. Defaults to `None`.
99+
reverse: Optional boolean specifying whether to run the scan iteration
100+
forward or in reverse, equivalent to reversing the leading axes of
101+
the arrays in both `xs` and in `ys`.
102+
unroll: Optional positive integer or boolean specifying how many scan
103+
iterations to unroll within a single iteration of a loop. If an
104+
integer is provided, it determines how many unrolled loop iterations
105+
to run within a single rolled iteration of the loop. If a boolean is
106+
provided, it will determine if the loop is completely unrolled
107+
(`unroll=True`) or left completely unrolled (`unroll=False`).
108+
Note that unrolling is only supported by JAX and TensorFlow
109+
backends.
110+
111+
Returns:
112+
A pair where the first element represents the final loop carry value and
113+
the second element represents the stacked outputs of `f` when scanned
114+
over the leading axis of the inputs.
115+
116+
Examples:
117+
118+
>>> sum_fn = lambda c, x: (c + x, c + x)
119+
>>> init = keras.ops.array(0)
120+
>>> xs = keras.ops.array([1, 2, 3, 4, 5])
121+
>>> carry, result = ops.scan(sum_fn, init, xs)
122+
>>> carry
123+
15
124+
>>> result
125+
[1, 3, 6, 10, 15]
126+
"""
127+
if any_symbolic_tensors((init, xs)):
128+
return Scan(reverse=reverse, unroll=unroll).symbolic_call(
129+
f, init, xs, length
130+
)
131+
return backend.core.scan(
132+
f, init, xs, length, reverse=reverse, unroll=unroll
133+
)
134+
135+
28136
class Scatter(Operation):
29137
def call(self, indices, values, shape):
30138
return backend.core.scatter(indices, values, shape)

0 commit comments

Comments
 (0)