Skip to content

Commit dc76879

Browse files
authored
[scan] Reduce memory usage (#8562)
1 parent 0260209 commit dc76879

File tree

3 files changed

+148
-25
lines changed

3 files changed

+148
-25
lines changed

test/scan/test_scan.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,46 @@ def fn(carry, x):
444444
self.assertEqual(bf16_ys.dtype, torch.bfloat16)
445445
self.assertEqual(f32_ys.dtype, torch.float32)
446446

447+
def test_scan_activation_aliases_input(self):
448+
"""Test that if an intermediate activation of fn aliases an input,
449+
we directly save the input tensor into the context object, instead of
450+
indexing into the leading dimension during the while loop and copying
451+
the those slices into a new output tensor. This is a memory saving optimization.
452+
"""
453+
454+
def fn(carry, x):
455+
return carry, torch.sin(x)
456+
457+
carry = torch.randn(4, 4, requires_grad=True, device=self.device)
458+
xs = torch.randn(20, 4, 4, requires_grad=True, device=self.device)
459+
torch_xla.sync()
460+
461+
storage = []
462+
463+
def pack(x):
464+
storage.append(x)
465+
return len(storage) - 1
466+
467+
def unpack(x):
468+
return storage[x]
469+
470+
# Intercept the tensors stored in the context object.
471+
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
472+
final_carry, ys = scan(fn, carry, xs)
473+
ys.sum().backward()
474+
torch_xla.sync()
475+
476+
# Find the input that is stored in the context object.
477+
stored_xs = None
478+
for s in storage:
479+
if s.shape == xs.shape:
480+
assert stored_xs is None
481+
stored_xs = s
482+
483+
# Test that it's literally the same object as the input tensor,
484+
# as opposed to just numerically identical but otherwise an extra copy.
485+
assert id(stored_xs) == id(xs)
486+
447487

448488
class PyTreeTest(TestBase):
449489

@@ -469,12 +509,16 @@ def fn(carry, x):
469509
xs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
470510
requires_grad=True,
471511
device=self.device)
472-
forward, backward = value_and_grad_partitioned(fn, init, xs)
512+
forward, alias_input, backward = value_and_grad_partitioned(fn, init, xs)
473513

474-
# Forward should return `(new_carry, (y, (carry, x)))`,
475-
# because `(carry, x)` are the two intermediate activations (primals),
476-
# and they will be packed alongside the original output `y`.
514+
# Once we add back activations that are aliases to inputs, the result should
515+
# be `(new_carry, (y, (carry, x)))`, because `(carry, x)` are the two
516+
# intermediate activations (primals), and they will be packed alongside
517+
# the original output `y`.
477518
out = forward(init, xs[0])
519+
new_carry, (y, partial_activations) = out
520+
activations = alias_input(partial_activations, xs[0])
521+
out = (new_carry, (y, activations))
478522
torch_xla.sync()
479523
carry = init
480524
x = xs[0]
@@ -521,11 +565,12 @@ def fn(carry, x):
521565
}
522566

523567
# Get the forward and backward functions using value_and_grad_partitioned
524-
forward, backward = value_and_grad_partitioned(
568+
forward, alias_input, backward = value_and_grad_partitioned(
525569
fn, init, tree_map(lambda v: v.unsqueeze(0), x))
526570

527571
# Run the forward function
528-
carry_out, (y_out, activations) = forward(init, x)
572+
carry_out, (y_out, partial_activations) = forward(init, x)
573+
activations = alias_input(partial_activations, x)
529574
torch_xla.sync()
530575

531576
# Compute expected outputs and gradients using PyTorch autograd

test/scan/test_scan_spmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_scan_cumsum(self):
2525
"""This test uses `scan` to implement `torch.cumsum`."""
2626

2727
def fn(carry, x):
28-
new_carry = carry + x
28+
new_carry = torch.sin(carry + x)
2929
y = new_carry
3030
return new_carry, y
3131

torch_xla/experimental/scan.py

Lines changed: 96 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import torch_xla
4747
import torch_xla.core.xla_builder as xb
4848
from torch_xla.experimental.pytreeify import pytreeify
49+
import torch_xla.debug.profiler as xp
4950

5051
Carry = TypeVar('Carry')
5152
X = TypeVar('X')
@@ -154,35 +155,44 @@ def scan(fn, init, xs):
154155
if xs_length is None:
155156
raise ValueError(f"`xs` {xs} is an empty PyTree.")
156157

157-
forward, backward = value_and_grad_partitioned(
158+
forward, alias_input, backward = value_and_grad_partitioned(
158159
fn, init, xs, partition_fn=partition_fn)
159-
carry, ys = Scan.apply(forward, backward, init, xs) # type: ignore
160+
carry, ys = Scan.apply(forward, alias_input, backward, init,
161+
xs) # type: ignore
160162
return carry, ys
161163

162164

163165
def value_and_grad_partitioned(
164166
fn: Callable[[Carry, X], tuple[Carry, Y]],
165167
init: Carry,
166168
xs: X,
167-
partition_fn=default_partition) -> tuple[Callable, Callable]:
169+
partition_fn=default_partition) -> tuple[Callable, Callable, Callable]:
168170
"""
169171
Given a user `fn` to be scanned over the leading dimension of the input `xs`
170172
PyTree and an initial carry object `init`, symbolically traces `fn` and
171-
returns two functions, `forward` and `backward`, which wrap the forward and
172-
backward graphs of `fn` and plumbs through intermediate activations.
173-
Specifically, given
173+
returns three functions, `forward`, `alias_input`, and `backward`.
174+
`forward` and `backward` wrap the forward and backward graphs of `fn` and
175+
plumbs through intermediate activations, while `alias_input` is a memory
176+
saving optimization. Specifically, given
174177
175178
`fn(carry, x) -> (new_carry, y)`
176-
179+
177180
this function will build and return
178181
179-
`forward(carry, x) -> (new_carry, (y, activations))`
182+
`forward(carry, x) -> (new_carry, (y, partial_activations))`
183+
184+
`alias_input(stack(partial_activations), xs) -> stack(activations)`
180185
181186
`backward(grad_new_carry, (grad_y, activations)) -> (grad_carry, grad_x)`
182187
183188
where `grad_y` is the gradient w.r.t `y`, and `grad_new_carry` is the gradient
184189
w.r.t. `new_carry`.
185-
190+
191+
The `partial_activations` returned by `forward` are intermediate activations
192+
that do not alias any input tensors. You may pass a stack of `partial_activations`
193+
and the original input `xs` PyTree to `alias_input` to reconstitute the full
194+
list of `activations`.
195+
186196
`activations` will always be a flat list of tensors.
187197
188198
This is similar to the `value_and_grad` transform found in JAX, but additionally
@@ -201,7 +211,7 @@ def value_and_grad_partitioned(
201211
forward and backward graphs.
202212
203213
Returns:
204-
A tuple of `(forward, backward)`, detailed in the docstring of this function.
214+
A tuple of `(forward, alias_input, backward)`, detailed in the docstring of this function.
205215
"""
206216

207217
# Make some fake tensors to trace the user function and obtain the
@@ -253,24 +263,92 @@ def fn_no_output_aliasing(*args):
253263
fwd_graph = get_fwd()
254264
bwd_graph = get_bwd()
255265

256-
def forward(carry, x):
266+
# Figure out which activations are alises to the inputs. We don't need to
267+
# pass them through the scan logic unchanged. That would use more memory.
268+
input_activation_aliases = _find_input_activation_aliases(
269+
fake_carry_pytree, fake_x_pytree, num_out, fwd_graph)
270+
aliased_activations = set(input_activation_aliases.values())
271+
272+
def forward_core(carry, x):
257273
flat_carry, _ = tree_flatten(carry)
258274
flat_x, _ = tree_flatten(x)
259-
out = fwd_graph(*flat_carry, *flat_x)
275+
with xp.Trace('aot_forward'):
276+
out = fwd_graph(*flat_carry, *flat_x)
260277
actual_out, activations = split(out, num_out)
261278
carry, y = unflatten_fwd_out(actual_out)
262279
y = (y, activations)
263280
return carry, y
264281

282+
def forward(carry, x):
283+
carry, (y, activations) = forward_core(carry, x)
284+
285+
# Remove activations that alias to inputs. Those will be added back
286+
# in `alias_input`.
287+
partial_activations = tuple(
288+
v for i, v in enumerate(activations) if i not in aliased_activations)
289+
290+
y = (y, partial_activations)
291+
return carry, y
292+
293+
def alias_input(partial_activations, xs):
294+
"""
295+
Add back activations that are aliases to input tensors.
296+
297+
In principle, we could have `forward` return all the intermediate activations,
298+
including those that are aliases to an input tensor. However, those inputs will
299+
then be duplicated as part of the output of a `scan` call, because we want to
300+
save all activations during the forward pass of a `scan`. The XLA compiler can't
301+
optimize away this duplication likely because they're behind a DynamicSlice +
302+
DynamicUpdateSlice, so we end up doubling the memory usage from those inputs.
303+
304+
To reduce memory usage, we can have `forward` return the activations that
305+
don't alias to inputs, called `partial_activations`. The autograd implementation
306+
of `scan` will call `alias_input` to add back activations that are aliases
307+
of input tensors outside of a scan, turning the partial activations back to
308+
full activations.
309+
"""
310+
activations = list(partial_activations)
311+
aliased_inputs = [
312+
v for i, v in enumerate(tree_iter(xs)) if i in input_activation_aliases
313+
]
314+
for (i, activation_idx) in enumerate(input_activation_aliases.values()):
315+
activations.insert(activation_idx, aliased_inputs[i])
316+
return tuple(activations)
317+
265318
def backward(carry, x):
266319
grad_new_carry, _ = tree_flatten(carry)
267320
(grad_y, activations) = x
268321
grad_y, _ = tree_flatten_none(grad_y)
269-
out = bwd_graph(*activations, *grad_new_carry, *grad_y)
322+
with xp.Trace('aot_backward'):
323+
out = bwd_graph(*activations, *grad_new_carry, *grad_y)
270324
grad_carry, grad_x = unflatten_bwd_out(out)
271325
return grad_carry, grad_x
272326

273-
return forward, backward
327+
return forward, alias_input, backward
328+
329+
330+
def _find_input_activation_aliases(fake_carry_pytree, fake_x_pytree, num_out,
331+
fwd_graph):
332+
"""
333+
Find which activations are aliases to input tensors.
334+
335+
Returns:
336+
337+
A mapping from index into the flatttened
338+
input pytree to the index into the list of intermediate activations.
339+
340+
"""
341+
flat_carry, _ = tree_flatten(fake_carry_pytree)
342+
flat_x, _ = tree_flatten(fake_x_pytree)
343+
_actual_out, activations = split(fwd_graph(*flat_carry, *flat_x), num_out)
344+
input_id_to_index = {
345+
v: i for i, v in enumerate(id(v) for v in tree_iter(flat_x))
346+
}
347+
input_activation_aliases = {}
348+
for idx, i in enumerate(id(v) for v in activations):
349+
if i in input_id_to_index:
350+
input_activation_aliases[input_id_to_index[i]] = idx
351+
return input_activation_aliases
274352

275353

276354
def _make_get_graph_compiler():
@@ -297,12 +375,12 @@ def get_graph():
297375
class Scan(torch.autograd.Function):
298376

299377
@staticmethod
300-
def forward(ctx, forward, backward, init, xs):
301-
# Forward pass, save activations for backward
378+
def forward(ctx, forward, alias_input, backward, init, xs):
302379
ctx._backward = backward
303380
with torch.no_grad():
304381
carry, ys = _scan_impl_pytree(forward, init, xs)
305-
ys, activations = ys
382+
ys, partial_activations = ys
383+
activations = alias_input(partial_activations, xs)
306384
ctx.save_for_backward(*activations)
307385
return carry, ys
308386

@@ -314,7 +392,7 @@ def backward(ctx, grad_carry, grad_ys): # type: ignore
314392
# Reverse loop to propagate gradients from last iteration to first.
315393
grad_init, grad_xs = _scan_impl_pytree(
316394
backward, grad_carry, (grad_ys, activations), reverse=True)
317-
return None, None, grad_init, grad_xs
395+
return None, None, None, grad_init, grad_xs
318396

319397

320398
def _scan_impl_pytree(fn, init, xs, reverse: bool = False):

0 commit comments

Comments
 (0)