Skip to content

Commit

Permalink
Fix save_from_both_policies in presence of `save_and_offload_only_t…
Browse files Browse the repository at this point in the history
…hese_names` by comparing the enum

PiperOrigin-RevId: 706874882
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Dec 17, 2024
1 parent 772339e commit 7dd401c
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,14 @@ def policy(prim, *_, **params):
def save_from_both_policies(policy_1, policy_2):

def policy(prim, *args, **params):
return policy_1(prim, *args, **params) or policy_2(prim, *args, **params)

out1 = policy_1(prim, *args, **params)
out2 = policy_2(prim, *args, **params)
if not (isinstance(out1, bool) and isinstance(out2, bool)):
raise ValueError(
"The return value of the policies should be a boolean. Got:"
f" {out1} and {out2}. Please write a custom policy function directly,"
" rather than using this helper function.")
return out1 or out2
return policy


Expand Down

0 comments on commit 7dd401c

Please sign in to comment.