Skip to content

Commit 97814fb

Browse files
rickyyxcomaniac
andauthored
[v1] Refactor KVCacheManager for more hash input than token ids (#10507)
Signed-off-by: rickyx <[email protected]> Signed-off-by: Cody Yu <[email protected]> Co-authored-by: Cody Yu <[email protected]>
1 parent eebad39 commit 97814fb

File tree

3 files changed

+365
-186
lines changed

3 files changed

+365
-186
lines changed

tests/v1/core/test_prefix_caching.py

Lines changed: 206 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Compare the with and without prefix caching."""
2+
import pytest
3+
24
from vllm.inputs import token_inputs
35
from vllm.sampling_params import SamplingParams
6+
from vllm.utils import cdiv
47
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
5-
from vllm.v1.core.kv_cache_utils import hash_block_tokens
8+
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
69

710

811
def make_request(request_id, prompt_token_ids):
@@ -31,7 +34,8 @@ def test_prefill():
3134
# Fully cache miss
3235
# Incomplete 1 block (7 tokens)
3336
unique_token_ids = [3] * 7
34-
req0 = make_request("0", common_token_ids + unique_token_ids)
37+
all_token_ids = common_token_ids + unique_token_ids
38+
req0 = make_request("0", all_token_ids)
3539
computed_blocks = manager.get_computed_blocks(req0)
3640
assert not computed_blocks
3741
blocks = manager.allocate_slots(req0, 55, computed_blocks)
@@ -40,24 +44,16 @@ def test_prefill():
4044
# Check full block metadata
4145
parent_block_hash = None
4246
for block_id in (0, 1, 2):
43-
block_hash = hash_block_tokens(parent_block_hash,
44-
manager.block_pool[block_id].token_ids)
47+
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
48+
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
4549
assert manager.block_pool[block_id].block_hash == block_hash
4650
assert manager.block_pool[block_id].ref_cnt == 1
47-
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
48-
block_id + 1)
49-
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16)
5051
parent_block_hash = block_hash
5152

5253
# Check partial/preallocated block metadata
5354
for block_id in (3, 4):
5455
assert manager.block_pool[block_id].block_hash is None
5556
assert manager.block_pool[block_id].ref_cnt == 1
56-
assert manager.block_pool[block_id].num_hashed_tokens == 0
57-
if block_id == 3:
58-
assert manager.block_pool[block_id].token_ids == [3] * 7
59-
else:
60-
assert not manager.block_pool[block_id].token_ids
6157

6258
# Cache hit in the common prefix when the original block is still in use.
6359
# Incomplete 1 block (5 tokens)
@@ -113,7 +109,7 @@ def test_prefill():
113109
req3 = make_request("3", [99] * (16 * 9))
114110
computed_blocks = manager.get_computed_blocks(req3)
115111
assert not computed_blocks
116-
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks)
112+
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
117113
# This block ID order also checks the eviction order.
118114
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
119115
assert manager.free_block_queue.num_free_blocks == 0
@@ -148,7 +144,7 @@ def test_decode():
148144
req0.append_output_token_ids(8)
149145
new_blocks = manager.append_slots(req0, 4)
150146
assert new_blocks is not None and len(new_blocks) == 0
151-
assert len(manager.block_pool[3].token_ids) == 11
147+
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
152148

153149
# Append slots without allocating a new block, but start using the
154150
# preallocated block.
@@ -159,8 +155,7 @@ def test_decode():
159155
req0.append_output_token_ids(7)
160156
new_blocks = manager.append_slots(req0, 15)
161157
assert new_blocks is not None and len(new_blocks) == 0
162-
assert len(manager.block_pool[3].token_ids) == 16
163-
assert len(manager.block_pool[4].token_ids) == 10
158+
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
164159

165160
# Append slots with allocating a new block.
166161
req0.num_computed_tokens = 74
@@ -171,9 +166,6 @@ def test_decode():
171166
new_blocks = manager.append_slots(req0, 17)
172167
# Plus one preallocated block.
173168
assert new_blocks is not None and len(new_blocks) == 2
174-
assert len(manager.block_pool[4].token_ids) == 16
175-
assert len(manager.block_pool[5].token_ids) == 11
176-
assert len(manager.block_pool[6].token_ids) == 0
177169

178170

179171
def test_evict():
@@ -217,3 +209,198 @@ def test_evict():
217209
blocks = manager.allocate_slots(req2, 3, computed_blocks)
218210
assert [b.block_id for b in blocks] == [6, 5]
219211
assert manager.free_block_queue.num_free_blocks == 6
212+
213+
214+
def test_hash_block_correct_reuse():
215+
"""
216+
This tests when a previously cached block is reused as a new block,
217+
its hash metadata should be correctly reset.
218+
"""
219+
block_size = 16
220+
manager = KVCacheManager(
221+
block_size=block_size,
222+
num_gpu_blocks=1,
223+
sliding_window=False,
224+
enable_caching=True,
225+
num_preallocate_tokens=0,
226+
)
227+
228+
# Allocate 1 block and cache it.
229+
num_tokens = block_size * 1
230+
req = make_request("0", list(range(num_tokens)))
231+
computed_blocks = manager.get_computed_blocks(req)
232+
assert not computed_blocks
233+
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
234+
assert len(blocks) == 1
235+
236+
# Deallocate the block.
237+
manager.free(req)
238+
239+
# Allocate a new block that's not full, make sure hash info on the
240+
# block is cleared.
241+
req = make_request("1", list(range(num_tokens - 1)))
242+
computed_blocks = manager.get_computed_blocks(req)
243+
assert not computed_blocks
244+
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
245+
assert len(blocks) == 1
246+
247+
assert manager.block_pool[blocks[0].block_id].block_hash is None
248+
249+
250+
def test_computed_blocks_not_evicted():
251+
"""
252+
Test that the computed blocks are not evicted when getting new blocks
253+
for a request if there are any other free blocks.
254+
"""
255+
block_size = 16
256+
manager = KVCacheManager(
257+
block_size=block_size,
258+
num_gpu_blocks=2,
259+
sliding_window=False,
260+
enable_caching=True,
261+
num_preallocate_tokens=0,
262+
)
263+
264+
# Allocate a block and cache it.
265+
num_tokens = block_size * 1
266+
req0 = make_request("0", list(range(num_tokens)))
267+
computed_blocks = manager.get_computed_blocks(req0)
268+
assert not computed_blocks
269+
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
270+
assert len(blocks) == 1
271+
assert blocks[0].block_id == 0
272+
273+
# Allocate another block.
274+
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
275+
computed_blocks = manager.get_computed_blocks(req1)
276+
assert not computed_blocks
277+
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
278+
assert len(blocks) == 1
279+
assert blocks[0].block_id == 1
280+
281+
# Free the blocks.
282+
manager.free(req0)
283+
manager.free(req1)
284+
285+
# Now if we have a cache hit on the first block, we should evict the second
286+
# cached block rather than the first one.
287+
req2 = make_request("2", list(range(num_tokens * 2)))
288+
computed_blocks = manager.get_computed_blocks(req2)
289+
assert len(computed_blocks) == 1
290+
assert computed_blocks[0].block_id == 0
291+
292+
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
293+
computed_blocks)
294+
assert len(blocks) == 1
295+
assert blocks[0].block_id == 1
296+
297+
298+
def test_basic_prefix_caching_disabled():
299+
"""
300+
This tests that the prefix caching is disabled.
301+
"""
302+
block_size = 4
303+
manager = KVCacheManager(
304+
block_size=block_size,
305+
num_gpu_blocks=4,
306+
sliding_window=False,
307+
enable_caching=False,
308+
num_preallocate_tokens=0,
309+
)
310+
311+
req1 = make_request("1", list(range(10))) # 2 blocks and some more
312+
313+
computed_blocks = manager.get_computed_blocks(req1)
314+
assert not computed_blocks
315+
blocks = manager.allocate_slots(req1, 10, computed_blocks)
316+
assert len(blocks) == 3
317+
318+
# Free the blocks.
319+
manager.free(req1)
320+
321+
# No caching.
322+
req2 = make_request("2", list(range(16))) # shared prefix
323+
computed_blocks = manager.get_computed_blocks(req2)
324+
assert not computed_blocks
325+
blocks = manager.allocate_slots(req2, 16, computed_blocks)
326+
assert len(blocks) == 4
327+
328+
# New requests should not have any blocks.
329+
req3 = make_request("3", list(range(4)))
330+
computed_blocks = manager.get_computed_blocks(req3)
331+
assert not computed_blocks
332+
blocks = manager.allocate_slots(req3, 4, computed_blocks)
333+
assert not blocks
334+
335+
336+
@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
337+
@pytest.mark.parametrize("block_size", [4])
338+
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
339+
"""
340+
This tests that the preallocated blocks are correctly added.
341+
"""
342+
manager = KVCacheManager(
343+
block_size=block_size,
344+
num_gpu_blocks=10,
345+
sliding_window=False,
346+
enable_caching=True,
347+
num_preallocate_tokens=num_preallocate_tokens,
348+
)
349+
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
350+
351+
req = make_request("0", list(range(block_size * 30)))
352+
computed_blocks = manager.get_computed_blocks(req)
353+
assert not computed_blocks
354+
# Just ask for 1 block.
355+
blocks = manager.allocate_slots(req, block_size, computed_blocks)
356+
assert len(blocks) == 1 + num_preallocated_blocks
357+
358+
# Append slots to the block.
359+
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
360+
blocks = manager.append_slots(req, block_size) # Append 1 block.
361+
assert len(blocks) == 1 + num_preallocated_blocks
362+
363+
364+
def test_cache_blocks():
365+
"""
366+
This is a unit test that tests the correctness of the _cache_full_blocks
367+
function of KVCacheManager.
368+
"""
369+
block_size = 4
370+
manager = KVCacheManager(
371+
block_size=block_size,
372+
num_gpu_blocks=5,
373+
sliding_window=False,
374+
enable_caching=True,
375+
num_preallocate_tokens=0,
376+
)
377+
# Req:
378+
# Block 0: [0, 1, 2, 3]
379+
# Block 1: [4, 5, 6, 7]
380+
# Block 2: [8, 9, 10, 11]
381+
# Block 3: [12, 13]
382+
req = make_request("0", list(range(14)))
383+
384+
# Test that blocks are cached correctly for 2 full blocks from the start.
385+
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
386+
387+
manager._cache_full_blocks(
388+
request=req,
389+
blk_start_idx=0,
390+
full_blocks=blocks,
391+
prev_block=None,
392+
)
393+
394+
assert len(manager.cached_block_hash_to_block) == 2
395+
assert all([block.block_hash is not None for block in blocks])
396+
397+
# Test that blocks that don't start from the beginning are cached correctly.
398+
blocks = [KVCacheBlock(block_id=2)]
399+
manager._cache_full_blocks(
400+
request=req,
401+
blk_start_idx=2,
402+
full_blocks=blocks,
403+
prev_block=None,
404+
)
405+
assert len(manager.cached_block_hash_to_block) == 3
406+
assert blocks[0].block_hash is not None

0 commit comments

Comments
 (0)