@@ -866,27 +866,6 @@ def remat(
866
866
static_argnums : int | tuple [int , ...] = (),
867
867
policy : tp .Callable [..., bool ] | None = None ,
868
868
) -> F | tp .Callable [[F ], F ]:
869
- if isinstance (f , Missing ):
870
- return functools .partial (
871
- remat ,
872
- prevent_cse = prevent_cse ,
873
- static_argnums = static_argnums ,
874
- policy = policy ,
875
- ) # type: ignore[return-value]
876
-
877
- return resolve_kwargs ()(
878
- graph .update_context ('remat' )(
879
- general .split_inputs (
880
- jax .checkpoint (
881
- general .merge_inputs (f , ctxtag = 'remat' ),
882
- prevent_cse = prevent_cse ,
883
- static_argnums = static_argnums ,
884
- policy = policy ,
885
- ),
886
- ctxtag = 'remat' ,
887
- ),
888
- )
889
- )
890
869
"""A 'lifted' version of the
891
870
`jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
892
871
(a.k.a. ``jax.remat``).
@@ -901,4 +880,26 @@ def remat(
901
880
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
902
881
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
903
882
"""
883
+ if isinstance (f , Missing ):
884
+ return functools .partial (
885
+ remat ,
886
+ prevent_cse = prevent_cse ,
887
+ static_argnums = static_argnums ,
888
+ policy = policy ,
889
+ ) # type: ignore[return-value]
890
+
891
+ @resolve_kwargs ()
892
+ @graph .update_context ('remat' )
893
+ @general .split_inputs (ctxtag = 'remat' )
894
+ @functools .partial (
895
+ jax .checkpoint ,
896
+ prevent_cse = prevent_cse ,
897
+ static_argnums = static_argnums ,
898
+ policy = policy ,
899
+ )
900
+ @general .merge_inputs (ctxtag = 'remat' )
901
+ @functools .wraps (f )
902
+ def remat_wrapper (* args , ** kwargs ):
903
+ return f (* args , ** kwargs )
904
904
905
+ return remat_wrapper
0 commit comments