Skip to content

Commit 24f7605

Browse files
authored
Qdrant cache (#73)
* -created the semantic cache using qdrant * -updated readme
1 parent 3c06654 commit 24f7605

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
QDRANT_URL='http://localhost:6333'
2+
QDRANT_API_KEY='th3s3cr3tk3y'
3+
4+
OLLAMA_MODEL='llama3.2:latest'
5+
OLLAMA_BASE_URL='http://localhost:11434'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
## Qdrant Semantic Cache
2+
3+
- `pip install -r requirements.txt`
4+
- `python semantic_cache.py`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
qdrant-client==1.12.0
2+
python-dotenv==1.0.1
3+
fastembed==0.4.1
4+
datasets==3.0.1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import uuid
2+
import os
3+
from dotenv import load_dotenv, find_dotenv
4+
from llama_index.core.base.llms.types import ChatMessage, MessageRole
5+
from qdrant_client import QdrantClient
6+
from qdrant_client.http.models import PointStruct, SearchParams, VectorParams, Distance
7+
from sentence_transformers import SentenceTransformer
8+
from llama_index.llms.ollama import Ollama
9+
10+
11+
class SemanticCache:
12+
def __init__(self, threshold=0.35):
13+
# load the data from env
14+
load_dotenv(find_dotenv())
15+
16+
self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
17+
self.cache_client = QdrantClient(url=os.environ.get('QDRANT_URL'), api_key=os.environ.get('QDRANT_API_KEY'))
18+
self.cache_collection_name = "cache"
19+
self.threshold = threshold
20+
21+
# Create the cache collection
22+
if not self.cache_client.collection_exists(collection_name=self.cache_collection_name):
23+
self.cache_client.create_collection(
24+
collection_name=self.cache_collection_name,
25+
vectors_config=VectorParams(
26+
size=384,
27+
distance=Distance.COSINE
28+
)
29+
)
30+
31+
def get_embedding(self, text):
32+
return self.encoder.encode([text])[0]
33+
34+
def search_cache(self, query):
35+
query_vector = self.get_embedding(query)
36+
search_result = self.cache_client.search(
37+
collection_name=self.cache_collection_name,
38+
query_vector=query_vector,
39+
limit=1,
40+
search_params=SearchParams(hnsw_ef=128)
41+
)
42+
if search_result and search_result[0].score > self.threshold:
43+
return search_result[0].payload['response']
44+
return None
45+
46+
def add_to_cache(self, query, response):
47+
query_vector = self.get_embedding(query)
48+
point = PointStruct(
49+
id=str(uuid.uuid4()),
50+
vector=query_vector,
51+
payload={"query": query, "response": response}
52+
)
53+
self.cache_client.upsert(
54+
collection_name=self.cache_collection_name,
55+
points=[point]
56+
)
57+
58+
def get_response(self, query, compute_response_func):
59+
cached_response = self.search_cache(query)
60+
if cached_response:
61+
return cached_response
62+
_response = compute_response_func(query)
63+
self.add_to_cache(query, _response)
64+
return _response
65+
66+
67+
# Example usage
68+
def compute_response(query: str):
69+
llm = Ollama(model=os.environ.get('OLLAMA_MODEL'), base_url=os.environ.get('OLLAMA_BASE_URL'))
70+
# Create a user message
71+
user_message = ChatMessage(
72+
role=MessageRole.USER,
73+
content=query
74+
)
75+
76+
# Generate a response from the assistant
77+
assistant_message = llm.chat(messages=[user_message])
78+
return f"Computed response for: {query} is {assistant_message}"
79+
80+
81+
semantic_cache = SemanticCache(threshold=0.8)
82+
query = "What is the capital of France?"
83+
response = semantic_cache.get_response(query, compute_response)
84+
print(response)

0 commit comments

Comments
 (0)