Skip to content

Commit

Permalink
[V1] Move KV block hashes from Request to KVCacheManager (#12922)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Feb 8, 2025
1 parent b21f0f9 commit 3243158
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 31 deletions.
21 changes: 11 additions & 10 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_prefill():
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
Expand All @@ -76,7 +76,7 @@ def test_prefill():
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_prefill():
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
Expand Down Expand Up @@ -494,10 +494,11 @@ def test_mm_prefix_caching():
# Completed block should have hashes with extra keys.
assert not computed_blocks
assert num_computed_tokens == 0
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == ("aaa", )
assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb")
assert req0.kv_block_hashes[2].extra_keys == ("bbb", )
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("aaa", )
assert block_hashes[1].extra_keys == ("aaa", "bbb")
assert block_hashes[2].extra_keys == ("bbb", )

blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
Expand All @@ -510,8 +511,8 @@ def test_mm_prefix_caching():
assert new_blocks is not None and len(new_blocks) == 0

# The just completed block should have hashes with extra keys.
assert len(req0.kv_block_hashes) == 4
assert req0.kv_block_hashes[3].extra_keys == ("ccc", )
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys == ("ccc", )

# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
Expand Down Expand Up @@ -613,7 +614,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [4]
Expand Down
31 changes: 23 additions & 8 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def __init__(
self.req_to_blocks: DefaultDict[str,
List[KVCacheBlock]] = defaultdict(list)

# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: DefaultDict[
str, List[BlockHashType]] = defaultdict(list)

@property
def usage(self) -> float:
return 1.0 - (self.free_block_queue.num_free_blocks /
Expand All @@ -97,11 +103,11 @@ def get_computed_blocks(
computed_blocks = []

# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
block_hashes = hash_request_tokens(self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes

for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
Expand Down Expand Up @@ -435,7 +441,8 @@ def _cache_full_blocks(
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
num_cached_block_hashes = len(request.kv_block_hashes)
block_hashes = self.req_to_block_hashes[request.request_id]
num_cached_block_hashes = len(block_hashes)

# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
Expand Down Expand Up @@ -468,7 +475,7 @@ def _cache_full_blocks(
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
block_hash = block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
Expand All @@ -490,9 +497,17 @@ def _cache_full_blocks(
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)
block_hashes.append(block_hash)

# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash_value = block_hash.hash_value

def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)
1 change: 1 addition & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ def finish_requests(
def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
Expand Down
13 changes: 0 additions & 13 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
if TYPE_CHECKING:
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_utils import BlockHashType


class Request:
Expand Down Expand Up @@ -63,11 +62,6 @@ def __init__(
if self.mm_hashes:
assert len(self.mm_inputs) == len(self.mm_hashes)

# Cache the computed kv block hashes of the request to avoid
# recomputing.
self._kv_block_hashes: List[BlockHashType] = []
self.kv_block_hashes = ConstantList(self._kv_block_hashes)

# Read-only views
# Prevent directly appending to the these lists since
# they should also be updated simultaneously.
Expand Down Expand Up @@ -124,13 +118,6 @@ def get_num_encoder_tokens(self, input_id: int) -> int:
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens

def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
self._kv_block_hashes = value
self.kv_block_hashes = ConstantList(self._kv_block_hashes)

def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
self._kv_block_hashes.append(block_hash)


class RequestStatus(enum.IntEnum):
"""Status of a request."""
Expand Down

0 comments on commit 3243158

Please sign in to comment.