|
| 1 | +import uuid |
| 2 | +import os |
| 3 | +from dotenv import load_dotenv, find_dotenv |
| 4 | +from llama_index.core.base.llms.types import ChatMessage, MessageRole |
| 5 | +from qdrant_client import QdrantClient |
| 6 | +from qdrant_client.http.models import PointStruct, SearchParams, VectorParams, Distance |
| 7 | +from sentence_transformers import SentenceTransformer |
| 8 | +from llama_index.llms.ollama import Ollama |
| 9 | + |
| 10 | + |
| 11 | +class SemanticCache: |
| 12 | + def __init__(self, threshold=0.35): |
| 13 | + # load the data from env |
| 14 | + load_dotenv(find_dotenv()) |
| 15 | + |
| 16 | + self.encoder = SentenceTransformer('all-MiniLM-L6-v2') |
| 17 | + self.cache_client = QdrantClient(url=os.environ.get('QDRANT_URL'), api_key=os.environ.get('QDRANT_API_KEY')) |
| 18 | + self.cache_collection_name = "cache" |
| 19 | + self.threshold = threshold |
| 20 | + |
| 21 | + # Create the cache collection |
| 22 | + if not self.cache_client.collection_exists(collection_name=self.cache_collection_name): |
| 23 | + self.cache_client.create_collection( |
| 24 | + collection_name=self.cache_collection_name, |
| 25 | + vectors_config=VectorParams( |
| 26 | + size=384, |
| 27 | + distance=Distance.COSINE |
| 28 | + ) |
| 29 | + ) |
| 30 | + |
| 31 | + def get_embedding(self, text): |
| 32 | + return self.encoder.encode([text])[0] |
| 33 | + |
| 34 | + def search_cache(self, query): |
| 35 | + query_vector = self.get_embedding(query) |
| 36 | + search_result = self.cache_client.search( |
| 37 | + collection_name=self.cache_collection_name, |
| 38 | + query_vector=query_vector, |
| 39 | + limit=1, |
| 40 | + search_params=SearchParams(hnsw_ef=128) |
| 41 | + ) |
| 42 | + if search_result and search_result[0].score > self.threshold: |
| 43 | + return search_result[0].payload['response'] |
| 44 | + return None |
| 45 | + |
| 46 | + def add_to_cache(self, query, response): |
| 47 | + query_vector = self.get_embedding(query) |
| 48 | + point = PointStruct( |
| 49 | + id=str(uuid.uuid4()), |
| 50 | + vector=query_vector, |
| 51 | + payload={"query": query, "response": response} |
| 52 | + ) |
| 53 | + self.cache_client.upsert( |
| 54 | + collection_name=self.cache_collection_name, |
| 55 | + points=[point] |
| 56 | + ) |
| 57 | + |
| 58 | + def get_response(self, query, compute_response_func): |
| 59 | + cached_response = self.search_cache(query) |
| 60 | + if cached_response: |
| 61 | + return cached_response |
| 62 | + _response = compute_response_func(query) |
| 63 | + self.add_to_cache(query, _response) |
| 64 | + return _response |
| 65 | + |
| 66 | + |
| 67 | +# Example usage |
| 68 | +def compute_response(query: str): |
| 69 | + llm = Ollama(model=os.environ.get('OLLAMA_MODEL'), base_url=os.environ.get('OLLAMA_BASE_URL')) |
| 70 | + # Create a user message |
| 71 | + user_message = ChatMessage( |
| 72 | + role=MessageRole.USER, |
| 73 | + content=query |
| 74 | + ) |
| 75 | + |
| 76 | + # Generate a response from the assistant |
| 77 | + assistant_message = llm.chat(messages=[user_message]) |
| 78 | + return f"Computed response for: {query} is {assistant_message}" |
| 79 | + |
| 80 | + |
| 81 | +semantic_cache = SemanticCache(threshold=0.8) |
| 82 | +query = "What is the capital of France?" |
| 83 | +response = semantic_cache.get_response(query, compute_response) |
| 84 | +print(response) |
0 commit comments