46
46
import torch_xla
47
47
import torch_xla .core .xla_builder as xb
48
48
from torch_xla .experimental .pytreeify import pytreeify
49
+ import torch_xla .debug .profiler as xp
49
50
50
51
Carry = TypeVar ('Carry' )
51
52
X = TypeVar ('X' )
@@ -154,35 +155,44 @@ def scan(fn, init, xs):
154
155
if xs_length is None :
155
156
raise ValueError (f"`xs` { xs } is an empty PyTree." )
156
157
157
- forward , backward = value_and_grad_partitioned (
158
+ forward , alias_input , backward = value_and_grad_partitioned (
158
159
fn , init , xs , partition_fn = partition_fn )
159
- carry , ys = Scan .apply (forward , backward , init , xs ) # type: ignore
160
+ carry , ys = Scan .apply (forward , alias_input , backward , init ,
161
+ xs ) # type: ignore
160
162
return carry , ys
161
163
162
164
163
165
def value_and_grad_partitioned (
164
166
fn : Callable [[Carry , X ], tuple [Carry , Y ]],
165
167
init : Carry ,
166
168
xs : X ,
167
- partition_fn = default_partition ) -> tuple [Callable , Callable ]:
169
+ partition_fn = default_partition ) -> tuple [Callable , Callable , Callable ]:
168
170
"""
169
171
Given a user `fn` to be scanned over the leading dimension of the input `xs`
170
172
PyTree and an initial carry object `init`, symbolically traces `fn` and
171
- returns two functions, `forward` and `backward`, which wrap the forward and
172
- backward graphs of `fn` and plumbs through intermediate activations.
173
- Specifically, given
173
+ returns three functions, `forward`, `alias_input`, and `backward`.
174
+ `forward` and `backward` wrap the forward and backward graphs of `fn` and
175
+ plumbs through intermediate activations, while `alias_input` is a memory
176
+ saving optimization. Specifically, given
174
177
175
178
`fn(carry, x) -> (new_carry, y)`
176
-
179
+
177
180
this function will build and return
178
181
179
- `forward(carry, x) -> (new_carry, (y, activations))`
182
+ `forward(carry, x) -> (new_carry, (y, partial_activations))`
183
+
184
+ `alias_input(stack(partial_activations), xs) -> stack(activations)`
180
185
181
186
`backward(grad_new_carry, (grad_y, activations)) -> (grad_carry, grad_x)`
182
187
183
188
where `grad_y` is the gradient w.r.t `y`, and `grad_new_carry` is the gradient
184
189
w.r.t. `new_carry`.
185
-
190
+
191
+ The `partial_activations` returned by `forward` are intermediate activations
192
+ that do not alias any input tensors. You may pass a stack of `partial_activations`
193
+ and the original input `xs` PyTree to `alias_input` to reconstitute the full
194
+ list of `activations`.
195
+
186
196
`activations` will always be a flat list of tensors.
187
197
188
198
This is similar to the `value_and_grad` transform found in JAX, but additionally
@@ -201,7 +211,7 @@ def value_and_grad_partitioned(
201
211
forward and backward graphs.
202
212
203
213
Returns:
204
- A tuple of `(forward, backward)`, detailed in the docstring of this function.
214
+ A tuple of `(forward, alias_input, backward)`, detailed in the docstring of this function.
205
215
"""
206
216
207
217
# Make some fake tensors to trace the user function and obtain the
@@ -253,24 +263,92 @@ def fn_no_output_aliasing(*args):
253
263
fwd_graph = get_fwd ()
254
264
bwd_graph = get_bwd ()
255
265
256
- def forward (carry , x ):
266
+ # Figure out which activations are alises to the inputs. We don't need to
267
+ # pass them through the scan logic unchanged. That would use more memory.
268
+ input_activation_aliases = _find_input_activation_aliases (
269
+ fake_carry_pytree , fake_x_pytree , num_out , fwd_graph )
270
+ aliased_activations = set (input_activation_aliases .values ())
271
+
272
+ def forward_core (carry , x ):
257
273
flat_carry , _ = tree_flatten (carry )
258
274
flat_x , _ = tree_flatten (x )
259
- out = fwd_graph (* flat_carry , * flat_x )
275
+ with xp .Trace ('aot_forward' ):
276
+ out = fwd_graph (* flat_carry , * flat_x )
260
277
actual_out , activations = split (out , num_out )
261
278
carry , y = unflatten_fwd_out (actual_out )
262
279
y = (y , activations )
263
280
return carry , y
264
281
282
+ def forward (carry , x ):
283
+ carry , (y , activations ) = forward_core (carry , x )
284
+
285
+ # Remove activations that alias to inputs. Those will be added back
286
+ # in `alias_input`.
287
+ partial_activations = tuple (
288
+ v for i , v in enumerate (activations ) if i not in aliased_activations )
289
+
290
+ y = (y , partial_activations )
291
+ return carry , y
292
+
293
+ def alias_input (partial_activations , xs ):
294
+ """
295
+ Add back activations that are aliases to input tensors.
296
+
297
+ In principle, we could have `forward` return all the intermediate activations,
298
+ including those that are aliases to an input tensor. However, those inputs will
299
+ then be duplicated as part of the output of a `scan` call, because we want to
300
+ save all activations during the forward pass of a `scan`. The XLA compiler can't
301
+ optimize away this duplication likely because they're behind a DynamicSlice +
302
+ DynamicUpdateSlice, so we end up doubling the memory usage from those inputs.
303
+
304
+ To reduce memory usage, we can have `forward` return the activations that
305
+ don't alias to inputs, called `partial_activations`. The autograd implementation
306
+ of `scan` will call `alias_input` to add back activations that are aliases
307
+ of input tensors outside of a scan, turning the partial activations back to
308
+ full activations.
309
+ """
310
+ activations = list (partial_activations )
311
+ aliased_inputs = [
312
+ v for i , v in enumerate (tree_iter (xs )) if i in input_activation_aliases
313
+ ]
314
+ for (i , activation_idx ) in enumerate (input_activation_aliases .values ()):
315
+ activations .insert (activation_idx , aliased_inputs [i ])
316
+ return tuple (activations )
317
+
265
318
def backward (carry , x ):
266
319
grad_new_carry , _ = tree_flatten (carry )
267
320
(grad_y , activations ) = x
268
321
grad_y , _ = tree_flatten_none (grad_y )
269
- out = bwd_graph (* activations , * grad_new_carry , * grad_y )
322
+ with xp .Trace ('aot_backward' ):
323
+ out = bwd_graph (* activations , * grad_new_carry , * grad_y )
270
324
grad_carry , grad_x = unflatten_bwd_out (out )
271
325
return grad_carry , grad_x
272
326
273
- return forward , backward
327
+ return forward , alias_input , backward
328
+
329
+
330
+ def _find_input_activation_aliases (fake_carry_pytree , fake_x_pytree , num_out ,
331
+ fwd_graph ):
332
+ """
333
+ Find which activations are aliases to input tensors.
334
+
335
+ Returns:
336
+
337
+ A mapping from index into the flatttened
338
+ input pytree to the index into the list of intermediate activations.
339
+
340
+ """
341
+ flat_carry , _ = tree_flatten (fake_carry_pytree )
342
+ flat_x , _ = tree_flatten (fake_x_pytree )
343
+ _actual_out , activations = split (fwd_graph (* flat_carry , * flat_x ), num_out )
344
+ input_id_to_index = {
345
+ v : i for i , v in enumerate (id (v ) for v in tree_iter (flat_x ))
346
+ }
347
+ input_activation_aliases = {}
348
+ for idx , i in enumerate (id (v ) for v in activations ):
349
+ if i in input_id_to_index :
350
+ input_activation_aliases [input_id_to_index [i ]] = idx
351
+ return input_activation_aliases
274
352
275
353
276
354
def _make_get_graph_compiler ():
@@ -297,12 +375,12 @@ def get_graph():
297
375
class Scan (torch .autograd .Function ):
298
376
299
377
@staticmethod
300
- def forward (ctx , forward , backward , init , xs ):
301
- # Forward pass, save activations for backward
378
+ def forward (ctx , forward , alias_input , backward , init , xs ):
302
379
ctx ._backward = backward
303
380
with torch .no_grad ():
304
381
carry , ys = _scan_impl_pytree (forward , init , xs )
305
- ys , activations = ys
382
+ ys , partial_activations = ys
383
+ activations = alias_input (partial_activations , xs )
306
384
ctx .save_for_backward (* activations )
307
385
return carry , ys
308
386
@@ -314,7 +392,7 @@ def backward(ctx, grad_carry, grad_ys): # type: ignore
314
392
# Reverse loop to propagate gradients from last iteration to first.
315
393
grad_init , grad_xs = _scan_impl_pytree (
316
394
backward , grad_carry , (grad_ys , activations ), reverse = True )
317
- return None , None , grad_init , grad_xs
395
+ return None , None , None , grad_init , grad_xs
318
396
319
397
320
398
def _scan_impl_pytree (fn , init , xs , reverse : bool = False ):
0 commit comments