Skip to content

Commit

Permalink
dr.if_while_loop(): maintain AD status of unchanged variables in …
Browse files Browse the repository at this point in the history
…AD-suspended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #2 of the fix for issue #253 reported by @dvicini and
targets ``dr.while_lop()`` only. The previous commit fixed the same
problem for ``if`` statements.
  • Loading branch information
wjakob authored and njroussel committed Oct 21, 2024
1 parent 494c571 commit 030f584
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 5 deletions.
32 changes: 29 additions & 3 deletions src/extra/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name,
index64_vector indices2;
JitVar active_it;
size_t it = 0;
bool grad_suspended = ad_grad_suspended();

while (true) {
// Evaluate the loop state
Expand Down Expand Up @@ -189,8 +190,15 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name,
}

for (size_t i = 0; i < indices2.size(); ++i) {
uint64_t i1 = indices2[i];
uint64_t i2 = ad_var_copy(i1);
// Kernel caching: Must create an AD copy so that gradient
// computation steps involving this variable (even if unchangecd
// & only used as a read-only dependency) are correctly placed
// within their associated loop iterations. This does not create
// a copy of the underlying JIT variable.

uint64_t i1 = indices2[i],
i2 = grad_suspended ? ad_var_inc_ref(i1) : ad_var_copy(i1);

ad_var_dec_ref(i1);
ad_mark_loop_boundary(i2);
int unused = 0;
Expand Down Expand Up @@ -926,8 +934,26 @@ bool ad_loop(JitBackend backend, int symbolic, int compress,
write_cb, cond_cb, body_cb, indices_in,
implicit_in, implicit_out);
}
needs_ad &= ad;

if (needs_ad && ad) {
if (needs_ad && ad_grad_suspended()) {
// Maintain differentiability of unchanged variables
bool rewrite = false;
index64_vector indices_out;

read_cb(payload, indices_out);
for (size_t i = 0; i < indices_out.size(); ++i) {
if ((uint32_t) indices_in[i] == (uint32_t) indices_out[i] &&
indices_in[i] != indices_out[i]) {
ad_var_inc_ref(indices_in[i]);
jit_var_dec_ref((uint32_t) indices_out[i]);
indices_out[i] = indices_in[i];
rewrite = true;
}
}
if (rewrite)
write_cb(payload, indices_out, false);
} else if (needs_ad) {
index64_vector indices_out;
read_cb(payload, indices_out);

Expand Down
4 changes: 2 additions & 2 deletions tests/test_if_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,7 @@ def test18_if_stmt_preserve_unused_ad(t, mode):
x = t(0, 1)
y = t(1, 3)
dr.enable_grad(x, y)
print(x.index)
print(y.index)
y_id = y.index_ad

with dr.suspend_grad():
def true_fn(x, y):
Expand All @@ -613,3 +612,4 @@ def false_fn(x, y):

assert not dr.grad_enabled(x)
assert dr.grad_enabled(y)
assert y.index_ad == y_id
28 changes: 28 additions & 0 deletions tests/test_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,31 @@ def loop(t, x, y: t, n = 10):
dr.make_opaque(x, y)

y = loop(t, [x, x], y)


@pytest.test_arrays('float32,diff,shape=(*)')
@pytest.mark.parametrize('mode', ['symbolic', 'evaluated'])
def test29_preserve_differentiability_suspend(t, mode):
x = t(0, 0)
y = t(1, 2)
dr.enable_grad(x, y)
y_id = y.index_ad

with dr.suspend_grad():
def cond_fn(x, _):
return x < 10

def body_fn(x, y):
return x + y, y

x, y = dr.while_loop(
state=(x, y),
cond=cond_fn,
labels=('x', 'y'),
body=body_fn,
mode=mode
)

assert not dr.grad_enabled(x)
assert dr.grad_enabled(y)
assert y.index_ad == y_id

0 comments on commit 030f584

Please sign in to comment.