From f58e9ea3b33c26199ce75b30e486de9524364542 Mon Sep 17 00:00:00 2001 From: Surya Date: Thu, 4 Apr 2024 00:45:19 -0700 Subject: [PATCH] Refactor vector retrieval and apply metadata filter --- vlite/main.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/vlite/main.py b/vlite/main.py index 16a6ac8..edcd956 100644 --- a/vlite/main.py +++ b/vlite/main.py @@ -80,7 +80,6 @@ def add(self, data, metadata=None, need_chunks=True, newEmbedding=False): print("Encoding text... not chunking") encoded_data = self.model.embed(chunks, device=self.device) - # Quantize the embeddings to binary and int8 binary_encoded_data = self.model.quantize(encoded_data, precision="binary") int8_encoded_data = self.model.quantize(encoded_data, precision="int8") @@ -90,8 +89,8 @@ def add(self, data, metadata=None, need_chunks=True, newEmbedding=False): 'text': chunk, 'metadata': item_metadata, 'vector': vector, - 'binary_vector': binary_vector.tolist(), # Convert to list for JSON serialization - 'int8_vector': int8_vector.tolist() # Convert to list for JSON serialization + 'binary_vector': binary_vector.tolist(), + 'int8_vector': int8_vector.tolist() } @@ -142,20 +141,17 @@ def retrieval_rescore(self, query_binary_vector, query_int8_vector, top_k, metad # Perform binary search binary_vectors = np.array([item['binary_vector'] for item in self.index.values()]) similarities = np.dot(query_binary_vector, binary_vectors.T).flatten() + top_k_indices = np.argsort(similarities)[-top_k*4:][::-1] # Retrieve top_k*4 results for rescoring + top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices] - # Apply metadata filter while finding similar texts + # Apply metadata filter on the retrieved top_k*4 items if metadata: - filtered_indices = [] - for idx, item_id in enumerate(self.index.keys()): # Iterate over item IDs + filtered_ids = [] + for item_id in top_k_ids: item_metadata = self.index[item_id]['metadata'] if all(item_metadata.get(key) == value for key, value in metadata.items()): - filtered_indices.append(idx) - if len(filtered_indices) == top_k: # Stop when we have found top_k - top_k_ids = [list(self.index.keys())[idx] for idx in filtered_indices] - else: - top_k_ids = [list(self.index.keys())[idx] for idx in np.argsort(similarities)[-top_k:][::-1]] - else: - top_k_ids = [list(self.index.keys())[idx] for idx in np.argsort(similarities)[-top_k:][::-1]] + filtered_ids.append(item_id) + top_k_ids = filtered_ids[:top_k*4] # Perform rescoring using int8 embeddings int8_vectors = np.array([self.index[idx]['int8_vector'] for idx in top_k_ids]) @@ -166,7 +162,7 @@ def retrieval_rescore(self, query_binary_vector, query_int8_vector, top_k, metad sorted_ids = [top_k_ids[idx] for idx in sorted_indices] sorted_scores = rescored_similarities[sorted_indices] - return list(zip(sorted_ids, sorted_scores)) + return list(zip(sorted_ids, sorted_scores))[:top_k] # Return top_k results after rescoring def delete(self, ids): """