Skip to content

Commit

Permalink
improve
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jan 31, 2025
1 parent 19475f7 commit c5196d9
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 26 deletions.
31 changes: 13 additions & 18 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def __init__(
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}

# Prefix cache metrics.
self.prefix_caching_metrics: PrefixCachingMetrics = {
"query_total": 0,
"query_hit": 0,
}
self.prefix_caching_metrics = PrefixCachingMetrics()

@property
def usage(self) -> float:
Expand All @@ -85,21 +82,14 @@ def usage(self) -> float:
return 1.0 - (self.free_block_queue.num_free_blocks /
self.num_gpu_blocks)

def get_and_reset_prefix_cache_hit_rate(self) -> float:
"""Get the overall hit rate of prefix caching and reset
the metrics.
@property
def prefix_cache_hit_rate(self) -> float:
"""Get the prefix caching hit rate.
Returns:
The hit rate of prefix caching (between 0.0 and 1.0).
The prefix caching hit rate.
"""
hit_rate = 0.0
if self.prefix_caching_metrics["query_total"] > 0:
hit_rate = self.prefix_caching_metrics[
"query_hit"] / self.prefix_caching_metrics["query_total"]

self.prefix_caching_metrics["query_hit"] = 0
self.prefix_caching_metrics["query_total"] = 0
return hit_rate
return self.prefix_caching_metrics.hit_rate

def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
Expand Down Expand Up @@ -136,8 +126,10 @@ def get_computed_blocks(
else:
break

self.prefix_caching_metrics["query_total"] += len(block_hashes)
self.prefix_caching_metrics["query_hit"] += len(computed_blocks)
self.prefix_caching_metrics.add_request_query(
num_queries=len(block_hashes),
num_hits=len(computed_blocks),
)

# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
Expand Down Expand Up @@ -367,6 +359,9 @@ def reset_prefix_cache(self) -> bool:
for block in self.block_pool:
block.reset_hash()

# Reset the prefix caching metrics.
self.prefix_caching_metrics.reset()

logger.info("Successfully reset prefix cache")
return True

Expand Down
53 changes: 46 additions & 7 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""KV-Cache Utilities."""
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple, TypedDict
from typing import Any, List, NamedTuple, Optional, Tuple

from vllm.config import VllmConfig
from vllm.logger import init_logger
Expand All @@ -27,14 +28,52 @@ class BlockHashType(NamedTuple):
extra_keys: Optional[Any] = None


class PrefixCachingMetrics(TypedDict):
"""Metrics for prefix caching."""
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the most recent N requests.
query_total: int
"""The total number of queries."""
Args:
interval: The number of the most recent requests to aggregate.
Defaults to 1000.
"""

def __init__(self, interval: int = 1000):
self.interval = interval
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.request_queries: deque[Tuple[int, int]] = deque()

query_hit: int
"""The number of queries that hit the prefix cache."""
def add_request_query(self, num_queries: int, num_hits: int):
"""Add a request to the metrics. This function is called when
a new request is being scheduled and is looking for computed blocks.
When there are more than `interval` requests, the oldest request
is removed from the metrics.
Args:
num_queries: The number of queries in the request.
num_hits: The number of hits in the request.
"""

self.request_queries.append((num_queries, num_hits))
if len(self.request_queries) > self.interval:
old_num_queries, old_num_hits = self.request_queries.popleft()
self.aggregated_query_total -= old_num_queries
self.aggregated_query_hit -= old_num_hits

self.aggregated_query_total += num_queries
self.aggregated_query_hit += num_hits

def reset(self):
"""Reset the metrics."""
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.request_queries.clear()

@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def make_stats(self) -> SchedulerStats:
num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage,
gpu_prefix_cache_hit_rate=self.kv_cache_manager.
get_and_reset_prefix_cache_hit_rate(),
prefix_cache_hit_rate,
)


Expand Down

0 comments on commit c5196d9

Please sign in to comment.