Skip to content

Commit

Permalink
set batch, optmize retrieve
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 18, 2024
1 parent 8b19146 commit ddbf16b
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,10 @@ def rank_and_filter(self, query_binary_vector, top_k, metadata=None):
query_binary_vector = np.array(query_binary_vector).reshape(-1)
logger.debug(f"[VLite.rank_and_filter] Shape of query vector after reshaping: {query_binary_vector.shape}")
# Collect all binary vectors and ensure they all have the same shape as the query vector
binary_vectors = []
mismatch_count = 0
for item_id, item in self.index.items():
binary_vector = item['binary_vector']
if len(binary_vector) == len(query_binary_vector):
binary_vectors.append(binary_vector)
else:
mismatch_count += 1
logger.warning(f"[VLite.rank_and_filter] Skipping vector with ID {item_id} of length {len(binary_vector)}, expected length {len(query_binary_vector)}")
if mismatch_count > 0:
logger.warning(f"[VLite.rank_and_filter] Skipped {mismatch_count} vectors due to length mismatch.")
# Convert list of binary vectors to a NumPy array
binary_vectors = [item['binary_vector'] for item in self.index.values() if len(item['binary_vector']) == len(query_binary_vector)]
binary_vector_indices = [idx for idx, item in enumerate(self.index.values()) if len(item['binary_vector']) == len(query_binary_vector)]
if binary_vectors:
corpus_binary_vectors = np.array(binary_vectors, dtype=np.float32)
corpus_binary_vectors = np.asarray(binary_vectors, dtype=np.float32)
logger.debug(f"[VLite.rank_and_filter] Shape of corpus binary vectors array: {corpus_binary_vectors.shape}")
else:
raise ValueError("No valid binary vectors found for comparison.")
Expand All @@ -148,7 +138,8 @@ def rank_and_filter(self, query_binary_vector, top_k, metadata=None):
logger.debug(f"[VLite.rank_and_filter] Top {top_k} scores: {top_k_scores}")
logger.debug(f"[VLite.rank_and_filter] No. of items in the collection: {len(self.index)}")
logger.debug(f"[VLite.rank_and_filter] Vlite count: {self.count()}")
top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices]

top_k_ids = [list(self.index.keys())[binary_vector_indices[idx]] for idx in top_k_indices]
# Apply metadata filter on the retrieved top_k items
filtered_ids = []
if metadata:
Expand Down Expand Up @@ -233,6 +224,41 @@ def set(self, id, text=None, metadata=None, vector=None):
else:
self.add(text, metadata=metadata, item_id=id)
logger.info(f"[VLite.set] Item with ID '{id}' created successfully.")


def set_batch(self, texts, embeddings, metadatas=None):
start_time = time.time()
if not isinstance(texts, list):
texts = [texts]

if metadatas is None:
metadatas = [{}] * len(texts)
elif not isinstance(metadatas, list):
metadatas = [metadatas]

if len(texts) != len(embeddings):
print("asdasd",len(texts), len(embeddings))
raise ValueError("The number of texts and embeddings must be the same.")

if len(texts) != len(metadatas):
raise ValueError("The number of texts and metadatas must be the same.")

for text, embedding, metadata in zip(texts, embeddings, metadatas):
item_id = str(uuid4())
chunk_id = f"{item_id}_0"

self.index[chunk_id] = {
'text': text,
'metadata': metadata,
'binary_vector': embedding.tolist()
}


self.save()
logger.info("[VLite.set_batch] Texts added successfully.")
end_time = time.time()
logger.debug(f"[VLite.set_batch] Execution time: {end_time - start_time:.5f} seconds")


def count(self):
return len(self.index)
Expand Down

0 comments on commit ddbf16b

Please sign in to comment.