Skip to content

Commit c4145b6

Browse files
committed
[nnx] refactor remat
1 parent 7600ad1 commit c4145b6

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

flax/nnx/transforms/autodiff.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -866,27 +866,6 @@ def remat(
866866
static_argnums: int | tuple[int, ...] = (),
867867
policy: tp.Callable[..., bool] | None = None,
868868
) -> F | tp.Callable[[F], F]:
869-
if isinstance(f, Missing):
870-
return functools.partial(
871-
remat,
872-
prevent_cse=prevent_cse,
873-
static_argnums=static_argnums,
874-
policy=policy,
875-
) # type: ignore[return-value]
876-
877-
return resolve_kwargs()(
878-
graph.update_context('remat')(
879-
general.split_inputs(
880-
jax.checkpoint(
881-
general.merge_inputs(f, ctxtag='remat'),
882-
prevent_cse=prevent_cse,
883-
static_argnums=static_argnums,
884-
policy=policy,
885-
),
886-
ctxtag='remat',
887-
),
888-
)
889-
)
890869
"""A 'lifted' version of the
891870
`jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
892871
(a.k.a. ``jax.remat``).
@@ -901,4 +880,26 @@ def remat(
901880
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
902881
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
903882
"""
883+
if isinstance(f, Missing):
884+
return functools.partial(
885+
remat,
886+
prevent_cse=prevent_cse,
887+
static_argnums=static_argnums,
888+
policy=policy,
889+
) # type: ignore[return-value]
890+
891+
@resolve_kwargs()
892+
@graph.update_context('remat')
893+
@general.split_inputs(ctxtag='remat')
894+
@functools.partial(
895+
jax.checkpoint,
896+
prevent_cse=prevent_cse,
897+
static_argnums=static_argnums,
898+
policy=policy,
899+
)
900+
@general.merge_inputs(ctxtag='remat')
901+
@functools.wraps(f)
902+
def remat_wrapper(*args, **kwargs):
903+
return f(*args, **kwargs)
904904

905+
return remat_wrapper

0 commit comments

Comments
 (0)