Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

is fwd_kvcache compatible with torch.compile in 2.7.2post1 ? #1386

Open
vince62s opened this issue Dec 14, 2024 · 6 comments
Open

is fwd_kvcache compatible with torch.compile in 2.7.2post1 ? #1386

vince62s opened this issue Dec 14, 2024 · 6 comments

Comments

@vince62s
Copy link

Getting this warning and then many subsequent recompiles because using dynamic shapes (and dynamic=True in torch.compile)

/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:725: UserWarning: Graph break due to unsupported builtin flash_attn_2_cuda.PyCapsule.fwd_kvcache. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.

@tridao
Copy link
Contributor

tridao commented Dec 14, 2024

Sure, would love to see some PR fixing this

@vince62s
Copy link
Author

@ani300 in case you know how to fix this.

@vince62s
Copy link
Author

By the way: there is a slight speed regression for inference with kvcache between 2.5.9.post1 and 2.6.1
(4% in my case)

@tridao
Copy link
Contributor

tridao commented Dec 16, 2024

Can you send a short script to reproduce the speed regression? e.g. with this input, 2.5.9.post1 gets XXX seconds and 2.6.1 gets YYY seconds

@ani300
Copy link
Contributor

ani300 commented Dec 16, 2024

@vince62s I probably forgot to add the torch.compile() wrapping for this function when I did the rest. I can probably take a stab at it later in the week, as I'm wrapped up with a work deadline until Wednesday

@vince62s
Copy link
Author

I am lazy so I did not recompile 2.6.1 which takes too long to compile but 2.6.1 and 2.7.2post1 are similar in speed.

import torch
import time
from flash_attn.flash_attn_interface import flash_attn_with_kvcache

def test_flash_attn_with_kvcache():
    # Define tensor dimensions
    batch_size = 32
    num_heads = 16
    head_dim = 64
    seqlen_q = 1
    cache_len = 1024

    torch.cuda.synchronize()
    starttime = time.time()
    for i in range(100000):
        # Generate random tensors for query and cached key/value
        q = torch.randn(batch_size, 1, num_heads, head_dim, dtype=torch.float16, device="cuda")
        k = torch.randn(batch_size, 1, num_heads, head_dim, dtype=torch.float16, device="cuda")
        v = torch.randn(batch_size, 1, num_heads, head_dim, dtype=torch.float16, device="cuda")
        k_cache = torch.randn(batch_size, cache_len, num_heads, head_dim, dtype=torch.float16, device="cuda")
        v_cache = torch.randn(batch_size, cache_len, num_heads, head_dim, dtype=torch.float16, device="cuda")

        # Test for non-causal case
        attn_output = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens=cache_len)
    torch.cuda.synchronize()
    print(time.time() - starttime)
    

# Run the test
if __name__ == "__main__":
    test_flash_attn_with_kvcache()

With 2.5.9post1: 28.4653 sec
With 2.7.2post1: 29.2737 sec

That's 3% but my real world use case says 4% (maybe because I use Rotary cos/sin also)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants