From 2b55acd318c2245eb360b156071d8acb964b1321 Mon Sep 17 00:00:00 2001 From: Surya Date: Tue, 16 Apr 2024 13:19:04 -0700 Subject: [PATCH] 0.09724 seconds retrieval --- vlite/index.py | 38 +++++++++++ vlite/main.py | 170 +++++++++++++++++++++---------------------------- vlite/onnx.py | 67 ------------------- 3 files changed, 112 insertions(+), 163 deletions(-) create mode 100644 vlite/index.py delete mode 100644 vlite/onnx.py diff --git a/vlite/index.py b/vlite/index.py new file mode 100644 index 0000000..e40f071 --- /dev/null +++ b/vlite/index.py @@ -0,0 +1,38 @@ +import numpy as np +import struct +import json +from enum import Enum +from typing import List, Union, Dict + +class BinaryVectorIndex: + def __init__(self, embedding_size=64): + self.index = {} + self.embedding_size = embedding_size + + def add(self, chunk_id, binary_vector): + binary_vector = binary_vector.tolist() + self.index[chunk_id] = binary_vector + + def add_batch(self, chunk_ids, binary_vectors): + for chunk_id, binary_vector in zip(chunk_ids, binary_vectors): + self.add(chunk_id, binary_vector) + + def remove(self, chunk_id): + if chunk_id in self.index: + del self.index[chunk_id] + + def search(self, query_vector, top_k): + query_vector = np.array(query_vector.tolist()) + binary_vectors = np.array(list(self.index.values())) + chunk_ids = np.array(list(self.index.keys())) + + distances = np.count_nonzero(binary_vectors != query_vector[:binary_vectors.shape[1]], axis=1) + similarities = 1 - distances / binary_vectors.shape[1] + + sorted_indices = np.argsort(similarities)[::-1] + top_k_indices = sorted_indices[:top_k] + + top_k_ids = chunk_ids[top_k_indices] + top_k_scores = similarities[top_k_indices] + + return top_k_ids.tolist(), top_k_scores.tolist() \ No newline at end of file diff --git a/vlite/main.py b/vlite/main.py index 47a7e74..5a56b0e 100644 --- a/vlite/main.py +++ b/vlite/main.py @@ -3,14 +3,15 @@ from .utils import check_cuda_available, check_mps_available from .model import EmbeddingModel from .utils import chop_and_chunk +from .index import BinaryVectorIndex import datetime from .ctx import Ctx import time import logging # Configure logging -logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) class VLite: def __init__(self, collection=None, device=None, model_name='mixedbread-ai/mxbai-embed-large-v1'): @@ -22,7 +23,7 @@ def __init__(self, collection=None, device=None, model_name='mixedbread-ai/mxbai device = 'mps' else: device = 'cpu' - logger.info(f"[VLite.__init__] Initializing VLite with device: {device}") + print(f"[VLite.__init__] Initializing VLite with device: {device}") self.device = device if collection is None: current_datetime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") @@ -31,25 +32,32 @@ def __init__(self, collection=None, device=None, model_name='mixedbread-ai/mxbai self.model = EmbeddingModel(model_name, device=device) if model_name else EmbeddingModel() self.ctx = Ctx() self.index = {} + self.binary_index = BinaryVectorIndex() try: ctx_file = self.ctx.read(collection) ctx_file.load() - logger.debug(f"[VLite.__init__] Number of embeddings: {len(ctx_file.embeddings)}") - logger.debug(f"[VLite.__init__] Number of metadata: {len(ctx_file.metadata)}") + print(f"[VLite.__init__] Number of embeddings: {len(ctx_file.embeddings)}") + print(f"[VLite.__init__] Number of metadata: {len(ctx_file.metadata)}") + + chunk_ids = list(ctx_file.metadata.keys()) self.index = { chunk_id: { 'text': ctx_file.contexts[idx] if idx < len(ctx_file.contexts) else "", 'metadata': ctx_file.metadata.get(chunk_id, {}), - 'binary_vector': np.array(ctx_file.embeddings[idx]) if idx < len(ctx_file.embeddings) else np.zeros(self.model.embedding_size) } - for idx, chunk_id in enumerate(ctx_file.metadata.keys()) + for idx, chunk_id in enumerate(chunk_ids) } + + self.binary_index.add_batch( + chunk_ids, + [np.array(embedding) if idx < len(ctx_file.embeddings) else np.zeros(64) for idx, embedding in enumerate(ctx_file.embeddings)] + ) except FileNotFoundError: logger.warning(f"[VLite.__init__] Collection file {self.collection} not found. Initializing empty attributes.") end_time = time.time() - logger.debug(f"[VLite.__init__] Execution time: {end_time - start_time:.5f} seconds") - logger.info(f"[VLite.__init__] Using device: {self.device}") - + print(f"[VLite.__init__] Execution time: {end_time - start_time:.5f} seconds") + print(f"[VLite.__init__] Using device: {self.device}") + def add(self, data, metadata=None, item_id=None, need_chunks=False, fast=True): start_time = time.time() data = [data] if not isinstance(data, list) else data @@ -71,91 +79,57 @@ def add(self, data, metadata=None, item_id=None, need_chunks=False, fast=True): chunks = chop_and_chunk(text_content, fast=fast) else: chunks = [text_content] - logger.debug("[VLite.add] Encoding text... not chunking") + print("[VLite.add] Encoding text... not chunking") all_chunks.extend(chunks) all_metadata.extend([item_metadata] * len(chunks)) all_ids.extend([item_id] * len(chunks)) binary_encoded_data = self.model.embed(all_chunks, precision="binary") for idx, (chunk, binary_vector, metadata) in enumerate(zip(all_chunks, binary_encoded_data, all_metadata)): - chunk_id = f"{item_id}_{idx}" + chunk_id = f"{all_ids[idx]}_{idx}" self.index[chunk_id] = { 'text': chunk, 'metadata': metadata, - 'binary_vector': binary_vector.tolist() + 'item_id': all_ids[idx] # Store the item ID along with the chunk data } - + self.binary_index.add(chunk_id, binary_vector) + print(f"[VLite.add] Added chunk ID: {chunk_id}") + print(f"[VLite.add] Main index keys: {list(self.index.keys())}") + print(f"[VLite.add] Binary index keys: {list(self.binary_index.index.keys())}") if item_id not in [result[0] for result in results]: results.append((item_id, binary_encoded_data, metadata)) - self.save() - logger.info("[VLite.add] Text added successfully.") + print("[VLite.add] Text added successfully.") end_time = time.time() - logger.debug(f"[VLite.add] Execution time: {end_time - start_time:.5f} seconds") + print(f"[VLite.add] Execution time: {end_time - start_time:.5f} seconds") return results def retrieve(self, text=None, top_k=5, metadata=None, return_scores=False): start_time = time.time() - logger.info("[VLite.retrieve] Retrieving similar texts...") if text: - logger.info(f"[VLite.retrieve] Retrieving top {top_k} similar texts for query: {text}") query_binary_vectors = self.model.embed(text, precision="binary") - # Perform search on the query binary vectors - results = [] - for query_binary_vector in query_binary_vectors: - chunk_results = self.rank_and_filter(query_binary_vector, top_k, metadata) - results.extend(chunk_results) - # Sort the results by similarity score - results.sort(key=lambda x: x[1]) - results = results[:top_k] - logger.info("[VLite.retrieve] Retrieval completed.") - end_time = time.time() - logger.debug(f"[VLite.retrieve] Execution time: {end_time - start_time:.5f} seconds") + + top_k_ids, top_k_scores = self.binary_index.search(query_binary_vectors, top_k) + + chunk_data = [self.index[chunk_id] for chunk_id in top_k_ids if chunk_id in self.index] + + if metadata: + chunk_data = [data for data in chunk_data if all(data['metadata'].get(key) == value for key, value in metadata.items())] + + texts = [data['text'] for data in chunk_data] + metadatas = [data['metadata'] for data in chunk_data] + scores = [score for chunk_id, score in zip(top_k_ids, top_k_scores) if chunk_id in self.index][:len(chunk_data)] + if return_scores: - return [(idx, self.index[idx]['text'], self.index[idx]['metadata'], score) for idx, score in results] + results = list(zip(chunk_data, texts, metadatas, scores)) else: - return [(idx, self.index[idx]['text'], self.index[idx]['metadata']) for idx, _ in results] - - def rank_and_filter(self, query_binary_vector, top_k, metadata=None): - start_time = time.time() - logger.debug(f"[VLite.rank_and_filter] Shape of query vector: {query_binary_vector.shape}") - query_binary_vector = np.array(query_binary_vector).reshape(-1) - logger.debug(f"[VLite.rank_and_filter] Shape of query vector after reshaping: {query_binary_vector.shape}") - # Collect all binary vectors and ensure they all have the same shape as the query vector - binary_vectors = [] - mismatch_count = 0 - for item_id, item in self.index.items(): - binary_vector = item['binary_vector'] - if len(binary_vector) == len(query_binary_vector): - binary_vectors.append(binary_vector) - else: - mismatch_count += 1 - logger.warning(f"[VLite.rank_and_filter] Skipping vector with ID {item_id} of length {len(binary_vector)}, expected length {len(query_binary_vector)}") - if mismatch_count > 0: - logger.warning(f"[VLite.rank_and_filter] Skipped {mismatch_count} vectors due to length mismatch.") - # Convert list of binary vectors to a NumPy array - if binary_vectors: - corpus_binary_vectors = np.array(binary_vectors) - logger.debug(f"[VLite.rank_and_filter] Shape of corpus binary vectors array: {corpus_binary_vectors.shape}") + results = list(zip(chunk_data, texts, metadatas)) + + end_time = time.time() + logger.debug(f"[VLite.retrieve] Execution time: {end_time - start_time:.5f} seconds") + return results else: - raise ValueError("No valid binary vectors found for comparison.") - top_k_indices, top_k_scores = self.model.search(query_binary_vector, corpus_binary_vectors, top_k) - logger.debug(f"[VLite.rank_and_filter] Top {top_k} indices: {top_k_indices}") - logger.debug(f"[VLite.rank_and_filter] Top {top_k} scores: {top_k_scores}") - logger.debug(f"[VLite.rank_and_filter] No. of items in the collection: {len(self.index)}") - logger.debug(f"[VLite.rank_and_filter] Vlite count: {self.count()}") - top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices] - # Apply metadata filter on the retrieved top_k items - filtered_ids = [] - if metadata: - for chunk_id in top_k_ids: - item_metadata = self.index[chunk_id]['metadata'] - if all(item_metadata.get(key) == value for key, value in metadata.items()): - filtered_ids.append(chunk_id) - top_k_ids = filtered_ids[:top_k] - top_k_scores = top_k_scores[:len(top_k_ids)] - end_time = time.time() - logger.debug(f"[VLite.rank_and_filter] Execution time: {end_time - start_time:.5f} seconds") - return list(zip(top_k_ids, top_k_scores)) + logger.warning("[VLite.retrieve] No query text provided.") + return [] def update(self, id, text=None, metadata=None, vector=None): start_time = time.time() @@ -167,11 +141,12 @@ def update(self, id, text=None, metadata=None, vector=None): if metadata is not None: self.index[chunk_id]['metadata'].update(metadata) if vector is not None: - self.index[chunk_id]['vector'] = vector + self.binary_index.remove(chunk_id) + self.binary_index.add(chunk_id, vector) self.save() - logger.info(f"[VLite.update] Item with ID '{id}' updated successfully.") + print(f"[VLite.update] Item with ID '{id}' updated successfully.") end_time = time.time() - logger.debug(f"[VLite.update] Execution time: {end_time - start_time:.5f} seconds") + print(f"[VLite.update] Execution time: {end_time - start_time:.5f} seconds") return True else: logger.warning(f"[VLite.update] Item with ID '{id}' not found.") @@ -185,11 +160,12 @@ def delete(self, ids): chunk_ids = [chunk_id for chunk_id in self.index if chunk_id.startswith(f"{id}_")] for chunk_id in chunk_ids: if chunk_id in self.index: + self.binary_index.remove(chunk_id) del self.index[chunk_id] deleted_count += 1 if deleted_count > 0: self.save() - logger.info(f"[VLite.delete] Deleted {deleted_count} item(s) from the collection.") + print(f"[VLite.delete] Deleted {deleted_count} item(s) from the collection.") else: logger.warning("[VLite.delete] No items found with the specified IDs.") return deleted_count @@ -221,47 +197,49 @@ def get(self, ids=None, where=None): return items def set(self, id, text=None, metadata=None, vector=None): - logger.info(f"[VLite.set] Setting attributes for item with ID: {id}") - chunk_ids = [chunk_id for chunk_id in self.index if chunk_id.startswith(f"{id}_")] - if chunk_ids: - self.update(id, text, metadata, vector) - else: - self.add(text, metadata=metadata, item_id=id) - logger.info(f"[VLite.set] Item with ID '{id}' created successfully.") + print(f"[VLite.set] Setting attributes for item with ID: {id}") + self.delete(id) # Remove existing item with the same ID + self.add(text, metadata=metadata, item_id=id) # Add the item as a new entry + print(f"[VLite.set] Item with ID '{id}' created successfully.") def count(self): return len(self.index) def save(self): - logger.info(f"[VLite.save] Saving collection to {self.collection}") + print(f"[VLite.save] Saving collection to {self.collection}") with self.ctx.create(self.collection) as ctx_file: ctx_file.set_header( embedding_model="mixedbread-ai/mxbai-embed-large-v1", - embedding_size=self.model.model_metadata.get('bert.embedding_length', 1024), + embedding_size=64, embedding_dtype=self.model.embedding_dtype, - context_length=self.model.model_metadata.get('bert.context_length', 512) + context_length=512 ) for chunk_id, chunk_data in self.index.items(): - ctx_file.add_embedding(chunk_data['binary_vector']) + binary_vector = self.binary_index.index.get(chunk_id, np.zeros(64)) + ctx_file.add_embedding(binary_vector) ctx_file.add_context(chunk_data['text']) if 'metadata' in chunk_data: ctx_file.add_metadata(chunk_id, chunk_data['metadata']) - logger.info("[VLite.save] Collection saved successfully.") + print("[VLite.save] Collection saved successfully.") def clear(self): - logger.info("[VLite.clear] Clearing the collection...") + print("[VLite.clear] Clearing the collection...") self.index = {} + self.binary_index = BinaryVectorIndex() self.ctx.delete(self.collection) - logger.info("[VLite.clear] Collection cleared.") + print("[VLite.clear] Collection cleared.") def info(self): - logger.info("[VLite.info] Collection Information:") - logger.info(f"[VLite.info] Items: {self.count()}") - logger.info(f"[VLite.info] Collection file: {self.collection}") - logger.info(f"[VLite.info] Embedding model: {self.model}") + print("[VLite.info] Collection Information:") + print(f"[VLite.info] Items: {self.count()}") + print(f"[VLite.info] Collection file: {self.collection}") + print(f"[VLite.info] Embedding model: {self.model}") def __repr__(self): return f"VLite(collection={self.collection}, device={self.device}, model={self.model})" def dump(self): - return self.index \ No newline at end of file + return { + 'index': self.index, + 'binary_index': self.binary_index.index + } \ No newline at end of file diff --git a/vlite/onnx.py b/vlite/onnx.py deleted file mode 100644 index b02f5be..0000000 --- a/vlite/onnx.py +++ /dev/null @@ -1,67 +0,0 @@ -# A dependency-light way to run the onnx model - -from tokenizers import Tokenizer -import onnxruntime as ort -import numpy as np -from typing import List - -MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1" - -# Use pytorches default epsilon for division by zero -# https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html -def normalize(v): - norm = np.linalg.norm(v, axis=1) - norm[norm == 0] = 1e-12 - return v / norm[:, np.newaxis] - -# Sampel implementation of the default sentence-transformers model using ONNX -class ONNXModel(): - - def __init__(self): - # max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128 - # https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480 - self.tokenizer = Tokenizer.from_file("onnx/tokenizer.json") - print("[tokenizer ]",self.tokenizer.get_vocab_size()) - print("[tokenizer ]",self.tokenizer.get_vocab()) - - - self.tokenizer.enable_truncation(max_length=512) - self.model = ort.InferenceSession("onnx/model.onnx") - print("[model ]",self.model.get_modelmeta()) - - - def __call__(self, documents: List[str], batch_size: int = 32): - all_embeddings = [] - for i in range(0, len(documents), batch_size): - batch = documents[i:i + batch_size] - encoded = [self.tokenizer.encode(d) for d in batch] - input_ids = np.array([e.ids for e in encoded]) - attention_mask = np.array([e.attention_mask for e in encoded]) - onnx_input = { - "input_ids": np.array(input_ids, dtype=np.int64), - "attention_mask": np.array(attention_mask, dtype=np.int64), - "token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64), - } - model_output = self.model.run(None, onnx_input) - last_hidden_state = model_output[0] - # Perform mean pooling with attention weighting - input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), last_hidden_state.shape) - embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None) - embeddings = normalize(embeddings).astype(np.float32) - all_embeddings.append(embeddings) - return np.concatenate(all_embeddings) - - -# sample_text = "This is a sample text that is likely to overflow the entire model and will be truncated. \ -# Keep writing and writing until we reach the end of the model.This is a sample text that is likely to overflow the entire model and \ -# will be truncated. Keep writing and writing until we reach the end of the model.This is a sample text that is likely to overflow the entire \ -# model and will be truncated. Keep writing and writing until we reach the end of the model. This is a sample text that is likely to overflow \ -# the entire model and will be truncated. Keep writing and writing until we reach the end of the model. This is a sample text that is likely to overflow \ -# the entire model and will be truncated. Keep writing and writing until we reach the end of the model." -# model = DefaultEmbeddingModel() -# # print(model([sample_text, sample_text])) - -# embeddings = model([sample_text, sample_text]) -# print(embeddings.shape) -# # print(embeddings[0] == embeddings[1]) -# # print(embeddings) \ No newline at end of file