Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the behavior of control flow operations (dr.if_stmt(), drjit.while_loop()) in AD-suspended scopes #299

Merged
merged 3 commits into from
Oct 21, 2024

Conversation

wjakob
Copy link
Member

@wjakob wjakob commented Oct 16, 2024

Dr.Jit control flow operations (dr.if_stmt(), drjit.while_loop()) currently disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope.

This can be surprising when a @dr.syntax transformation silently passes local variables to such an operation, which then become non-differentiable. @dvicini reported this in issue #253.

This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity.

The PR has 3 parts:

  • the first commit changes the semantics of ad_var_inc_ref so that it only does reference counting. The old behavior (drop ref if variable is currently non-differentiable) is moved to a new function ad_var_copy_ref() (see the description of this commit for details). This already fixes 80% of the problems.
  • Parts 2 and 3 fix dr.if_stmt() and dr.while_loop(), respectively.

The PR depends on a Dr.Jit-Core PR: mitsuba-renderer/drjit-core#104

@wjakob wjakob force-pushed the ad-suspend-fixes branch 2 times, most recently from 32128f8 to ce7b4d8 Compare October 18, 2024 10:04
Copy link
Member

@njroussel njroussel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. It was not immediately obvious to me what would happen with a dr.scatter() that only involved attached state variables - but it's actually fine.

The AD layer exposes a function named ``ad_var_inc_ref()`` that
increases the reference count of a variable analogous to
``jit_var_inc_ref()``. However, one difference between the two is that
the former detaches AD variables when the underlying index has
derivative tracking disabled.

For example, this ensures that code like

```python
x = Float(0)
dr.enable_grad(x)
with dr.suspend_grad():
    y = Float(y)
```
creates a non-differentiable copy.

However, since there are many other operations throughout the Dr.Jit
codebase that require reference counting, there were quite a few places
that exhibited this detaching behavior, which is not always wanted.
(see issue #253).

This commit provides two reference counting functions:

- ``ad_var_inc_ref()`` which increases the reference count *without*
  detaching, and

- ``ad_var_copy_ref()``, which detaches (i.e., reproducing the former
  behavior)

Following this split, only the constructor of AD arrays uses the
detaching ``ad_var_copy_ref()``, while all other operations use
the new ``ad_var_inc_ref()``.
…pended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #1 of the fix for issue #253 reported by @dvicini and
targets ``dr.if_stmt()`` only. The next commit will also fix the same
problem for while loops.
…AD-suspended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #2 of the fix for issue #253 reported by @dvicini and
targets ``dr.while_lop()`` only. The previous commit fixed the same
problem for ``if`` statements.
@njroussel njroussel merged commit e41c103 into master Oct 21, 2024
@njroussel njroussel deleted the ad-suspend-fixes branch October 21, 2024 14:11
@andyyankai
Copy link

Hi, may I ask if this fix can relate to this issue? Thanks!
mitsuba-renderer/mitsuba3#1334

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants