Skip to content

Commit

Permalink
Refactor retrieval_rescore method in VLite class
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 4, 2024
1 parent f58e9ea commit 014a509
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class VLite:
def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxbai-embed-large-v1'):
if collection is None:
current_datetime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
collection = f"vlite_{current_datetime}.npz"
collection = f"vlite_{current_datetime}"
self.collection = f"{collection}.npz"
self.device = device
self.model = EmbeddingModel(model_name) if model_name else EmbeddingModel()
Expand Down Expand Up @@ -119,13 +119,12 @@ def retrieve(self, text=None, top_k=5, metadata=None, newEmbedding=False):
query_binary_vector = self.model.quantize(query_vector, precision="binary")
query_int8_vector = self.model.quantize(query_vector, precision="int8")

# Perform binary search and rescoring
results = self.retrieval_rescore(query_binary_vector, query_int8_vector, top_k, metadata)
results = self.rescore(query_binary_vector, query_int8_vector, top_k, metadata)

print("Retrieval completed.")
return [(self.index[idx]['text'], score, self.index[idx]['metadata']) for idx, score in results]

def retrieval_rescore(self, query_binary_vector, query_int8_vector, top_k, metadata=None):
def rescore(self, query_binary_vector, query_int8_vector, top_k, metadata=None, rescore_multiplier=4):
"""
Performs retrieval using binary search and rescoring using int8 embeddings.
Expand All @@ -141,17 +140,17 @@ def retrieval_rescore(self, query_binary_vector, query_int8_vector, top_k, metad
# Perform binary search
binary_vectors = np.array([item['binary_vector'] for item in self.index.values()])
similarities = np.dot(query_binary_vector, binary_vectors.T).flatten()
top_k_indices = np.argsort(similarities)[-top_k*4:][::-1] # Retrieve top_k*4 results for rescoring
top_k_indices = np.argsort(similarities)[-top_k*rescore_multiplier:][::-1]
top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices]

# Apply metadata filter on the retrieved top_k*4 items
# Filter results based on metadata
if metadata:
filtered_ids = []
for item_id in top_k_ids:
item_metadata = self.index[item_id]['metadata']
if all(item_metadata.get(key) == value for key, value in metadata.items()):
filtered_ids.append(item_id)
top_k_ids = filtered_ids[:top_k*4]
top_k_ids = filtered_ids[:top_k*rescore_multiplier]

# Perform rescoring using int8 embeddings
int8_vectors = np.array([self.index[idx]['int8_vector'] for idx in top_k_ids])
Expand Down

0 comments on commit 014a509

Please sign in to comment.