Skip to content

Commit

Permalink
Torch gradient_checkpoint_scope, fix potential segfault
Browse files Browse the repository at this point in the history
Maybe fix #1581
  • Loading branch information
albertz committed Jul 13, 2024
1 parent 90f56c9 commit 78af734
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions returnn/torch/util/gradient_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import contextlib
from weakref import ref, WeakSet
import threading
import atexit

import torch
from torch.utils.weak import WeakTensorKeyDictionary # needs Torch >=2.0.0
Expand Down Expand Up @@ -178,6 +179,8 @@ def _maybe_exit_saved_tensors_hooks_scope(self):
self.exit_saved_tensors_hooks_scope()

def __del__(self):
if _python_exit:
return
# Note, be very careful what we do in __del__ because it might be called in a different thread!
# Note that the __del__ will likely be called very late,
# as the reference to the _Graph is kept alive until we used it for backprop,
Expand Down Expand Up @@ -220,6 +223,8 @@ def _unpack_hook(x: Union[torch.Tensor, _GraphTensor]) -> torch.Tensor:
return x

def _tensor_del_hook(self):
if _python_exit:
return
# Some of the relevant tensors got deleted.
# If we are in the right thread, maybe we can do the cleanup now.
self._maybe_exit_saved_tensors_hooks_scope()
Expand Down Expand Up @@ -601,3 +606,12 @@ def _custom_saved_tensors_hooks_call_callbacks():
assert not _custom_saved_tensors_hooks_tls_ctx.callbacks and not _custom_saved_tensors_hooks_tls_ctx.stack
finally:
_custom_saved_tensors_hooks_tls_ctx.in_callback = False


def _python_exit_handler():
global _python_exit
_python_exit = True


_python_exit = False
atexit.register(_python_exit_handler)

0 comments on commit 78af734

Please sign in to comment.