From ddbf16b1b2aa56017e7768e407a16b5ddb3bb5bf Mon Sep 17 00:00:00 2001 From: Surya Date: Wed, 17 Apr 2024 21:38:13 -0700 Subject: [PATCH] set batch, optmize retrieve --- vlite/main.py | 54 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/vlite/main.py b/vlite/main.py index 96d1682..4eacdce 100644 --- a/vlite/main.py +++ b/vlite/main.py @@ -126,20 +126,10 @@ def rank_and_filter(self, query_binary_vector, top_k, metadata=None): query_binary_vector = np.array(query_binary_vector).reshape(-1) logger.debug(f"[VLite.rank_and_filter] Shape of query vector after reshaping: {query_binary_vector.shape}") # Collect all binary vectors and ensure they all have the same shape as the query vector - binary_vectors = [] - mismatch_count = 0 - for item_id, item in self.index.items(): - binary_vector = item['binary_vector'] - if len(binary_vector) == len(query_binary_vector): - binary_vectors.append(binary_vector) - else: - mismatch_count += 1 - logger.warning(f"[VLite.rank_and_filter] Skipping vector with ID {item_id} of length {len(binary_vector)}, expected length {len(query_binary_vector)}") - if mismatch_count > 0: - logger.warning(f"[VLite.rank_and_filter] Skipped {mismatch_count} vectors due to length mismatch.") - # Convert list of binary vectors to a NumPy array + binary_vectors = [item['binary_vector'] for item in self.index.values() if len(item['binary_vector']) == len(query_binary_vector)] + binary_vector_indices = [idx for idx, item in enumerate(self.index.values()) if len(item['binary_vector']) == len(query_binary_vector)] if binary_vectors: - corpus_binary_vectors = np.array(binary_vectors, dtype=np.float32) + corpus_binary_vectors = np.asarray(binary_vectors, dtype=np.float32) logger.debug(f"[VLite.rank_and_filter] Shape of corpus binary vectors array: {corpus_binary_vectors.shape}") else: raise ValueError("No valid binary vectors found for comparison.") @@ -148,7 +138,8 @@ def rank_and_filter(self, query_binary_vector, top_k, metadata=None): logger.debug(f"[VLite.rank_and_filter] Top {top_k} scores: {top_k_scores}") logger.debug(f"[VLite.rank_and_filter] No. of items in the collection: {len(self.index)}") logger.debug(f"[VLite.rank_and_filter] Vlite count: {self.count()}") - top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices] + + top_k_ids = [list(self.index.keys())[binary_vector_indices[idx]] for idx in top_k_indices] # Apply metadata filter on the retrieved top_k items filtered_ids = [] if metadata: @@ -233,6 +224,41 @@ def set(self, id, text=None, metadata=None, vector=None): else: self.add(text, metadata=metadata, item_id=id) logger.info(f"[VLite.set] Item with ID '{id}' created successfully.") + + + def set_batch(self, texts, embeddings, metadatas=None): + start_time = time.time() + if not isinstance(texts, list): + texts = [texts] + + if metadatas is None: + metadatas = [{}] * len(texts) + elif not isinstance(metadatas, list): + metadatas = [metadatas] + + if len(texts) != len(embeddings): + print("asdasd",len(texts), len(embeddings)) + raise ValueError("The number of texts and embeddings must be the same.") + + if len(texts) != len(metadatas): + raise ValueError("The number of texts and metadatas must be the same.") + + for text, embedding, metadata in zip(texts, embeddings, metadatas): + item_id = str(uuid4()) + chunk_id = f"{item_id}_0" + + self.index[chunk_id] = { + 'text': text, + 'metadata': metadata, + 'binary_vector': embedding.tolist() + } + + + self.save() + logger.info("[VLite.set_batch] Texts added successfully.") + end_time = time.time() + logger.debug(f"[VLite.set_batch] Execution time: {end_time - start_time:.5f} seconds") + def count(self): return len(self.index)