Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch multiple simultaneous gradient_checkpoint_scope #1583

Open
albertz opened this issue Jul 15, 2024 · 0 comments
Open

Torch multiple simultaneous gradient_checkpoint_scope #1583

albertz opened this issue Jul 15, 2024 · 0 comments

Comments

@albertz
Copy link
Member

albertz commented Jul 15, 2024

There will only be one saved_tensors_hooks active, specifically for the most recent gradient_checkpoint_scope. So any of the earlier pack hooks will not be used, when there are multiple simultaneous gradient_checkpoint_scopes.

Example code:

def get_var1():
    with gradient_checkpoint_scope():
        return var1 + torch.randn_like(var1)

def get_var2():
    with gradient_checkpoint_scope():
        return var2 + torch.randn_like(var2)

x = get_var1() * get_var2()

A solution is that we keep a global weak tensor key dictionary for all registered tensors of any gradient_checkpoint_scope, and in the pack hook, check that instead of the local.

It's currently maybe not so important, as this is a case we likely do not run into (yet; I guess).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant