1
+ import os
2
+ import tqdm
3
+ from qdrant_client import QdrantClient , models
4
+ from fastembed .embedding import TextEmbedding
5
+ from fastembed .sparse .sparse_text_embedding import SparseTextEmbedding
6
+ from fastembed .late_interaction import LateInteractionTextEmbedding
7
+ from dotenv import load_dotenv , find_dotenv
8
+ from datasets import load_dataset
9
+
10
+ _ = load_dotenv (find_dotenv ())
11
+
12
+
13
+ class AdvancedHybridSearch :
14
+ def __init__ (self , collection_name : str ):
15
+ self .dense_embedding_model = TextEmbedding (model_name = os .environ .get ("DENSE_MODEL" ))
16
+ self .sparse_embedding_model = SparseTextEmbedding (model_name = os .environ .get ("SPARSE_MODEL" ))
17
+ self .late_interaction_embedding_model = LateInteractionTextEmbedding (os .environ .get ("LATE_INTERACTION_MODEL" ))
18
+
19
+ self .client = QdrantClient (url = os .environ ['DB_URL' ], api_key = os .environ ['DB_API_KEY' ])
20
+
21
+ self .collection_name = collection_name
22
+ self .dense_embeddings = None
23
+ self .sparse_embeddings = None
24
+ self .late_interaction_embeddings = None
25
+ self .dataset = None
26
+
27
+ self ._create_collection ()
28
+
29
+ def _get_dimensions (self ):
30
+ self .dataset = load_dataset ("BeIR/scifact" , 'corpus' , split = "corpus" )
31
+ self .dense_embeddings = list (self .dense_embedding_model .passage_embed (self .dataset ["text" ][0 :1 ]))
32
+ self .sparse_embeddings = list (self .sparse_embedding_model .passage_embed (self .dataset ["text" ][0 :1 ]))
33
+ self .late_interaction_embeddings = list (
34
+ self .late_interaction_embedding_model .passage_embed (self .dataset ["text" ][0 :1 ]))
35
+
36
+ def _create_collection (self ):
37
+
38
+ self ._get_dimensions ()
39
+
40
+ if not self .client .collection_exists (collection_name = self .collection_name ):
41
+ self .client .create_collection (
42
+ collection_name = self .collection_name ,
43
+ vectors_config = {
44
+ "all-MiniLM-L6-v2" : models .VectorParams (
45
+ size = len (self .dense_embeddings [0 ]),
46
+ distance = models .Distance .COSINE
47
+ ),
48
+ "colbertv2.0" : models .VectorParams (
49
+ size = len (self .late_interaction_embeddings [0 ][0 ]),
50
+ distance = models .Distance .COSINE ,
51
+ multivector_config = models .MultiVectorConfig (
52
+ comparator = models .MultiVectorComparator .MAX_SIM
53
+ )
54
+ )
55
+ },
56
+ sparse_vectors_config = {
57
+ "splade-PP-en-v1" : models .SparseVectorParams (
58
+ modifier = models .Modifier .IDF
59
+ )
60
+ }
61
+ )
62
+
63
+ def insert_data (self ):
64
+ batch_size = 4
65
+ for batch in tqdm .tqdm (self .dataset .iter (batch_size = batch_size ), total = len (self .dataset ) // batch_size ):
66
+ dense_embeddings = list (self .dense_embedding_model .passage_embed (batch ["text" ]))
67
+ sparse_embeddings = list (self .sparse_embedding_model .passage_embed (batch ["text" ]))
68
+ late_interaction_embeddings = list (self .late_interaction_embedding_model .passage_embed (batch ["text" ]))
69
+
70
+ self .client .upsert (
71
+ collection_name = self .collection_name ,
72
+ points = [
73
+ models .PointStruct (
74
+ id = int (batch ["_id" ][i ]),
75
+ vector = {
76
+ "all-MiniLM-L6-v2" : dense_embeddings [i ].tolist (),
77
+ "splade-PP-en-v1" : sparse_embeddings [i ].as_object (),
78
+ "colbertv2.0" : late_interaction_embeddings [i ].tolist (),
79
+ },
80
+ payload = {
81
+ "_id" : batch ["_id" ][i ],
82
+ "title" : batch ["title" ][i ],
83
+ "text" : batch ["text" ][i ],
84
+ }
85
+ )
86
+ for i , _ in enumerate (batch ["_id" ])
87
+ ]
88
+ )
89
+
90
+ def query_with_dense_embedding (self , query_text : str ):
91
+ query_vector = next (self .dense_embedding_model .embed (query_text )).tolist ()
92
+ results = self .client .query_points (
93
+ collection_name = self .collection_name ,
94
+ query = query_vector ,
95
+ using = "all-MiniLM-L6-v2" ,
96
+ with_payload = False ,
97
+ limit = 10 ,
98
+ )
99
+ return results
100
+
101
+ def query_with_sparse_embedding (self , query_text : str ):
102
+ query_vector = next (self .sparse_embedding_model .embed (query_text ))
103
+ results = self .client .query_points (
104
+ collection_name = self .collection_name ,
105
+ query = models .SparseVector (** query_vector .as_object ()),
106
+ using = "splade-PP-en-v1" ,
107
+ with_payload = False ,
108
+ limit = 10 ,
109
+ )
110
+ return results
111
+
112
+ def query_with_late_interaction_embedding (self , query_text : str ):
113
+ query_vector = next (self .late_interaction_embedding_model .embed (query_text )).tolist ()
114
+ results = self .client .query_points (
115
+ collection_name = self .collection_name ,
116
+ query = query_vector ,
117
+ using = "colbertv2.0" ,
118
+ with_payload = False ,
119
+ limit = 10 ,
120
+ )
121
+ return results
122
+
123
+ def query_with_rrf (self , query_text : str ):
124
+ dense_query_vector = next (self .dense_embedding_model .embed (query_text )).tolist ()
125
+ sparse_query_vector = next (self .sparse_embedding_model .embed (query_text ))
126
+
127
+ prefetch = [
128
+ models .Prefetch (
129
+ query = dense_query_vector ,
130
+ using = "all-MiniLM-L6-v2" ,
131
+ limit = 20 ,
132
+ ),
133
+ models .Prefetch (
134
+ query = models .SparseVector (** sparse_query_vector .as_object ()),
135
+ using = "splade-PP-en-v1" ,
136
+ limit = 20 ,
137
+ ),
138
+ ]
139
+
140
+ results = self .client .query_points (
141
+ collection_name = self .collection_name ,
142
+ prefetch = prefetch ,
143
+ query = models .FusionQuery (
144
+ fusion = models .Fusion .RRF
145
+ ),
146
+ with_payload = False ,
147
+ limit = 10 ,
148
+ )
149
+ return results
0 commit comments