1
1
"""Compare the with and without prefix caching."""
2
+ import pytest
3
+
2
4
from vllm .inputs import token_inputs
3
5
from vllm .sampling_params import SamplingParams
6
+ from vllm .utils import cdiv
4
7
from vllm .v1 .core .kv_cache_manager import KVCacheManager , Request
5
- from vllm .v1 .core .kv_cache_utils import hash_block_tokens
8
+ from vllm .v1 .core .kv_cache_utils import KVCacheBlock , hash_block_tokens
6
9
7
10
8
11
def make_request (request_id , prompt_token_ids ):
@@ -31,7 +34,8 @@ def test_prefill():
31
34
# Fully cache miss
32
35
# Incomplete 1 block (7 tokens)
33
36
unique_token_ids = [3 ] * 7
34
- req0 = make_request ("0" , common_token_ids + unique_token_ids )
37
+ all_token_ids = common_token_ids + unique_token_ids
38
+ req0 = make_request ("0" , all_token_ids )
35
39
computed_blocks = manager .get_computed_blocks (req0 )
36
40
assert not computed_blocks
37
41
blocks = manager .allocate_slots (req0 , 55 , computed_blocks )
@@ -40,24 +44,16 @@ def test_prefill():
40
44
# Check full block metadata
41
45
parent_block_hash = None
42
46
for block_id in (0 , 1 , 2 ):
43
- block_hash = hash_block_tokens ( parent_block_hash ,
44
- manager . block_pool [ block_id ]. token_ids )
47
+ block_tokens = tuple ( all_token_ids [ block_id * 16 :( block_id + 1 ) * 16 ])
48
+ block_hash = hash_block_tokens ( parent_block_hash , block_tokens )
45
49
assert manager .block_pool [block_id ].block_hash == block_hash
46
50
assert manager .block_pool [block_id ].ref_cnt == 1
47
- assert manager .block_pool [block_id ].num_hashed_tokens == 16 * (
48
- block_id + 1 )
49
- assert manager .block_pool [block_id ].token_ids == tuple ([block_id ] * 16 )
50
51
parent_block_hash = block_hash
51
52
52
53
# Check partial/preallocated block metadata
53
54
for block_id in (3 , 4 ):
54
55
assert manager .block_pool [block_id ].block_hash is None
55
56
assert manager .block_pool [block_id ].ref_cnt == 1
56
- assert manager .block_pool [block_id ].num_hashed_tokens == 0
57
- if block_id == 3 :
58
- assert manager .block_pool [block_id ].token_ids == [3 ] * 7
59
- else :
60
- assert not manager .block_pool [block_id ].token_ids
61
57
62
58
# Cache hit in the common prefix when the original block is still in use.
63
59
# Incomplete 1 block (5 tokens)
@@ -113,7 +109,7 @@ def test_prefill():
113
109
req3 = make_request ("3" , [99 ] * (16 * 9 ))
114
110
computed_blocks = manager .get_computed_blocks (req3 )
115
111
assert not computed_blocks
116
- blocks = manager .allocate_slots (req2 , 16 * 9 , computed_blocks )
112
+ blocks = manager .allocate_slots (req3 , 16 * 9 , computed_blocks )
117
113
# This block ID order also checks the eviction order.
118
114
assert [b .block_id for b in blocks ] == [9 , 4 , 3 , 6 , 5 , 8 , 7 , 2 , 1 , 0 ]
119
115
assert manager .free_block_queue .num_free_blocks == 0
@@ -148,7 +144,7 @@ def test_decode():
148
144
req0 .append_output_token_ids (8 )
149
145
new_blocks = manager .append_slots (req0 , 4 )
150
146
assert new_blocks is not None and len (new_blocks ) == 0
151
- assert len ( manager .block_pool [ 3 ]. token_ids ) == 11
147
+ assert manager .req_to_blocks [ req0 . request_id ][ - 2 ]. block_hash is None
152
148
153
149
# Append slots without allocating a new block, but start using the
154
150
# preallocated block.
@@ -159,8 +155,7 @@ def test_decode():
159
155
req0 .append_output_token_ids (7 )
160
156
new_blocks = manager .append_slots (req0 , 15 )
161
157
assert new_blocks is not None and len (new_blocks ) == 0
162
- assert len (manager .block_pool [3 ].token_ids ) == 16
163
- assert len (manager .block_pool [4 ].token_ids ) == 10
158
+ assert manager .req_to_blocks [req0 .request_id ][- 2 ].block_hash is not None
164
159
165
160
# Append slots with allocating a new block.
166
161
req0 .num_computed_tokens = 74
@@ -171,9 +166,6 @@ def test_decode():
171
166
new_blocks = manager .append_slots (req0 , 17 )
172
167
# Plus one preallocated block.
173
168
assert new_blocks is not None and len (new_blocks ) == 2
174
- assert len (manager .block_pool [4 ].token_ids ) == 16
175
- assert len (manager .block_pool [5 ].token_ids ) == 11
176
- assert len (manager .block_pool [6 ].token_ids ) == 0
177
169
178
170
179
171
def test_evict ():
@@ -217,3 +209,198 @@ def test_evict():
217
209
blocks = manager .allocate_slots (req2 , 3 , computed_blocks )
218
210
assert [b .block_id for b in blocks ] == [6 , 5 ]
219
211
assert manager .free_block_queue .num_free_blocks == 6
212
+
213
+
214
+ def test_hash_block_correct_reuse ():
215
+ """
216
+ This tests when a previously cached block is reused as a new block,
217
+ its hash metadata should be correctly reset.
218
+ """
219
+ block_size = 16
220
+ manager = KVCacheManager (
221
+ block_size = block_size ,
222
+ num_gpu_blocks = 1 ,
223
+ sliding_window = False ,
224
+ enable_caching = True ,
225
+ num_preallocate_tokens = 0 ,
226
+ )
227
+
228
+ # Allocate 1 block and cache it.
229
+ num_tokens = block_size * 1
230
+ req = make_request ("0" , list (range (num_tokens )))
231
+ computed_blocks = manager .get_computed_blocks (req )
232
+ assert not computed_blocks
233
+ blocks = manager .allocate_slots (req , num_tokens , computed_blocks )
234
+ assert len (blocks ) == 1
235
+
236
+ # Deallocate the block.
237
+ manager .free (req )
238
+
239
+ # Allocate a new block that's not full, make sure hash info on the
240
+ # block is cleared.
241
+ req = make_request ("1" , list (range (num_tokens - 1 )))
242
+ computed_blocks = manager .get_computed_blocks (req )
243
+ assert not computed_blocks
244
+ blocks = manager .allocate_slots (req , num_tokens - 1 , computed_blocks )
245
+ assert len (blocks ) == 1
246
+
247
+ assert manager .block_pool [blocks [0 ].block_id ].block_hash is None
248
+
249
+
250
+ def test_computed_blocks_not_evicted ():
251
+ """
252
+ Test that the computed blocks are not evicted when getting new blocks
253
+ for a request if there are any other free blocks.
254
+ """
255
+ block_size = 16
256
+ manager = KVCacheManager (
257
+ block_size = block_size ,
258
+ num_gpu_blocks = 2 ,
259
+ sliding_window = False ,
260
+ enable_caching = True ,
261
+ num_preallocate_tokens = 0 ,
262
+ )
263
+
264
+ # Allocate a block and cache it.
265
+ num_tokens = block_size * 1
266
+ req0 = make_request ("0" , list (range (num_tokens )))
267
+ computed_blocks = manager .get_computed_blocks (req0 )
268
+ assert not computed_blocks
269
+ blocks = manager .allocate_slots (req0 , num_tokens , computed_blocks )
270
+ assert len (blocks ) == 1
271
+ assert blocks [0 ].block_id == 0
272
+
273
+ # Allocate another block.
274
+ req1 = make_request ("1" , list (range (num_tokens , num_tokens * 2 )))
275
+ computed_blocks = manager .get_computed_blocks (req1 )
276
+ assert not computed_blocks
277
+ blocks = manager .allocate_slots (req1 , num_tokens , computed_blocks )
278
+ assert len (blocks ) == 1
279
+ assert blocks [0 ].block_id == 1
280
+
281
+ # Free the blocks.
282
+ manager .free (req0 )
283
+ manager .free (req1 )
284
+
285
+ # Now if we have a cache hit on the first block, we should evict the second
286
+ # cached block rather than the first one.
287
+ req2 = make_request ("2" , list (range (num_tokens * 2 )))
288
+ computed_blocks = manager .get_computed_blocks (req2 )
289
+ assert len (computed_blocks ) == 1
290
+ assert computed_blocks [0 ].block_id == 0
291
+
292
+ blocks = manager .allocate_slots (req2 , num_tokens * 2 - num_tokens ,
293
+ computed_blocks )
294
+ assert len (blocks ) == 1
295
+ assert blocks [0 ].block_id == 1
296
+
297
+
298
+ def test_basic_prefix_caching_disabled ():
299
+ """
300
+ This tests that the prefix caching is disabled.
301
+ """
302
+ block_size = 4
303
+ manager = KVCacheManager (
304
+ block_size = block_size ,
305
+ num_gpu_blocks = 4 ,
306
+ sliding_window = False ,
307
+ enable_caching = False ,
308
+ num_preallocate_tokens = 0 ,
309
+ )
310
+
311
+ req1 = make_request ("1" , list (range (10 ))) # 2 blocks and some more
312
+
313
+ computed_blocks = manager .get_computed_blocks (req1 )
314
+ assert not computed_blocks
315
+ blocks = manager .allocate_slots (req1 , 10 , computed_blocks )
316
+ assert len (blocks ) == 3
317
+
318
+ # Free the blocks.
319
+ manager .free (req1 )
320
+
321
+ # No caching.
322
+ req2 = make_request ("2" , list (range (16 ))) # shared prefix
323
+ computed_blocks = manager .get_computed_blocks (req2 )
324
+ assert not computed_blocks
325
+ blocks = manager .allocate_slots (req2 , 16 , computed_blocks )
326
+ assert len (blocks ) == 4
327
+
328
+ # New requests should not have any blocks.
329
+ req3 = make_request ("3" , list (range (4 )))
330
+ computed_blocks = manager .get_computed_blocks (req3 )
331
+ assert not computed_blocks
332
+ blocks = manager .allocate_slots (req3 , 4 , computed_blocks )
333
+ assert not blocks
334
+
335
+
336
+ @pytest .mark .parametrize ("num_preallocate_tokens" , list (range (0 , 8 )))
337
+ @pytest .mark .parametrize ("block_size" , [4 ])
338
+ def test_preallocate_blocks (num_preallocate_tokens : int , block_size : int ):
339
+ """
340
+ This tests that the preallocated blocks are correctly added.
341
+ """
342
+ manager = KVCacheManager (
343
+ block_size = block_size ,
344
+ num_gpu_blocks = 10 ,
345
+ sliding_window = False ,
346
+ enable_caching = True ,
347
+ num_preallocate_tokens = num_preallocate_tokens ,
348
+ )
349
+ num_preallocated_blocks = cdiv (num_preallocate_tokens , block_size )
350
+
351
+ req = make_request ("0" , list (range (block_size * 30 )))
352
+ computed_blocks = manager .get_computed_blocks (req )
353
+ assert not computed_blocks
354
+ # Just ask for 1 block.
355
+ blocks = manager .allocate_slots (req , block_size , computed_blocks )
356
+ assert len (blocks ) == 1 + num_preallocated_blocks
357
+
358
+ # Append slots to the block.
359
+ req .num_computed_tokens = block_size * len (blocks ) # Assume all used.
360
+ blocks = manager .append_slots (req , block_size ) # Append 1 block.
361
+ assert len (blocks ) == 1 + num_preallocated_blocks
362
+
363
+
364
+ def test_cache_blocks ():
365
+ """
366
+ This is a unit test that tests the correctness of the _cache_full_blocks
367
+ function of KVCacheManager.
368
+ """
369
+ block_size = 4
370
+ manager = KVCacheManager (
371
+ block_size = block_size ,
372
+ num_gpu_blocks = 5 ,
373
+ sliding_window = False ,
374
+ enable_caching = True ,
375
+ num_preallocate_tokens = 0 ,
376
+ )
377
+ # Req:
378
+ # Block 0: [0, 1, 2, 3]
379
+ # Block 1: [4, 5, 6, 7]
380
+ # Block 2: [8, 9, 10, 11]
381
+ # Block 3: [12, 13]
382
+ req = make_request ("0" , list (range (14 )))
383
+
384
+ # Test that blocks are cached correctly for 2 full blocks from the start.
385
+ blocks = [KVCacheBlock (block_id = i ) for i in range (2 )]
386
+
387
+ manager ._cache_full_blocks (
388
+ request = req ,
389
+ blk_start_idx = 0 ,
390
+ full_blocks = blocks ,
391
+ prev_block = None ,
392
+ )
393
+
394
+ assert len (manager .cached_block_hash_to_block ) == 2
395
+ assert all ([block .block_hash is not None for block in blocks ])
396
+
397
+ # Test that blocks that don't start from the beginning are cached correctly.
398
+ blocks = [KVCacheBlock (block_id = 2 )]
399
+ manager ._cache_full_blocks (
400
+ request = req ,
401
+ blk_start_idx = 2 ,
402
+ full_blocks = blocks ,
403
+ prev_block = None ,
404
+ )
405
+ assert len (manager .cached_block_hash_to_block ) == 3
406
+ assert blocks [0 ].block_hash is not None
0 commit comments