Skip to content

Commit

Permalink
Refactor vector retrieval and apply metadata filter
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 4, 2024
1 parent 84c51e1 commit f58e9ea
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def add(self, data, metadata=None, need_chunks=True, newEmbedding=False):
print("Encoding text... not chunking")
encoded_data = self.model.embed(chunks, device=self.device)

# Quantize the embeddings to binary and int8
binary_encoded_data = self.model.quantize(encoded_data, precision="binary")
int8_encoded_data = self.model.quantize(encoded_data, precision="int8")

Expand All @@ -90,8 +89,8 @@ def add(self, data, metadata=None, need_chunks=True, newEmbedding=False):
'text': chunk,
'metadata': item_metadata,
'vector': vector,
'binary_vector': binary_vector.tolist(), # Convert to list for JSON serialization
'int8_vector': int8_vector.tolist() # Convert to list for JSON serialization
'binary_vector': binary_vector.tolist(),
'int8_vector': int8_vector.tolist()
}


Expand Down Expand Up @@ -142,20 +141,17 @@ def retrieval_rescore(self, query_binary_vector, query_int8_vector, top_k, metad
# Perform binary search
binary_vectors = np.array([item['binary_vector'] for item in self.index.values()])
similarities = np.dot(query_binary_vector, binary_vectors.T).flatten()
top_k_indices = np.argsort(similarities)[-top_k*4:][::-1] # Retrieve top_k*4 results for rescoring
top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices]

# Apply metadata filter while finding similar texts
# Apply metadata filter on the retrieved top_k*4 items
if metadata:
filtered_indices = []
for idx, item_id in enumerate(self.index.keys()): # Iterate over item IDs
filtered_ids = []
for item_id in top_k_ids:
item_metadata = self.index[item_id]['metadata']
if all(item_metadata.get(key) == value for key, value in metadata.items()):
filtered_indices.append(idx)
if len(filtered_indices) == top_k: # Stop when we have found top_k
top_k_ids = [list(self.index.keys())[idx] for idx in filtered_indices]
else:
top_k_ids = [list(self.index.keys())[idx] for idx in np.argsort(similarities)[-top_k:][::-1]]
else:
top_k_ids = [list(self.index.keys())[idx] for idx in np.argsort(similarities)[-top_k:][::-1]]
filtered_ids.append(item_id)
top_k_ids = filtered_ids[:top_k*4]

# Perform rescoring using int8 embeddings
int8_vectors = np.array([self.index[idx]['int8_vector'] for idx in top_k_ids])
Expand All @@ -166,7 +162,7 @@ def retrieval_rescore(self, query_binary_vector, query_int8_vector, top_k, metad
sorted_ids = [top_k_ids[idx] for idx in sorted_indices]
sorted_scores = rescored_similarities[sorted_indices]

return list(zip(sorted_ids, sorted_scores))
return list(zip(sorted_ids, sorted_scores))[:top_k] # Return top_k results after rescoring

def delete(self, ids):
"""
Expand Down

0 comments on commit f58e9ea

Please sign in to comment.