Skip to content

Commit 047d036

Browse files
pavanjavapavanmantha
and
pavanmantha
authored
Qdrant search improvements (#21)
* enhanced the query engines with max_retries parameter * -implemented hybrid search, -improved the simple search * -implemented hybrid search, -improved the simple search * -impproved the docs * -implemented advanced search, -corrected hybrid search code, -improved docs --------- Co-authored-by: pavanmantha <[email protected]>
1 parent e28eb73 commit 047d036

File tree

9 files changed

+22213
-2
lines changed

9 files changed

+22213
-2
lines changed

bootstraprag/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def create(project_name, framework, template, observability):
4040
]
4141
elif framework == 'None':
4242
framework = 'qdrant'
43-
template_choices = ['simple-search']
43+
template_choices = ['simple-search', 'hybrid-search', 'hybrid-search-advanced']
4444
# Use InquirerPy to select template with arrow keys
4545
template = inquirer.select(
4646
message="Which template would you like to use?",

bootstraprag/templates/qdrant/hybrid_search/search.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self, collection_name: str, vector_dimension: int = 384, distance:
1515

1616
# set the dense and sparse embedding models
1717
self.client.set_model(os.environ.get('DENSE_MODEL'))
18+
self.client.set_sparse_model(os.environ.get('SPARSE_MODEL'))
1819
self.vector_dimension = vector_dimension
1920
self.distance = distance
2021
self.collection_name = collection_name
@@ -49,7 +50,7 @@ def insert(self) -> UpdateResult:
4950
collection_name=self.collection_name,
5051
documents=self.documents,
5152
metadata=self.metadata,
52-
batch_size=128, # a batch os 128 embeddings will be pushed in a single request
53+
# batch_size=128, # a batch os 128 embeddings will be pushed in a single request
5354
ids=tqdm(range(len(self.documents)))
5455
)
5556

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
DB_URL='http://localhost:6333'
2+
DB_API_KEY='th3s3cr3tk3y'
3+
COLLECTION_NAME='YOUR_COLLECTION'
4+
DENSE_MODEL='sentence-transformers/all-MiniLM-L6-v2'
5+
SPARSE_MODEL='prithivida/Splade_PP_en_v1'
6+
LATE_INTERACTION_MODEL="colbert-ir/colbertv2.0"

bootstraprag/templates/qdrant/hybrid_search_advanced/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import os
2+
import tqdm
3+
from qdrant_client import QdrantClient, models
4+
from fastembed.embedding import TextEmbedding
5+
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
6+
from fastembed.late_interaction import LateInteractionTextEmbedding
7+
from dotenv import load_dotenv, find_dotenv
8+
from datasets import load_dataset
9+
10+
_ = load_dotenv(find_dotenv())
11+
12+
13+
class AdvancedHybridSearch:
14+
def __init__(self, collection_name: str):
15+
self.dense_embedding_model = TextEmbedding(model_name=os.environ.get("DENSE_MODEL"))
16+
self.sparse_embedding_model = SparseTextEmbedding(model_name=os.environ.get("SPARSE_MODEL"))
17+
self.late_interaction_embedding_model = LateInteractionTextEmbedding(os.environ.get("LATE_INTERACTION_MODEL"))
18+
19+
self.client = QdrantClient(url=os.environ['DB_URL'], api_key=os.environ['DB_API_KEY'])
20+
21+
self.collection_name = collection_name
22+
self.dense_embeddings = None
23+
self.sparse_embeddings = None
24+
self.late_interaction_embeddings = None
25+
self.dataset = None
26+
27+
self._create_collection()
28+
29+
def _get_dimensions(self):
30+
self.dataset = load_dataset("BeIR/scifact", 'corpus', split="corpus")
31+
self.dense_embeddings = list(self.dense_embedding_model.passage_embed(self.dataset["text"][0:1]))
32+
self.sparse_embeddings = list(self.sparse_embedding_model.passage_embed(self.dataset["text"][0:1]))
33+
self.late_interaction_embeddings = list(
34+
self.late_interaction_embedding_model.passage_embed(self.dataset["text"][0:1]))
35+
36+
def _create_collection(self):
37+
38+
self._get_dimensions()
39+
40+
if not self.client.collection_exists(collection_name=self.collection_name):
41+
self.client.create_collection(
42+
collection_name=self.collection_name,
43+
vectors_config={
44+
"all-MiniLM-L6-v2": models.VectorParams(
45+
size=len(self.dense_embeddings[0]),
46+
distance=models.Distance.COSINE
47+
),
48+
"colbertv2.0": models.VectorParams(
49+
size=len(self.late_interaction_embeddings[0][0]),
50+
distance=models.Distance.COSINE,
51+
multivector_config=models.MultiVectorConfig(
52+
comparator=models.MultiVectorComparator.MAX_SIM
53+
)
54+
)
55+
},
56+
sparse_vectors_config={
57+
"splade-PP-en-v1": models.SparseVectorParams(
58+
modifier=models.Modifier.IDF
59+
)
60+
}
61+
)
62+
63+
def insert_data(self):
64+
batch_size = 4
65+
for batch in tqdm.tqdm(self.dataset.iter(batch_size=batch_size), total=len(self.dataset) // batch_size):
66+
dense_embeddings = list(self.dense_embedding_model.passage_embed(batch["text"]))
67+
sparse_embeddings = list(self.sparse_embedding_model.passage_embed(batch["text"]))
68+
late_interaction_embeddings = list(self.late_interaction_embedding_model.passage_embed(batch["text"]))
69+
70+
self.client.upsert(
71+
collection_name=self.collection_name,
72+
points=[
73+
models.PointStruct(
74+
id=int(batch["_id"][i]),
75+
vector={
76+
"all-MiniLM-L6-v2": dense_embeddings[i].tolist(),
77+
"splade-PP-en-v1": sparse_embeddings[i].as_object(),
78+
"colbertv2.0": late_interaction_embeddings[i].tolist(),
79+
},
80+
payload={
81+
"_id": batch["_id"][i],
82+
"title": batch["title"][i],
83+
"text": batch["text"][i],
84+
}
85+
)
86+
for i, _ in enumerate(batch["_id"])
87+
]
88+
)
89+
90+
def query_with_dense_embedding(self, query_text: str):
91+
query_vector = next(self.dense_embedding_model.embed(query_text)).tolist()
92+
results = self.client.query_points(
93+
collection_name=self.collection_name,
94+
query=query_vector,
95+
using="all-MiniLM-L6-v2",
96+
with_payload=False,
97+
limit=10,
98+
)
99+
return results
100+
101+
def query_with_sparse_embedding(self, query_text: str):
102+
query_vector = next(self.sparse_embedding_model.embed(query_text))
103+
results = self.client.query_points(
104+
collection_name=self.collection_name,
105+
query=models.SparseVector(**query_vector.as_object()),
106+
using="splade-PP-en-v1",
107+
with_payload=False,
108+
limit=10,
109+
)
110+
return results
111+
112+
def query_with_late_interaction_embedding(self, query_text: str):
113+
query_vector = next(self.late_interaction_embedding_model.embed(query_text)).tolist()
114+
results = self.client.query_points(
115+
collection_name=self.collection_name,
116+
query=query_vector,
117+
using="colbertv2.0",
118+
with_payload=False,
119+
limit=10,
120+
)
121+
return results
122+
123+
def query_with_rrf(self, query_text: str):
124+
dense_query_vector = next(self.dense_embedding_model.embed(query_text)).tolist()
125+
sparse_query_vector = next(self.sparse_embedding_model.embed(query_text))
126+
127+
prefetch = [
128+
models.Prefetch(
129+
query=dense_query_vector,
130+
using="all-MiniLM-L6-v2",
131+
limit=20,
132+
),
133+
models.Prefetch(
134+
query=models.SparseVector(**sparse_query_vector.as_object()),
135+
using="splade-PP-en-v1",
136+
limit=20,
137+
),
138+
]
139+
140+
results = self.client.query_points(
141+
collection_name=self.collection_name,
142+
prefetch=prefetch,
143+
query=models.FusionQuery(
144+
fusion=models.Fusion.RRF
145+
),
146+
with_payload=False,
147+
limit=10,
148+
)
149+
return results

0 commit comments

Comments
 (0)