You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to run benchmark_flash_attention.py on a ROCm system with MI300X GPUs. It runs through sequence lengths 512, 1024, 2048, and 4096 for headdim=64, but after that it hangs and gives a Memory Access Fault error. I observe this on flash-attention v2.6.3, on pytorch 2.5.1 and nightly builds of pytorch 2.6.
Here is the pystack stack trace:
Traceback for thread 3560650 [] (most recent call last):
(Python) File "/home/tensorwave/nikhil/fa-tests/flash-attention/benchmarks/benchmark_flash_attention.py", line 126, in <module>
_, b0 = time_fwd_bwd(
(Python) File "/home/tensorwave/nikhil/fa-tests/flash-attention/benchmarks/benchmark_flash_attention.py", line 66, in time_fwd_bwd
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
(Python) File "/home/tensorwave/nikhil/fa-tests/torch2.6_11-18/lib/python3.10/site-packages/flash_attn-2.6.3-py3.10-linux-x86_64.egg/flash_attn/utils/benchmark.py", line 140, in benchmark_fwd_bwd
benchmark_backward(
(Python) File "/home/tensorwave/nikhil/fa-tests/torch2.6_11-18/lib/python3.10/site-packages/flash_attn-2.6.3-py3.10-linux-x86_64.egg/flash_attn/utils/benchmark.py", line 66, in benchmark_backward
m = t.timeit(repeats)
(Python) File "/home/tensorwave/nikhil/fa-tests/torch2.6_11-18/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 274, in timeit
self._timeit(number=max(int(number // 100), 2))
(Python) File "/home/tensorwave/nikhil/fa-tests/torch2.6_11-18/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 264, in _timeit
return max(self._timer.timeit(number), 1e-9)
(Python) File "/usr/lib/python3.10/timeit.py", line 178, in timeit
timing = self.inner(it, self.timer)
(Python) File "<timeit-src>", line 8, in inner
(Python) File "/home/tensorwave/nikhil/fa-tests/torch2.6_11-18/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 18, in timer
torch.cuda.synchronize()
(Python) File "/home/tensorwave/nikhil/fa-tests/torch2.6_11-18/lib/python3.10/site-packages/torch/cuda/__init__.py", line 969, in synchronize
return torch._C._cuda_synchronize()
The text was updated successfully, but these errors were encountered:
I am trying to run
benchmark_flash_attention.py
on a ROCm system with MI300X GPUs. It runs through sequence lengths 512, 1024, 2048, and 4096 for headdim=64, but after that it hangs and gives a Memory Access Fault error. I observe this on flash-attention v2.6.3, on pytorch 2.5.1 and nightly builds of pytorch 2.6.Here is the pystack stack trace:
The text was updated successfully, but these errors were encountered: