Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

looking for a test to compare the result with the KV cache updated in place and without the KV cache #1414

Open
chakpongchung opened this issue Dec 26, 2024 · 0 comments

Comments

@chakpongchung
Copy link

chakpongchung commented Dec 26, 2024

correct me if I am wrong: It seems there is no such an existing test. Could you provide such a test or point out one if there is any?

fwiw, I looked at the test
test_flash_attn_kvcache from tests/test_flash_attn_ck.py and tests/test_flash_attn.py

Here is my failed attempt. Likely I am not using this function correctly.

from flash_attn import flash_attn_with_kvcache
from flash_attn import flash_attn_func

import torch

n_layers = 2
dim = 3
num_kv_heads = 1
batch_size = 1
head_size = dim // num_kv_heads
block_size = 256

device = 'cuda'
dtype = torch.float16

kv_cache = [(torch.zeros((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device),
             torch.zeros((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)) for _ in
            range(n_layers)]

layer = 0

k_cache, v_cache = kv_cache[layer]

k = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
v = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
q = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)

max_num_blocks_per_seq = 2
block_tables = torch.randint(0,
                             batch_size,
                             (batch_size, max_num_blocks_per_seq),
                             dtype=torch.int32
                             , device=device)

cache_seqlens = torch.tensor([i for i in range(batch_size)], dtype=torch.int32, device=device)

y = flash_attn_with_kvcache(q, k_cache=k_cache, v_cache=v_cache, k=k, v=v, block_table=block_tables,
                            cache_seqlens=cache_seqlens,
                            causal=True)

to_print_v = v.view(batch_size * block_size, num_kv_heads * head_size).float().cpu().numpy()
to_print_v_cache = v_cache.view(batch_size * block_size, num_kv_heads * head_size).float().cpu().numpy()

y_with_cached_KV = flash_attn_with_kvcache(q, k_cache=k_cache, v_cache=v_cache, k=None, v=None,
                             block_table=block_tables,
                             cache_seqlens=cache_seqlens,
                             causal=True)



yy=flash_attn_func(q,k,v)


assert y_with_cached_KV == yy #RuntimeError: Boolean value of Tensor with more than one value is ambiguous


@chakpongchung chakpongchung changed the title looking for a test to compare the result with and without the KV cache looking for a test to compare the result with the KV cache updated in place and without the KV cache Dec 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant