Skip to content

Commit

Permalink
[V1] Add uncache_blocks (vllm-project#12333)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored and tjtanaa committed Jan 28, 2025
1 parent 8dab4e9 commit 7e5655a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
30 changes: 30 additions & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,33 @@ def test_reset_prefix_cache():
assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])


def test_uncache_blocks():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)

req0 = make_request("0", list(range(30)))
blocks = manager.allocate_slots(req0, 30, [])
assert [b.block_id for b in blocks] == [0, 1]
assert len(manager.cached_block_hash_to_block) == 1

req0.num_computed_tokens = 30

# Simulate speculative tokens.
for _ in range(5):
req0.append_output_token_ids(8)
manager.append_slots(req0, 5)
assert len(manager.cached_block_hash_to_block) == 2

# After sampling, assuming only 1 token is accepted.
req0.num_computed_tokens = 31
num_uncached_blocks = manager.uncache_blocks(req0)
assert num_uncached_blocks == 1
assert len(manager.cached_block_hash_to_block) == 1
33 changes: 31 additions & 2 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,29 @@ def free(self, request: Request) -> None:
if block.ref_cnt == 0:
self.free_block_queue.append(block)

def uncache_blocks(self, request: Request) -> int:
"""Uncache the blocks that are no longer full based on the
num_computed_tokens in the given request. This happens when
the blocks were full and cached due to speculative tokens, but the
speculative tokens are not accepted.
Args:
request: The request.
Returns:
The number of uncached blocks.
"""
blocks = self.req_to_blocks[request.request_id]
num_computed_tokens = request.num_computed_tokens
num_full_blocks = num_computed_tokens // self.block_size
num_uncached_blocks = 0
for block in blocks[num_full_blocks:]:
# If the block is not cached, the following blocks are not cached.
if not self._maybe_evict_cached_block(block):
break
num_uncached_blocks += 1
return num_uncached_blocks

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
Expand Down Expand Up @@ -386,21 +409,24 @@ def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:

# If the block is cached, evict it.
if self.enable_caching:
self._evict_cached_block(curr_block)
self._maybe_evict_cached_block(curr_block)

curr_block.incr_ref()
ret.append(curr_block)
idx += 1

return ret

def _evict_cached_block(self, block: KVCacheBlock) -> None:
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.
Args:
block: The block to evict.
Returns:
True if the block is evicted, False otherwise.
"""
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
Expand All @@ -410,6 +436,9 @@ def _evict_cached_block(self, block: KVCacheBlock) -> None:
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]

return True
return False

def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
Expand Down

0 comments on commit 7e5655a

Please sign in to comment.