diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 93376c7bd170..3a81ecc85af9 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -142,8 +142,14 @@ def policy(prim, *_, **params): def save_from_both_policies(policy_1, policy_2): def policy(prim, *args, **params): - return policy_1(prim, *args, **params) or policy_2(prim, *args, **params) - + out1 = policy_1(prim, *args, **params) + out2 = policy_2(prim, *args, **params) + if not (isinstance(out1, bool) and isinstance(out2, bool)): + raise ValueError( + "The return value of the policies should be a boolean. Got:" + f" {out1} and {out2}. Please write a custom policy function directly," + " rather than using this helper function.") + return out1 or out2 return policy