From 030f5844d68ce66f287c76b8a3f2f98af6eb63b3 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Wed, 16 Oct 2024 22:30:00 +0900 Subject: [PATCH] ``dr.if_while_loop()``: maintain AD status of unchanged variables in AD-suspended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #2 of the fix for issue #253 reported by @dvicini and targets ``dr.while_lop()`` only. The previous commit fixed the same problem for ``if`` statements. --- src/extra/loop.cpp | 32 +++++++++++++++++++++++++++++--- tests/test_if_stmt.py | 4 ++-- tests/test_while_loop.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/extra/loop.cpp b/src/extra/loop.cpp index 35f0e7e2b..ddcb66aba 100644 --- a/src/extra/loop.cpp +++ b/src/extra/loop.cpp @@ -155,6 +155,7 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, index64_vector indices2; JitVar active_it; size_t it = 0; + bool grad_suspended = ad_grad_suspended(); while (true) { // Evaluate the loop state @@ -189,8 +190,15 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, } for (size_t i = 0; i < indices2.size(); ++i) { - uint64_t i1 = indices2[i]; - uint64_t i2 = ad_var_copy(i1); + // Kernel caching: Must create an AD copy so that gradient + // computation steps involving this variable (even if unchangecd + // & only used as a read-only dependency) are correctly placed + // within their associated loop iterations. This does not create + // a copy of the underlying JIT variable. + + uint64_t i1 = indices2[i], + i2 = grad_suspended ? ad_var_inc_ref(i1) : ad_var_copy(i1); + ad_var_dec_ref(i1); ad_mark_loop_boundary(i2); int unused = 0; @@ -926,8 +934,26 @@ bool ad_loop(JitBackend backend, int symbolic, int compress, write_cb, cond_cb, body_cb, indices_in, implicit_in, implicit_out); } + needs_ad &= ad; - if (needs_ad && ad) { + if (needs_ad && ad_grad_suspended()) { + // Maintain differentiability of unchanged variables + bool rewrite = false; + index64_vector indices_out; + + read_cb(payload, indices_out); + for (size_t i = 0; i < indices_out.size(); ++i) { + if ((uint32_t) indices_in[i] == (uint32_t) indices_out[i] && + indices_in[i] != indices_out[i]) { + ad_var_inc_ref(indices_in[i]); + jit_var_dec_ref((uint32_t) indices_out[i]); + indices_out[i] = indices_in[i]; + rewrite = true; + } + } + if (rewrite) + write_cb(payload, indices_out, false); + } else if (needs_ad) { index64_vector indices_out; read_cb(payload, indices_out); diff --git a/tests/test_if_stmt.py b/tests/test_if_stmt.py index 7b53be563..64f430931 100644 --- a/tests/test_if_stmt.py +++ b/tests/test_if_stmt.py @@ -592,8 +592,7 @@ def test18_if_stmt_preserve_unused_ad(t, mode): x = t(0, 1) y = t(1, 3) dr.enable_grad(x, y) - print(x.index) - print(y.index) + y_id = y.index_ad with dr.suspend_grad(): def true_fn(x, y): @@ -613,3 +612,4 @@ def false_fn(x, y): assert not dr.grad_enabled(x) assert dr.grad_enabled(y) + assert y.index_ad == y_id diff --git a/tests/test_while_loop.py b/tests/test_while_loop.py index 0a2f28c49..0ae0a3378 100644 --- a/tests/test_while_loop.py +++ b/tests/test_while_loop.py @@ -664,3 +664,31 @@ def loop(t, x, y: t, n = 10): dr.make_opaque(x, y) y = loop(t, [x, x], y) + + +@pytest.test_arrays('float32,diff,shape=(*)') +@pytest.mark.parametrize('mode', ['symbolic', 'evaluated']) +def test29_preserve_differentiability_suspend(t, mode): + x = t(0, 0) + y = t(1, 2) + dr.enable_grad(x, y) + y_id = y.index_ad + + with dr.suspend_grad(): + def cond_fn(x, _): + return x < 10 + + def body_fn(x, y): + return x + y, y + + x, y = dr.while_loop( + state=(x, y), + cond=cond_fn, + labels=('x', 'y'), + body=body_fn, + mode=mode + ) + + assert not dr.grad_enabled(x) + assert dr.grad_enabled(y) + assert y.index_ad == y_id