-
Notifications
You must be signed in to change notification settings - Fork 0
/
query.py
78 lines (62 loc) · 3.39 KB
/
query.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer
import torch
import chromadb
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE, Settings
# Load CLIP model and tokenizer for text embedding
model_id = "jinaai/jina-clip-v1"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
embedding_model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
# Load the reranker model
reranker_model = AutoModelForSequenceClassification.from_pretrained('jinaai/jina-reranker-v1-turbo-en', num_labels=1, trust_remote_code=True)
reranker_tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-reranker-v1-turbo-en', trust_remote_code=True)
class QueryRAG:
def __init__(self):
# Initialize ChromaDB Persistent Client
CHROMA_PATH = "chroma"
self.client = chromadb.PersistentClient(
path=CHROMA_PATH,
)
self.text_collection = self.client.get_collection("texts")
self.image_collection = self.client.get_collection("images")
def embed_text(self, texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
embeddings = embedding_model.get_text_features(**inputs).cpu().numpy()
return embeddings
def query(self, query_text: str):
# Embed the query text
query_embeddings = self.embed_text([query_text])
# Perform similarity search on text_collection
text_results = self.text_collection.query(query_embeddings=query_embeddings, n_results=10)
if 'documents' in text_results and len(text_results['documents'][0]) > 0:
# Extract context texts and sources
most_relevant_documents = [str(doc) for doc in text_results['documents'][0]]
sources = [source for source in text_results['ids'][0]]
# Construct sentence pairs
sentence_pairs = [[query_text, context_text] for context_text in most_relevant_documents]
# Tokenize and encode sentence pairs for reranking
reranker_inputs = reranker_tokenizer(sentence_pairs, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = reranker_model(**reranker_inputs)
scores = outputs.logits.squeeze(-1).cpu().numpy()
# Rank the documents by their reranker scores
ranked_indices = scores.argsort()[::-1][:3] # Get indices of the top 3 scores
top_documents = [most_relevant_documents[idx] for idx in ranked_indices]
top_sources = [sources[idx] for idx in ranked_indices]
formatted_responses = [(top_documents[i], top_sources[i]) for i in range(3)]
else:
formatted_responses = [("No relevant documents found.", "")]
return formatted_responses
def get_contexts(query_text: str):
query_rag = QueryRAG()
results = query_rag.query(query_text)
contexts = [result[0] for result in results if result[0] != "No relevant documents found."]
return contexts
def main():
query_rag = QueryRAG()
query_text = input("Enter your question: ")
results = query_rag.query(query_text)
for result in results:
print(f"Most relevant content: {result[0]}\nMost relevant source: {result[1]}\n")
if __name__ == "__main__":
main()