You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
correct me if I am wrong: It seems there is no such an existing test. Could you provide such a test or point out one if there is any?
fwiw, I looked at the test test_flash_attn_kvcache from tests/test_flash_attn_ck.py and tests/test_flash_attn.py
Here is my failed attempt. Likely I am not using this function correctly.
from flash_attn import flash_attn_with_kvcache
from flash_attn import flash_attn_func
import torch
n_layers = 2
dim = 3
num_kv_heads = 1
batch_size = 1
head_size = dim // num_kv_heads
block_size = 256
device = 'cuda'
dtype = torch.float16
kv_cache = [(torch.zeros((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device),
torch.zeros((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)) for _ in
range(n_layers)]
layer = 0
k_cache, v_cache = kv_cache[layer]
k = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
v = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
q = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
max_num_blocks_per_seq = 2
block_tables = torch.randint(0,
batch_size,
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32
, device=device)
cache_seqlens = torch.tensor([i for i in range(batch_size)], dtype=torch.int32, device=device)
y = flash_attn_with_kvcache(q, k_cache=k_cache, v_cache=v_cache, k=k, v=v, block_table=block_tables,
cache_seqlens=cache_seqlens,
causal=True)
to_print_v = v.view(batch_size * block_size, num_kv_heads * head_size).float().cpu().numpy()
to_print_v_cache = v_cache.view(batch_size * block_size, num_kv_heads * head_size).float().cpu().numpy()
y_with_cached_KV = flash_attn_with_kvcache(q, k_cache=k_cache, v_cache=v_cache, k=None, v=None,
block_table=block_tables,
cache_seqlens=cache_seqlens,
causal=True)
yy=flash_attn_func(q,k,v)
assert y_with_cached_KV == yy #RuntimeError: Boolean value of Tensor with more than one value is ambiguous
The text was updated successfully, but these errors were encountered:
chakpongchung
changed the title
looking for a test to compare the result with and without the KV cache
looking for a test to compare the result with the KV cache updated in place and without the KV cache
Dec 27, 2024
correct me if I am wrong: It seems there is no such an existing test. Could you provide such a test or point out one if there is any?
fwiw, I looked at the test
test_flash_attn_kvcache
fromtests/test_flash_attn_ck.py
andtests/test_flash_attn.py
Here is my failed attempt. Likely I am not using this function correctly.
The text was updated successfully, but these errors were encountered: