Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to relax V shape requirements to have different head dim than q/k? #753

Closed
Maykeye opened this issue Jan 5, 2024 · 2 comments · May be fixed by #980
Closed

Is it possible to relax V shape requirements to have different head dim than q/k? #753

Maykeye opened this issue Jan 5, 2024 · 2 comments · May be fixed by #980

Comments

@Maykeye
Copy link

Maykeye commented Jan 5, 2024

Torch's SDPA doesn't require V to have the same dimensions as inputs, it even noted in docs with different dimensions E and Ev as when V is multiplied by, head dimensions is gone and we have only L x L matrix.

In [23]: qk = torch.randn(4, 4, 4, 8).bfloat16().cuda()  

In [24]: v = torch.randn(4, 4, 4, 16).bfloat16().cuda()

In [25]: F.scaled_dot_product_attention(qk, qk, v).shape
Out[25]: torch.Size([4, 4, 4, 16])

same with xfrormers, they use K and Kv in doc.

In [26]: xops.memory_efficient_attention(qk, qk,v).shape
Out[26]: torch.Size([4, 4, 4, 16])

However flash attention 2 [2.4.2] requires head dimensions to match.

In [27]: flash_attn.flash_attn_func(qk,qk,v)....
RuntimeError: v must have shape (batch_size, seqlen_k, num_heads_k, head_size_og)

(as documented it requires all tensors to have headdim per head (error uses different name than documentation))

can it be relaxed to have different head_size for v or implementation depends on head dimensions match?

@tridao
Copy link
Contributor

tridao commented Jan 5, 2024

While it's theoretically possible, we don't plan to do that. The reason is that we're already templating on the head dimension (32, 64, 96, 128, 160, 192, 224, 256). If V has a different head dimension we'd need to increase the number of templates by 8x, and compilation time will increase by 8x.

@Maykeye
Copy link
Author

Maykeye commented Jan 6, 2024

I see

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants