Skip to content

Commit

Permalink
optimizoor
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 4, 2024
1 parent 014a509 commit 967a5e8
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def retrieve(self, text=None, top_k=5, metadata=None, newEmbedding=False):
print("Retrieval completed.")
return [(self.index[idx]['text'], score, self.index[idx]['metadata']) for idx, score in results]

def rescore(self, query_binary_vector, query_int8_vector, top_k, metadata=None, rescore_multiplier=4):
def rescore(self, query_binary_vector, query_int8_vector, top_k, metadata=None):
"""
Performs retrieval using binary search and rescoring using int8 embeddings.
Expand All @@ -139,30 +139,30 @@ def rescore(self, query_binary_vector, query_int8_vector, top_k, metadata=None,
"""
# Perform binary search
binary_vectors = np.array([item['binary_vector'] for item in self.index.values()])
similarities = np.dot(query_binary_vector, binary_vectors.T).flatten()
top_k_indices = np.argsort(similarities)[-top_k*rescore_multiplier:][::-1]
binary_similarities = np.einsum('i,ji->j', query_binary_vector[0], binary_vectors)
top_k_indices = np.argpartition(binary_similarities, -top_k*4)[-top_k*4:]
top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices]

# Filter results based on metadata
# Apply metadata filter on the retrieved top_k*4 items
if metadata:
filtered_ids = []
for item_id in top_k_ids:
item_metadata = self.index[item_id]['metadata']
if all(item_metadata.get(key) == value for key, value in metadata.items()):
filtered_ids.append(item_id)
top_k_ids = filtered_ids[:top_k*rescore_multiplier]
top_k_ids = filtered_ids[:top_k*4]

# Perform rescoring using int8 embeddings
int8_vectors = np.array([self.index[idx]['int8_vector'] for idx in top_k_ids])
rescored_similarities = self.model.rescore(query_int8_vector, int8_vectors)
int8_similarities = np.einsum('i,ji->j', query_int8_vector[0], int8_vectors)

# Sort the results based on the rescored similarities
sorted_indices = np.argsort(rescored_similarities)[::-1]
sorted_ids = [top_k_ids[idx] for idx in sorted_indices]
sorted_scores = rescored_similarities[sorted_indices]
# Sort the results based on the int8 similarities
sorted_indices = np.argpartition(int8_similarities, -top_k)[-top_k:]
sorted_ids = np.take(top_k_ids, sorted_indices)
sorted_scores = int8_similarities[sorted_indices]

return list(zip(sorted_ids, sorted_scores))

return list(zip(sorted_ids, sorted_scores))[:top_k] # Return top_k results after rescoring

def delete(self, ids):
"""
Deletes items from the collection by their IDs.
Expand Down

0 comments on commit 967a5e8

Please sign in to comment.