You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to figure out if flash_attn_with_kvcache() output matches sdpa.
I am running this script without Rotary embeddings first:
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from flash_attn import flash_attn_with_kvcache
batchsize=4
heads=16
dim_per_head = 16
for seqlen in [1, 2]:
query = torch.rand((batchsize, heads, seqlen, dim_per_head), device=torch.device("cuda")).half()
key = query
value = query
########### Compute attn with SDPA
with sdpa_kernel([SDPBackend.MATH]):
attn_output1 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]):
attn_output2 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
attn_output3 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
attn_output4 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
########### flash kvcache
kcache = torch.zeros_like(key)
vcache = torch.zeros_like(value)
attn_output5 = flash_attn_with_kvcache(
query.transpose(1, 2),
kcache.transpose(1, 2),
vcache.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
cache_seqlens=0,
).transpose(1, 2)
print("======> results for sequence length: ", seqlen)
for atol in [1e-3, 1e-4, 1e-5]:
print("atol: ", atol)
print("MATH vs EFFICIENT: ", torch.allclose(attn_output1, attn_output2, atol=atol))
print("EFFICIENT vs SDPAFLASH: ", torch.allclose(attn_output2, attn_output3, atol=atol))
print("MATH vs SDPAFLASH: ", torch.allclose(attn_output1, attn_output3, atol=atol))
print("MATH vs CUDNN: ", torch.allclose(attn_output1, attn_output4, atol=atol))
print("EFFICIENT vs CUDNN: ", torch.allclose(attn_output2, attn_output4, atol=atol))
print("SDPAFLASH vs CUDNN: ", torch.allclose(attn_output3, attn_output4, atol=atol))
print("EFFICIENT vs FLASHCACHE: ", torch.allclose(attn_output2, attn_output5, atol=atol))
print("SDPAFLASH vs FLASHCACHE: ", torch.allclose(attn_output3, attn_output5, atol=atol))
print()
print()
The output is as follow:
======> results for sequence length: 1
atol: 0.001
MATH vs EFFICIENT: True
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: True
MATH vs CUDNN: True
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
EFFICIENT vs FLASHCACHE: True
SDPAFLASH vs FLASHCACHE: True
atol: 0.0001
MATH vs EFFICIENT: True
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: True
MATH vs CUDNN: True
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
EFFICIENT vs FLASHCACHE: True
SDPAFLASH vs FLASHCACHE: True
atol: 1e-05
MATH vs EFFICIENT: True
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: True
MATH vs CUDNN: True
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
EFFICIENT vs FLASHCACHE: True
SDPAFLASH vs FLASHCACHE: True
======> results for sequence length: 2
atol: 0.001
MATH vs EFFICIENT: True
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: True
MATH vs CUDNN: True
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
EFFICIENT vs FLASHCACHE: True
SDPAFLASH vs FLASHCACHE: True
atol: 0.0001
MATH vs EFFICIENT: False
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: False
MATH vs CUDNN: False
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
EFFICIENT vs FLASHCACHE: True
SDPAFLASH vs FLASHCACHE: True
atol: 1e-05
MATH vs EFFICIENT: False
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: False
MATH vs CUDNN: False
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
EFFICIENT vs FLASHCACHE: True
SDPAFLASH vs FLASHCACHE: True
We can see that for seqlen=1 all outputs match equal or beyond precision 1e-5
When we switch to seqlen=2 (or more):
at atol 1e-3 all match
starting at atol 1e-4: all but MATH match. The reason being probably that MATH is computed at float32
If I switch query to float32, then MATH and EFFICIENT match
Since SDPAFLASH and FLASH match, let say it's just a precision issue.
NOW.
I am trying to do the same with Rotary embeddings:
Using the following code:
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from flash_attn import flash_attn_with_kvcache
maxseqlen=512
batchsize=2
heads=16
dim_per_head = 16
rotary_dim = 16
rotary_theta = 10000
inv_freq = 1.0 / (
rotary_theta ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)
).to(torch.device("cuda"))
tmax = torch.arange(maxseqlen, device=torch.device("cuda"))
rope = torch.outer(tmax, inv_freq)
cos = torch.cos(rope)
sin = torch.sin(rope)
cos = torch.cat((cos, cos), dim=-1) # Double the size by repeating `cos`
sin = torch.cat((sin, sin), dim=-1) # Double the size by repeating `sin`
position_embeddings = (cos, sin)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_emb(query, key, rope):
# now rope is a tuple (cos, sin)
cos, sin = rope
q_embed = (query * cos) + (rotate_half(query) * sin)
k_embed = (key * cos) + (rotate_half(key) * sin)
return q_embed.type_as(query), k_embed.type_as(key)
for seqlen in [1, 2]:
query = torch.rand((batchsize, heads, seqlen, dim_per_head), device=torch.device("cuda")).half()
key = query
value = query
########### Compute attn with SDPA
start_pos = 0
cos = position_embeddings[0][start_pos : start_pos + seqlen].to(query.dtype)
sin = position_embeddings[1][start_pos : start_pos + seqlen].to(query.dtype)
query, key = apply_rotary_emb(query, key, (cos, sin))
with sdpa_kernel([SDPBackend.MATH]):
attn_output1 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]):
attn_output2 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
attn_output3 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
attn_output4 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
########### flash kvcache
kcache = torch.zeros_like(key)
vcache = torch.zeros_like(value)
# flash expects dim1 of size half the rotary_dim
cos = position_embeddings[0][:, : cos.size(1) // 2].to(query.dtype)
sin = position_embeddings[1][:, : sin.size(1) // 2].to(query.dtype)
attn_output5 = flash_attn_with_kvcache(
query.transpose(1, 2),
kcache.transpose(1, 2),
vcache.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
rotary_cos=cos,
rotary_sin=sin,
cache_seqlens=0,
rotary_interleaved=False,
).transpose(1, 2)
print("results for sequence length: ", seqlen)
for atol in [1e-3, 1e-4, 1e-5]:
print("atol: ", atol)
print("MATH vs EFFICIENT: ", torch.allclose(attn_output1, attn_output2, atol=atol))
print("EFFICIENT vs SDPAFLASH: ", torch.allclose(attn_output2, attn_output3, atol=atol))
print("MATH vs SDPAFLASH: ", torch.allclose(attn_output1, attn_output3, atol=atol))
print("MATH vs CUDNN: ", torch.allclose(attn_output1, attn_output4, atol=atol))
print("EFFICIENT vs CUDNN: ", torch.allclose(attn_output2, attn_output4, atol=atol))
print("SDPAFLASH vs CUDNN: ", torch.allclose(attn_output3, attn_output4, atol=atol))
print("EFFICIENT vs FLASHCACHE: ", torch.allclose(attn_output2, attn_output5, atol=atol))
print("SDPAFLASH vs FLASHCACHE: ", torch.allclose(attn_output3, attn_output5, atol=atol))
print()
print()
output is:
results for sequence length: 1
atol: 0.001
MATH vs EFFICIENT: True
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: True
MATH vs CUDNN: True
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
EFFICIENT vs FLASHCACHE: True
SDPAFLASH vs FLASHCACHE: True
atol: 0.0001
MATH vs EFFICIENT: True
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: True
MATH vs CUDNN: True
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
EFFICIENT vs FLASHCACHE: True
SDPAFLASH vs FLASHCACHE: True
atol: 1e-05
MATH vs EFFICIENT: True
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: True
MATH vs CUDNN: True
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
EFFICIENT vs FLASHCACHE: True
SDPAFLASH vs FLASHCACHE: True
results for sequence length: 2
atol: 0.001
MATH vs EFFICIENT: True
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: True
MATH vs CUDNN: True
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
**EFFICIENT vs FLASHCACHE: False**
**SDPAFLASH vs FLASHCACHE: False**
atol: 0.0001
MATH vs EFFICIENT: False
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: False
MATH vs CUDNN: False
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
**EFFICIENT vs FLASHCACHE: False**
**SDPAFLASH vs FLASHCACHE: False**
atol: 1e-05
MATH vs EFFICIENT: False
EFFICIENT vs SDPAFLASH: True
MATH vs SDPAFLASH: False
MATH vs CUDNN: False
EFFICIENT vs CUDNN: True
SDPAFLASH vs CUDNN: True
**EFFICIENT vs FLASHCACHE: False**
**SDPAFLASH vs FLASHCACHE: False**
Again for seqlen=1 everything is fine
For Seqlen=2 (or more) FLASHCACHE does not match SDPACACHE (or EFFICIENT) anymore
The text was updated successfully, but these errors were encountered:
I am trying to figure out if flash_attn_with_kvcache() output matches sdpa.
I am running this script without Rotary embeddings first:
The output is as follow:
We can see that for seqlen=1 all outputs match equal or beyond precision 1e-5
When we switch to seqlen=2 (or more):
at atol 1e-3 all match
starting at atol 1e-4: all but MATH match. The reason being probably that MATH is computed at float32
If I switch query to float32, then MATH and EFFICIENT match
Since SDPAFLASH and FLASH match, let say it's just a precision issue.
NOW.
I am trying to do the same with Rotary embeddings:
Using the following code:
output is:
Again for seqlen=1 everything is fine
For Seqlen=2 (or more) FLASHCACHE does not match SDPACACHE (or EFFICIENT) anymore
The text was updated successfully, but these errors were encountered: