Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch gradient_checkpoint_scope could trigger segmentation fault? #1581

Open
albertz opened this issue Jul 12, 2024 · 16 comments
Open

Torch gradient_checkpoint_scope could trigger segmentation fault? #1581

albertz opened this issue Jul 12, 2024 · 16 comments

Comments

@albertz
Copy link
Member

albertz commented Jul 12, 2024

I just saw this in the CI (at commit d5b954b):

============================= test session starts ==============================
platform linux -- Python 3.10.[14](https://github.com/rwth-i6/returnn/actions/runs/9909690500/job/27378323845#step:7:15), pytest-8.2.2, pluggy-1.5.0
rootdir: /home/runner/work/returnn/returnn
configfile: pytest.ini
collected 2 items

tests/test_torch_util.py ..                                              [100%]

=============================== warnings summary ===============================
tests/test_torch_util.py::test_gradient_checkpoint_scope
  /home/runner/work/returnn/returnn/tests/test_torch_util.py:[15](https://github.com/rwth-i6/returnn/actions/runs/9909690500/job/27378323845#step:7:16)1: FutureWarning: `torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. Please use `torch.testing.assert_close()` instead. You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.
    torch.testing.assert_allclose(param_post_state[k], param_post_state_[k])

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================= 2 passed, 1 warning in 1.65s =========================
/home/runner/work/_temp/f14cefc5-56ba-4a81-9[17](https://github.com/rwth-i6/returnn/actions/runs/9909690500/job/27378323845#step:7:18)0-4e80a8ecf45f.sh: line 2:  [19](https://github.com/rwth-i6/returnn/actions/runs/9909690500/job/27378323845#step:7:20)90 Segmentation fault      (core dumped) python -m pytest tests/test_$TEST.py
Error: Process completed with exit code 139.

So tests ran through but at the exit, we got some segmentation fault. Maybe the gradient scope was cleaned up at that late point?

@albertz
Copy link
Member Author

albertz commented Jul 13, 2024

I got this now a second time (CI log). It's occurs 10% of the cases (very approximately).

I assume the Tensor.__del__ handler maybe runs very late and calls to PyTorch API which is not really expected anymore at that point.

@albertz
Copy link
Member Author

albertz commented Jul 13, 2024

I assume the Tensor.__del__ handler maybe runs very late and calls to PyTorch API which is not really expected anymore at that point.

I just pushed sth which should check for this. So let's see if this occurs again.

@albertz albertz reopened this Jul 15, 2024
@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

It also happened afterwards (in 6e2ce01, CI log). Actually much more often now, seems to be 100% of the cases?

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

I can reproduce the crash locally.

(gdb) bt
#0  0x00007ffff7c86054 in visit_decref () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#1  0x00007ffff7c3baa7 in _PyObject_VisitInstanceAttributes () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#2  0x00007ffff7c4687c in subtype_traverse () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#3  0x00007ffff7cf9d15 in deduce_unreachable () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#4  0x00007ffff7cf9c4c in gc_collect_main () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#5  0x00007ffff7cf925c in gc_collect_with_callback () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#6  0x00007ffff7cf9ef2 in PyGC_Collect () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#7  0x00007ffff7cee923 in Py_FinalizeEx () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#8  0x00007ffff7cf8d40 in Py_RunMain () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#9  0x00007ffff7cf8ab9 in Py_BytesMain () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#10 0x00007ffff778f1b7 in __libc_start_call_main (main=main@entry=0x401040 <main>, argc=argc@entry=4, argv=argv@entry=0x7fffffffda28) at ../sysdeps/nptl/libc_start_call_main.h:58
#11 0x00007ffff778f26c in __libc_start_main_impl (main=0x401040 <main>, argc=4, argv=0x7fffffffda28, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, 
    stack_end=0x7fffffffda18) at ../csu/libc-start.c:392
#12 0x0000000000401071 in _start () at ../sysdeps/x86_64/start.S:115

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

I was playing around with iterating through all alive objects at the end, and that also triggers the crash.

Sth like this:

print("**** remaining objects:")
import gc

for obj in gc.get_objects():
    if type(obj) in {tuple, list, dict}:
        continue
    print(type(obj), obj)

Crash:

0x00007ffff7c86054 in visit_decref () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
(gdb) bt
#0  0x00007ffff7c86054 in visit_decref () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#1  0x00007ffff7c3baa7 in _PyObject_VisitInstanceAttributes () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#2  0x00007ffff7c4687c in subtype_traverse () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#3  0x00007ffff7cf9d15 in deduce_unreachable () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#4  0x00007ffff7cf9457 in gc_collect_main () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#5  0x00007ffff7cf925c in gc_collect_with_callback () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#6  0x00007ffff7c85ddf in _PyObject_GC_Link () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#7  0x00007ffff7c85d12 in _PyObject_GC_New () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#8  0x00007ffff7c43771 in tuple_iter () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#9  0x00007ffff7c194e9 in PyObject_GetIter () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#10 0x00007ffff7c63c5c in _PyEval_EvalFrameDefault () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#11 0x00007ffff7c60cf2 in _PyEval_Vector () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#12 0x00007ffff7c6584a in _PyEval_EvalFrameDefault () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#13 0x00007ffff7c60cf2 in _PyEval_Vector () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#14 0x00007ffff7c1fbba in PyObject_CallOneArg () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#15 0x00007ffff7c3e3e5 in _PyObject_GenericGetAttrWithDict () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#16 0x00007ffff7c3dcae in PyObject_GetAttr () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#17 0x00007ffff7c62947 in _PyEval_EvalFrameDefault () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#18 0x00007ffff7c60cf2 in _PyEval_Vector () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#19 0x00007ffff7c46fe0 in vectorcall_method () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#20 0x00007ffff7cc0fb6 in slot_tp_str () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#21 0x00007ffff7c3eac7 in PyObject_Str () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#22 0x00007ffff7cad4c4 in PyFile_WriteObject () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#23 0x00007ffff7cdabaf in builtin_print () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#24 0x00007ffff7c6502f in _PyEval_EvalFrameDefault () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#25 0x00007ffff7c60cf2 in _PyEval_Vector () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#26 0x00007ffff7cdbec6 in PyEval_EvalCode () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#27 0x00007ffff7cf0884 in run_eval_code_obj () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#28 0x00007ffff7cf0806 in run_mod () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#29 0x00007ffff7cf0ff1 in pyrun_file () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#30 0x00007ffff7cf0c6b in _PyRun_SimpleFileObject () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#31 0x00007ffff7cf0a83 in _PyRun_AnyFileObject () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#32 0x00007ffff7cf8e8c in Py_RunMain () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#33 0x00007ffff7cf8ab9 in Py_BytesMain () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

With python3-dbg some more:

Thread 1 "python3.10" received signal SIGSEGV, Segmentation fault.
0x000055555567fb51 in _PyObject_IS_GC (obj=<unknown at remote 0x7fff08b03ac0>) at ../Include/internal/pycore_object.h:166
166     ../Include/internal/pycore_object.h: No such file or directory.
(gdb) bt
#0  0x000055555567fb51 in _PyObject_IS_GC (obj=<unknown at remote 0x7fff08b03ac0>) at ../Include/internal/pycore_object.h:166
#1  visit_decref (parent=<optimized out>, op=<unknown at remote 0x7fff08b03ac0>) at ../Modules/gcmodule.c:456
#2  dict_traverse (op={'pack_hook': <unknown at remote 0x7fff08b03ac0>, 'unpack_hook': <function at remote 0x7fff0aa0c670>}, 
    visit=<optimized out>, arg=<optimized out>) at ../Objects/dictobject.c:3250
#3  0x000055555567f3e5 in subtract_refs (containers=<optimized out>) at ../Modules/gcmodule.c:482
#4  deduce_unreachable (base=base@entry=0x555555b3dfa0, unreachable=unreachable@entry=0x7fffffffd410) at ../Modules/gcmodule.c:1105
#5  0x000055555567e94c in gc_collect_main (tstate=0x555555b59b90, generation=2, n_collected=0x7fffffffd4d8, 
    n_uncollectable=0x7fffffffd4d0, nofail=0) at ../Modules/gcmodule.c:1239
#6  0x0000555555785e80 in gc_collect_with_callback (tstate=0x555555b59b90, generation=2) at ../Modules/gcmodule.c:1413
#7  0x00005555557b74de in PyGC_Collect () at ../Modules/gcmodule.c:2099
#8  0x00005555557b4ef0 in Py_FinalizeEx () at ../Python/pylifecycle.c:1781
#9  0x00005555557a6313 in Py_RunMain () at ../Modules/main.c:668
#10 0x000055555577ca3d in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at ../Modules/main.c:720
#11 0x00007ffff7c29d90 in __libc_start_call_main (main=main@entry=0x55555577ca00 <main>, argc=argc@entry=2, 
    argv=argv@entry=0x7fffffffd7b8) at ../sysdeps/nptl/libc_start_call_main.h:58
#12 0x00007ffff7c29e40 in __libc_start_main_impl (main=0x55555577ca00 <main>, argc=2, argv=0x7fffffffd7b8, init=<optimized out>, 
    fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffd7a8) at ../csu/libc-start.c:392
#13 0x000055555577c935 in _start ()

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

With:

print("**** remaining objects:")
import gc

for obj in gc.get_objects():
    print("0x%x" % id(obj), type(obj), obj)
print("**** done.")

Another variant of the crash:

Thread 1 "python3.10" received signal SIGSEGV, Segmentation fault.
0x00005555556c0ff6 in PyObject_Repr (v=<unknown at remote 0x7fff08be3d80>) at ../Objects/object.c:422
422     ../Objects/object.c: No such file or directory.
(gdb) bt
#0  0x00005555556c0ff6 in PyObject_Repr (v=<unknown at remote 0x7fff08be3d80>) at ../Objects/object.c:422
#1  0x00005555557ca2c0 in dict_repr (mp=0x7fff08d92300) at ../Objects/dictobject.c:2148
#2  0x00005555556c4f4d in object_str (
    self={'pack_hook': <unknown at remote 0x7fff08be3d80>, 'unpack_hook': <function at remote 0x7fff0a9e8f70>})
    at ../Objects/typeobject.c:4550
#3  PyObject_Str (v={'pack_hook': <unknown at remote 0x7fff08be3d80>, 'unpack_hook': <function at remote 0x7fff0a9e8f70>})
    at ../Objects/object.c:499
#4  PyObject_Str (v={'pack_hook': <unknown at remote 0x7fff08be3d80>, 'unpack_hook': <function at remote 0x7fff0a9e8f70>})
    at ../Objects/object.c:462
#5  0x000055555578f24d in PyFile_WriteObject (
    v={'pack_hook': <unknown at remote 0x7fff08be3d80>, 'unpack_hook': <function at remote 0x7fff0a9e8f70>}, f=<optimized out>, 
    flags=<optimized out>) at ../Objects/fileobject.c:132
#6  0x000055555578e8f2 in builtin_print (self=<optimized out>, args=0x7ffff7529db0, nargs=3, kwnames=<optimized out>)
    at ../Python/bltinmodule.c:2003
#7  0x00005555556a22eb in cfunction_vectorcall_FASTCALL_KEYWORDS (
    func=<built-in method print of module object at remote 0x7ffff7594950>, args=0x7ffff7529db0, nargsf=<optimized out>, kwnames=0x0)
    at ../Objects/methodobject.c:446
#8  0x0000555555697827 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x7ffff7529db0, 
    callable=<built-in method print of module object at remote 0x7ffff7594950>, tstate=0x555555b59b90)
    at ../Include/cpython/abstract.h:114
#9  PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7ffff7529db0, 
    callable=<built-in method print of module object at remote 0x7ffff7594950>) at ../Include/cpython/abstract.h:123
#10 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, trace_info=0x7fffffffd300, 
    tstate=<optimized out>) at ../Python/ceval.c:5893
#11 _PyEval_EvalFrameDefault (tstate=<optimized out>, f=<optimized out>, throwflag=<optimized out>) at ../Python/ceval.c:4213
#12 0x0000555555693f96 in _PyEval_EvalFrame (throwflag=0, 
    f=Frame 0x7ffff7529c40, for file /home/az/Programmierung/returnn/tests/test_torch_util.py, line 354, in <module> (), 
    tstate=0x555555b59b90) at ../Include/internal/pycore_ceval.h:46
#13 _PyEval_Vector (tstate=0x555555b59b90, con=<optimized out>, locals=<optimized out>, args=<optimized out>, 
    argcount=<optimized out>, kwnames=<optimized out>) at ../Python/ceval.c:5067
#14 0x0000555555789c66 in PyEval_EvalCode (co=<code at remote 0x7ffff74a6d90>, 
    globals={'__name__': '__main__', '__doc__': '\nTest :mod:`returnn.torch.util`.\n', '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/az/Programmierung/returnn/tests/test_torch_util.py') at remote 0x7ffff73dd9f0>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7594950>, '__file__': '/home/az/Programmierung/returnn/tests/test_torch_util.py', '__cached__': None, 'annotations': <_Feature(optional=(3, 7, 0, 'beta', 1), mandatory=(3, 11, 0, 'alpha', 0), compiler_flag=16777216) at remote 0x7ffff731ddb0>, '_setup_test_env': <module at remote 0x7ffff7315f80>, 'os': <module at remote 0x7ffff7404e00>, 'sys': <module at remote 0x7ffff7582390>, 'unittest': <module at remote 0x7fffb95e19e0>, 'torch': <module at remote 0x7fffb95e1530>, 'better_exchook': <module at remote 0x7ffff739eca0>, 'gradient_checkpoint_scope': <type at remote 0x5555562d9ee0>, 'test_gradient_checkpoint_scope': <function at remote 0x7ffff74b56c0>, 'test_gradient_checkpoint_scope_twice': <functio...(truncated), 
    locals=<optimized out>) at ../Python/ceval.c:1134

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

Note, this object you see here in object_str, that looks very much like the __dict__ of a saved_tensors_hooks instance, which has pack_hook and unpack_hook.

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

Ok, I added this print in gradient_checkpoint_scope.__init__ to print the address of the pack_hook method:

    def __init__(self):
        self.record_graph_scope = _RecordGraph()
        self.record_graph_scope.graph.gradient_checkpoint_scope_backref = self
        # Note: saved_tensors_hooks is thread local.
        self.saved_tensors_hooks_scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook)
        print("*** pack hook: 0x%x" % id(self.saved_tensors_hooks_scope.pack_hook))

Then I get this at the end:

Executing: test_gradient_checkpoint_scope_twice
*** pack hook: 0x7fff0935a080
*** pack hook: 0x7fff08c07c80
*** _pack_hook
*** _pack_hook
*** _unpack_hook
*** _unpack_hook
*** _custom_saved_tensors_hooks_exit
*** pack hook: 0x7fff08be1240
*** _custom_saved_tensors_hooks_exit
*** pack hook: 0x7fff08be5200
*** _pack_hook
*** _pack_hook
*** _unpack_hook
*** _unpack_hook
*** _custom_saved_tensors_hooks_exit
----------------------------------------
Finished all tests.
**** remaining objects:
...
0x7fff08bc4e80 <class 'torch.autograd.graph.saved_tensors_hooks'> <torch.autograd.graph.saved_tensors_hooks object at 0x7fff08bc4e80>

Thread 1 "python3.10" received signal SIGSEGV, Segmentation fault.
0x00005555556c0ff6 in PyObject_Repr (v=<unknown at remote 0x7fff08c07c80>) at ../Objects/object.c:422
422     ../Objects/object.c: No such file or directory.
(gdb) bt
#0  0x00005555556c0ff6 in PyObject_Repr (v=<unknown at remote 0x7fff08c07c80>) at ../Objects/object.c:422
#1  0x00005555557ca2c0 in dict_repr (mp=0x7fff08c07240) at ../Objects/dictobject.c:2148
#2  0x00005555556c4f4d in object_str (
    self={'pack_hook': <unknown at remote 0x7fff08c07c80>, 'unpack_hook': <function at remote 0x7fff0a9e1000>})
    at ../Objects/typeobject.c:4550
#3  PyObject_Str (v={'pack_hook': <unknown at remote 0x7fff08c07c80>, 'unpack_hook': <function at remote 0x7fff0a9e1000>})
    at ../Objects/object.c:499
#4  PyObject_Str (v={'pack_hook': <unknown at remote 0x7fff08c07c80>, 'unpack_hook': <function at remote 0x7fff0a9e1000>})
    at ../Objects/object.c:462
#5  0x000055555578f24d in PyFile_WriteObject (
    v={'pack_hook': <unknown at remote 0x7fff08c07c80>, 'unpack_hook': <function at remote 0x7fff0a9e1000>}, f=<optimized out>, 
    flags=<optimized out>) at ../Objects/fileobject.c:132
#6  0x000055555578e8f2 in builtin_print (self=<optimized out>, args=0x7ffff7529db0, nargs=3, kwnames=<optimized out>)
    at ../Python/bltinmodule.c:2003
...

So I guess we have already freed the method but we are still trying to access it here.
Could this be an error on PyTorch regarding refcounting on the pack_hook?

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

Added some debug code:

def _custom_saved_tensors_hooks_exit(
    self: torch.autograd.graph.saved_tensors_hooks, exc_type=None, exc_val=None, exc_tb=None
):
    print(f"*** _custom_saved_tensors_hooks_exit, stack {_custom_saved_tensors_hooks_tls_ctx.stack}")
    f = sys._getframe()
    while f:
        co = f.f_code
        print("-", co.co_name, co.co_filename, f.f_lineno)
        f = f.f_back
    ...

Then:

**** iter 0
*** pack hook: 0x7fff091c2a00
*** gradient_checkpoint_scope.__enter__
*** pack hook: 0x7fff08b3c640
*** gradient_checkpoint_scope.__enter__
*** _custom_saved_tensors_hooks_enter
*** _pack_hook
*** _pack_hook
[New Thread 0x7fff08aff640 (LWP 232551)]
[New Thread 0x7ffefcfde640 (LWP 232555)]
*** _unpack_hook
*** _unpack_hook
*** exit_saved_tensors_hooks_scope __exit__ now, pack_hook: 0x7fff08b3c640
*** _custom_saved_tensors_hooks_exit, stack [<torch.autograd.graph.saved_tensors_hooks object at 0x7fff08bdecb0>, <torch.autograd.graph.saved_tensors_hooks object at 0x7fff08bdf190>]
- _custom_saved_tensors_hooks_exit /home/az/Programmierung/returnn/tests/test_torch_util.py 641
- exit_saved_tensors_hooks_scope /home/az/Programmierung/returnn/tests/test_torch_util.py 296
- _maybe_exit_saved_tensors_hooks_scope /home/az/Programmierung/returnn/tests/test_torch_util.py 270
- _unpack_hook /home/az/Programmierung/returnn/tests/test_torch_util.py 315
- backward /home/az/.local/lib/python3.10/site-packages/torch/autograd/__init__.py 266
- backward /home/az/.local/lib/python3.10/site-packages/torch/_tensor.py 522
- demo_run /home/az/Programmierung/returnn/tests/test_torch_util.py 722
- test_saved_tensors_hooks_gc_segfault /home/az/Programmierung/returnn/tests/test_torch_util.py 730
- <module> /home/az/Programmierung/returnn/tests/test_torch_util.py 880
*** _custom_saved_tensors_hooks_exit: exit now, scope <torch.autograd.graph.saved_tensors_hooks object at 0x7fff08bdf190>, pack_hook 0x7fff08b3c640
...

**** iter 4

Thread 1 "python3.10" received signal SIGSEGV, Segmentation fault.
0x000055555567fb51 in _PyObject_IS_GC (obj=<unknown at remote 0x7fff08b3c640>) at ../Include/internal/pycore_object.h:166
166     ../Include/internal/pycore_object.h: No such file or directory.

So, maybe the problem is that we call saved_tensors_hooks.__exit__ inside the unpack hook?

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

I have a standalone test case:

def test_saved_tensors_hooks_gc_segfault2():
    # https://github.com/rwth-i6/returnn/issues/1581
    shape = (101, 103)
    for i in range(10):
        v1 = torch.nn.Parameter(torch.randn(shape))
        v2 = torch.nn.Parameter(torch.randn(shape))

        class _Handler:
            def __init__(self, exit_in_unpack: bool = False):
                self.scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook)
                self.exit_in_unpack = exit_in_unpack
                self.exited = False

            def _pack_hook(self, x):
                print(f"*** _pack_hook {self}")
                return self, x

            @staticmethod
            def _unpack_hook(x):
                self, x = x
                print(f"*** _unpack_hook {self}")
                if self.exit_in_unpack and not self.exited:
                    self.exited = True
                    self.scope.__exit__()
                return x

        handler1 = _Handler(exit_in_unpack=False)
        handler1.scope.__enter__()
        v1_ = v1 + torch.randn(shape)

        handler2 = _Handler(exit_in_unpack=True)
        handler2.scope.__enter__()
        v2_ = v2 + torch.randn(shape)

        x = v1_ * v2_
        x.sum().backward()
        del x
        handler1.scope.__exit__()

I'm trying to simplify this now further.

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

Slightly different version:

def test_saved_tensors_hooks_gc_segfault2():
    # https://github.com/rwth-i6/returnn/issues/1581
    shape = (101, 103)
    for i in range(10):
        print("**** iter", i)
        v = torch.nn.Parameter(torch.randn(shape))

        class _Handler:
            def __init__(self):
                self.scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook)
                self.scope.__enter__()
                self.exited = False

            def _pack_hook(self, x):
                print(f"*** _pack_hook {self}")
                return x

            def _unpack_hook(self, x):
                print(f"*** _unpack_hook {self}")
                if not self.exited:
                    self.exited = True
                    self.scope.__exit__()
                return x

        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
            handler = _Handler()  # keep ref...  # noqa
            x = v * torch.randn(shape)
            x.sum().backward()

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

I reported that upstream: pytorch/pytorch#130734

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

I pushed a workaround now. See _can_exit_saved_tensors_hooks_inside_hooks. If possible, I would like to extend this logic later. But let's wait for the response in pytorch/pytorch#130734.

@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

Actually let's keep this open until we got some response, and then wait until we can update _can_exit_saved_tensors_hooks_inside_hooks.

@albertz albertz reopened this Jul 15, 2024
@albertz
Copy link
Member Author

albertz commented Jul 15, 2024

Also note, the current solution is maybe not so optimal. The current potential ways that we would exit the torch.autograd.graph.saved_tensors_hooks:

  • gradient_checkpoint_scope.__exit__. But likely not, as there are likely refs to the registered tensors.
  • gradient_checkpoint_scope.__del__ if in the right thread. But likely not, as there are likely still refs to the registered tensors.
  • Tensor.__del__ if in the right thread.
  • Any future call to torch.autograd.graph.saved_tensors_hooks.__enter__ or torch.autograd.graph.saved_tensors_hooks.__exit__.
  • In pack hook or unpack hook. But now not anymore if not _can_exit_saved_tensors_hooks_inside_hooks.

So, this means, in practice, with the current _can_exit_saved_tensors_hooks_inside_hooks check, the only real realistic way that it gets cleaned up is via the next saved_tensors_hooks.__enter__ or Tensor.__del__. Tensor.__del__ would be fine, but we cannot guarantee that this will be in the right thread.

The _GraphTensors are cleaned up independent of that, so the only problem here is the additional overhead we get because of a few pack/unpack hooks which don't do anything. Some solution to pytorch/pytorch#129867 would allow us to reduce this (and also simplify the whole logic).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant