Skip to content

Commit

Permalink
[Core] Support reset_prefix_cache (vllm-project#12284)
Browse files Browse the repository at this point in the history
Signed-off-by: Bowen Wang <[email protected]>
  • Loading branch information
comaniac authored and abmfy committed Jan 24, 2025
1 parent af9a034 commit cf60957
Show file tree
Hide file tree
Showing 27 changed files with 300 additions and 21 deletions.
38 changes: 38 additions & 0 deletions tests/core/block/test_prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,44 @@ def test_find_cached_blocks_prefix():
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks

# Test reset prefix cache
@staticmethod
@pytest.mark.parametrize("num_blocks", [10])
@pytest.mark.parametrize("block_size", [16])
def test_reset_prefix_cache(num_blocks: int, block_size: int):
"""This test case simulates the case of resetting the prefix cache."""

allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
token_ids = list(range(3 * block_size))

first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)

# Free each block in the first chain.
for block in first_chain:
allocator.free(block)

# Failed to reset prefix cache because some blocks are not freed yet.
assert not allocator.reset_prefix_cache()
assert allocator.get_prefix_cache_hit_rate() > 0.0

# Free each block in the second chain.
for block in second_chain:
allocator.free(block)

# Reset prefix cache.
assert allocator.reset_prefix_cache()
assert allocator.get_prefix_cache_hit_rate() == 0.0

@staticmethod
def create_immutable_chain(
block_size: int,
Expand Down
39 changes: 39 additions & 0 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,42 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free.
assert {block.ref_cnt for block in block_part1[3:]} == {0}


def test_reset_prefix_cache():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)

full_block_token_ids = [i for i in range(3) for _ in range(16)]
unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55, [])
assert [b.block_id for b in blocks] == [0, 1, 2, 3]

unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [4]

# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
assert manager.cached_block_hash_to_block

# Free the blocks.
manager.free(req0)
manager.free(req1)

assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])
7 changes: 7 additions & 0 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,13 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
success = True
for allocator in self._allocators.values():
success = success and allocator.reset_prefix_cache()
return success

def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
Expand Down
10 changes: 10 additions & 0 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass

class NoFreeBlocksError(ValueError):
pass

Expand Down Expand Up @@ -297,6 +302,11 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache."""
pass

@abstractmethod
def find_cached_blocks_prefix(
self,
Expand Down
19 changes: 14 additions & 5 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union

from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
Expand Down Expand Up @@ -136,16 +136,18 @@ def _allocate_block_id(self) -> BlockId:
self._refcounter.incr(block_id)
return block_id

def _free_block_id(self, block: Block) -> None:
block_id = block.block_id
def _free_block_id(self, block: Union[Block, BlockId]) -> None:
if isinstance(block, Block):
block_id = block.block_id
block.block_id = None
else:
block_id = block
assert block_id is not None

refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.appendleft(block_id)

block.block_id = None

def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id
self._free_block_id(block)
Expand All @@ -154,6 +156,9 @@ def free(self, block: Block, keep_block_object: bool = False) -> None:
if not keep_block_object:
self._block_pool.free_block(block)

def free_block_id(self, block_id: BlockId) -> None:
self._free_block_id(block_id)

def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Expand Down Expand Up @@ -325,6 +330,10 @@ def swap_in(self, blocks: List[Block]) -> None:
def get_prefix_cache_hit_rate(self) -> float:
return -1

def reset_prefix_cache(self) -> bool:
"""No prefix cache for naive block allocator."""
return True

def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]:
# Not applicable for naive block allocator.
return []
Expand Down
44 changes: 43 additions & 1 deletion vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
NaiveBlockAllocator)
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.logger import init_logger
from vllm.sequence import Sequence

PrefixHash = int
Expand All @@ -21,6 +22,8 @@
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1

logger = init_logger(__name__)


class BlockTracker:
"""Used to track the status of a block inside the prefix caching allocator
Expand Down Expand Up @@ -105,7 +108,8 @@ def __init__(

# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy)
self.eviction_policy = eviction_policy
self.evictor: Evictor = make_evictor(self.eviction_policy)

# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
Expand Down Expand Up @@ -428,6 +432,44 @@ def all_block_ids(self) -> FrozenSet[int]:
def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.get_num_total_blocks() -
self.get_num_free_blocks())
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False

# Free all blocks in the evictor.
while (block_id :=
self._maybe_allocate_evicted_block_id()) is not None:
self._hashless_allocator.free_block_id(block_id)

# Should not have any cached blocks because all blocks are evicted.
assert not self._cached_blocks

# Reset the evictor.
self.evictor = make_evictor(self.eviction_policy)

# Reset the block tracker.
for block_id in self._block_tracker:
self._block_tracker[block_id] = BlockTracker()

# Reset the metrics.
self.metric_data = CacheMetricData()

logger.info("Successfully reset prefix cache")
return True

def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
return block.content_hash in self._cached_blocks
Expand Down
3 changes: 3 additions & 0 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ def get_num_free_cpu_blocks(self) -> int:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)

def reset_prefix_cache(self) -> bool:
return self.block_allocator.reset_prefix_cache()

def _can_swap(self,
seq_group: SequenceGroup,
device: Device,
Expand Down
5 changes: 5 additions & 0 deletions vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
pass

@abstractmethod
def get_num_cached_tokens(self, seq: Sequence) -> int:
pass
3 changes: 3 additions & 0 deletions vllm/core/placeholder_block_space_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup,
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1

def reset_prefix_cache(self) -> bool:
return True

def get_num_cached_tokens(self, seq: Sequence) -> int:
return 0
3 changes: 3 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,9 @@ def has_unfinished_seqs(self) -> bool:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)

def reset_prefix_cache(self) -> bool:
return self.block_manager.reset_prefix_cache()

def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,9 @@ async def start_profile(self) -> None:
async def stop_profile(self) -> None:
self.engine.stop_profile()

async def reset_prefix_cache(self) -> None:
self.engine.reset_prefix_cache()

async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request)

Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,14 @@ def has_unfinished_requests_for_virtual_engine(
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""

success = True
for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache()
return success

@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,
Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2


class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1


@dataclass
class RPCLoadAdapterRequest:
lora_request: LoRARequest
Expand All @@ -134,7 +138,8 @@ class RPCAdapterLoadedResponse:


RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest]
RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetPrefixCacheRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError]
Expand Down
12 changes: 10 additions & 2 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable
Expand Down Expand Up @@ -675,6 +676,13 @@ async def stop_profile(self) -> None:
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)

async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache"""

await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
socket=self.input_socket)

async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests
Expand Down
10 changes: 8 additions & 2 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.logger import init_logger
Expand Down Expand Up @@ -237,6 +238,8 @@ def handle_new_input(self):
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
else:
raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}")
Expand Down Expand Up @@ -361,6 +364,9 @@ def start_profile(self) -> None:
def stop_profile(self) -> None:
self.engine.stop_profile()

def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()


def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")
Expand Down
Loading

0 comments on commit cf60957

Please sign in to comment.