Skip to content

Commit

Permalink
0.09724 seconds retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 16, 2024
1 parent 7555737 commit 2b55acd
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 163 deletions.
38 changes: 38 additions & 0 deletions vlite/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import struct
import json
from enum import Enum
from typing import List, Union, Dict

class BinaryVectorIndex:
def __init__(self, embedding_size=64):
self.index = {}
self.embedding_size = embedding_size

def add(self, chunk_id, binary_vector):
binary_vector = binary_vector.tolist()
self.index[chunk_id] = binary_vector

def add_batch(self, chunk_ids, binary_vectors):
for chunk_id, binary_vector in zip(chunk_ids, binary_vectors):
self.add(chunk_id, binary_vector)

def remove(self, chunk_id):
if chunk_id in self.index:
del self.index[chunk_id]

def search(self, query_vector, top_k):
query_vector = np.array(query_vector.tolist())
binary_vectors = np.array(list(self.index.values()))
chunk_ids = np.array(list(self.index.keys()))

distances = np.count_nonzero(binary_vectors != query_vector[:binary_vectors.shape[1]], axis=1)
similarities = 1 - distances / binary_vectors.shape[1]

sorted_indices = np.argsort(similarities)[::-1]
top_k_indices = sorted_indices[:top_k]

top_k_ids = chunk_ids[top_k_indices]
top_k_scores = similarities[top_k_indices]

return top_k_ids.tolist(), top_k_scores.tolist()
170 changes: 74 additions & 96 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from .utils import check_cuda_available, check_mps_available
from .model import EmbeddingModel
from .utils import chop_and_chunk
from .index import BinaryVectorIndex
import datetime
from .ctx import Ctx
import time
import logging

# Configure logging
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class VLite:
def __init__(self, collection=None, device=None, model_name='mixedbread-ai/mxbai-embed-large-v1'):
Expand All @@ -22,7 +23,7 @@ def __init__(self, collection=None, device=None, model_name='mixedbread-ai/mxbai
device = 'mps'
else:
device = 'cpu'
logger.info(f"[VLite.__init__] Initializing VLite with device: {device}")
print(f"[VLite.__init__] Initializing VLite with device: {device}")
self.device = device
if collection is None:
current_datetime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
Expand All @@ -31,25 +32,32 @@ def __init__(self, collection=None, device=None, model_name='mixedbread-ai/mxbai
self.model = EmbeddingModel(model_name, device=device) if model_name else EmbeddingModel()
self.ctx = Ctx()
self.index = {}
self.binary_index = BinaryVectorIndex()
try:
ctx_file = self.ctx.read(collection)
ctx_file.load()
logger.debug(f"[VLite.__init__] Number of embeddings: {len(ctx_file.embeddings)}")
logger.debug(f"[VLite.__init__] Number of metadata: {len(ctx_file.metadata)}")
print(f"[VLite.__init__] Number of embeddings: {len(ctx_file.embeddings)}")
print(f"[VLite.__init__] Number of metadata: {len(ctx_file.metadata)}")

chunk_ids = list(ctx_file.metadata.keys())
self.index = {
chunk_id: {
'text': ctx_file.contexts[idx] if idx < len(ctx_file.contexts) else "",
'metadata': ctx_file.metadata.get(chunk_id, {}),
'binary_vector': np.array(ctx_file.embeddings[idx]) if idx < len(ctx_file.embeddings) else np.zeros(self.model.embedding_size)
}
for idx, chunk_id in enumerate(ctx_file.metadata.keys())
for idx, chunk_id in enumerate(chunk_ids)
}

self.binary_index.add_batch(
chunk_ids,
[np.array(embedding) if idx < len(ctx_file.embeddings) else np.zeros(64) for idx, embedding in enumerate(ctx_file.embeddings)]
)
except FileNotFoundError:
logger.warning(f"[VLite.__init__] Collection file {self.collection} not found. Initializing empty attributes.")
end_time = time.time()
logger.debug(f"[VLite.__init__] Execution time: {end_time - start_time:.5f} seconds")
logger.info(f"[VLite.__init__] Using device: {self.device}")

print(f"[VLite.__init__] Execution time: {end_time - start_time:.5f} seconds")
print(f"[VLite.__init__] Using device: {self.device}")
def add(self, data, metadata=None, item_id=None, need_chunks=False, fast=True):
start_time = time.time()
data = [data] if not isinstance(data, list) else data
Expand All @@ -71,91 +79,57 @@ def add(self, data, metadata=None, item_id=None, need_chunks=False, fast=True):
chunks = chop_and_chunk(text_content, fast=fast)
else:
chunks = [text_content]
logger.debug("[VLite.add] Encoding text... not chunking")
print("[VLite.add] Encoding text... not chunking")
all_chunks.extend(chunks)
all_metadata.extend([item_metadata] * len(chunks))
all_ids.extend([item_id] * len(chunks))
binary_encoded_data = self.model.embed(all_chunks, precision="binary")
for idx, (chunk, binary_vector, metadata) in enumerate(zip(all_chunks, binary_encoded_data, all_metadata)):
chunk_id = f"{item_id}_{idx}"
chunk_id = f"{all_ids[idx]}_{idx}"
self.index[chunk_id] = {
'text': chunk,
'metadata': metadata,
'binary_vector': binary_vector.tolist()
'item_id': all_ids[idx] # Store the item ID along with the chunk data
}

self.binary_index.add(chunk_id, binary_vector)
print(f"[VLite.add] Added chunk ID: {chunk_id}")
print(f"[VLite.add] Main index keys: {list(self.index.keys())}")
print(f"[VLite.add] Binary index keys: {list(self.binary_index.index.keys())}")
if item_id not in [result[0] for result in results]:
results.append((item_id, binary_encoded_data, metadata))

self.save()
logger.info("[VLite.add] Text added successfully.")
print("[VLite.add] Text added successfully.")
end_time = time.time()
logger.debug(f"[VLite.add] Execution time: {end_time - start_time:.5f} seconds")
print(f"[VLite.add] Execution time: {end_time - start_time:.5f} seconds")
return results

def retrieve(self, text=None, top_k=5, metadata=None, return_scores=False):
start_time = time.time()
logger.info("[VLite.retrieve] Retrieving similar texts...")
if text:
logger.info(f"[VLite.retrieve] Retrieving top {top_k} similar texts for query: {text}")
query_binary_vectors = self.model.embed(text, precision="binary")
# Perform search on the query binary vectors
results = []
for query_binary_vector in query_binary_vectors:
chunk_results = self.rank_and_filter(query_binary_vector, top_k, metadata)
results.extend(chunk_results)
# Sort the results by similarity score
results.sort(key=lambda x: x[1])
results = results[:top_k]
logger.info("[VLite.retrieve] Retrieval completed.")
end_time = time.time()
logger.debug(f"[VLite.retrieve] Execution time: {end_time - start_time:.5f} seconds")

top_k_ids, top_k_scores = self.binary_index.search(query_binary_vectors, top_k)

chunk_data = [self.index[chunk_id] for chunk_id in top_k_ids if chunk_id in self.index]

if metadata:
chunk_data = [data for data in chunk_data if all(data['metadata'].get(key) == value for key, value in metadata.items())]

texts = [data['text'] for data in chunk_data]
metadatas = [data['metadata'] for data in chunk_data]
scores = [score for chunk_id, score in zip(top_k_ids, top_k_scores) if chunk_id in self.index][:len(chunk_data)]

if return_scores:
return [(idx, self.index[idx]['text'], self.index[idx]['metadata'], score) for idx, score in results]
results = list(zip(chunk_data, texts, metadatas, scores))
else:
return [(idx, self.index[idx]['text'], self.index[idx]['metadata']) for idx, _ in results]

def rank_and_filter(self, query_binary_vector, top_k, metadata=None):
start_time = time.time()
logger.debug(f"[VLite.rank_and_filter] Shape of query vector: {query_binary_vector.shape}")
query_binary_vector = np.array(query_binary_vector).reshape(-1)
logger.debug(f"[VLite.rank_and_filter] Shape of query vector after reshaping: {query_binary_vector.shape}")
# Collect all binary vectors and ensure they all have the same shape as the query vector
binary_vectors = []
mismatch_count = 0
for item_id, item in self.index.items():
binary_vector = item['binary_vector']
if len(binary_vector) == len(query_binary_vector):
binary_vectors.append(binary_vector)
else:
mismatch_count += 1
logger.warning(f"[VLite.rank_and_filter] Skipping vector with ID {item_id} of length {len(binary_vector)}, expected length {len(query_binary_vector)}")
if mismatch_count > 0:
logger.warning(f"[VLite.rank_and_filter] Skipped {mismatch_count} vectors due to length mismatch.")
# Convert list of binary vectors to a NumPy array
if binary_vectors:
corpus_binary_vectors = np.array(binary_vectors)
logger.debug(f"[VLite.rank_and_filter] Shape of corpus binary vectors array: {corpus_binary_vectors.shape}")
results = list(zip(chunk_data, texts, metadatas))

end_time = time.time()
logger.debug(f"[VLite.retrieve] Execution time: {end_time - start_time:.5f} seconds")
return results
else:
raise ValueError("No valid binary vectors found for comparison.")
top_k_indices, top_k_scores = self.model.search(query_binary_vector, corpus_binary_vectors, top_k)
logger.debug(f"[VLite.rank_and_filter] Top {top_k} indices: {top_k_indices}")
logger.debug(f"[VLite.rank_and_filter] Top {top_k} scores: {top_k_scores}")
logger.debug(f"[VLite.rank_and_filter] No. of items in the collection: {len(self.index)}")
logger.debug(f"[VLite.rank_and_filter] Vlite count: {self.count()}")
top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices]
# Apply metadata filter on the retrieved top_k items
filtered_ids = []
if metadata:
for chunk_id in top_k_ids:
item_metadata = self.index[chunk_id]['metadata']
if all(item_metadata.get(key) == value for key, value in metadata.items()):
filtered_ids.append(chunk_id)
top_k_ids = filtered_ids[:top_k]
top_k_scores = top_k_scores[:len(top_k_ids)]
end_time = time.time()
logger.debug(f"[VLite.rank_and_filter] Execution time: {end_time - start_time:.5f} seconds")
return list(zip(top_k_ids, top_k_scores))
logger.warning("[VLite.retrieve] No query text provided.")
return []

def update(self, id, text=None, metadata=None, vector=None):
start_time = time.time()
Expand All @@ -167,11 +141,12 @@ def update(self, id, text=None, metadata=None, vector=None):
if metadata is not None:
self.index[chunk_id]['metadata'].update(metadata)
if vector is not None:
self.index[chunk_id]['vector'] = vector
self.binary_index.remove(chunk_id)
self.binary_index.add(chunk_id, vector)
self.save()
logger.info(f"[VLite.update] Item with ID '{id}' updated successfully.")
print(f"[VLite.update] Item with ID '{id}' updated successfully.")
end_time = time.time()
logger.debug(f"[VLite.update] Execution time: {end_time - start_time:.5f} seconds")
print(f"[VLite.update] Execution time: {end_time - start_time:.5f} seconds")
return True
else:
logger.warning(f"[VLite.update] Item with ID '{id}' not found.")
Expand All @@ -185,11 +160,12 @@ def delete(self, ids):
chunk_ids = [chunk_id for chunk_id in self.index if chunk_id.startswith(f"{id}_")]
for chunk_id in chunk_ids:
if chunk_id in self.index:
self.binary_index.remove(chunk_id)
del self.index[chunk_id]
deleted_count += 1
if deleted_count > 0:
self.save()
logger.info(f"[VLite.delete] Deleted {deleted_count} item(s) from the collection.")
print(f"[VLite.delete] Deleted {deleted_count} item(s) from the collection.")
else:
logger.warning("[VLite.delete] No items found with the specified IDs.")
return deleted_count
Expand Down Expand Up @@ -221,47 +197,49 @@ def get(self, ids=None, where=None):
return items

def set(self, id, text=None, metadata=None, vector=None):
logger.info(f"[VLite.set] Setting attributes for item with ID: {id}")
chunk_ids = [chunk_id for chunk_id in self.index if chunk_id.startswith(f"{id}_")]
if chunk_ids:
self.update(id, text, metadata, vector)
else:
self.add(text, metadata=metadata, item_id=id)
logger.info(f"[VLite.set] Item with ID '{id}' created successfully.")
print(f"[VLite.set] Setting attributes for item with ID: {id}")
self.delete(id) # Remove existing item with the same ID
self.add(text, metadata=metadata, item_id=id) # Add the item as a new entry
print(f"[VLite.set] Item with ID '{id}' created successfully.")

def count(self):
return len(self.index)

def save(self):
logger.info(f"[VLite.save] Saving collection to {self.collection}")
print(f"[VLite.save] Saving collection to {self.collection}")
with self.ctx.create(self.collection) as ctx_file:
ctx_file.set_header(
embedding_model="mixedbread-ai/mxbai-embed-large-v1",
embedding_size=self.model.model_metadata.get('bert.embedding_length', 1024),
embedding_size=64,
embedding_dtype=self.model.embedding_dtype,
context_length=self.model.model_metadata.get('bert.context_length', 512)
context_length=512
)
for chunk_id, chunk_data in self.index.items():
ctx_file.add_embedding(chunk_data['binary_vector'])
binary_vector = self.binary_index.index.get(chunk_id, np.zeros(64))
ctx_file.add_embedding(binary_vector)
ctx_file.add_context(chunk_data['text'])
if 'metadata' in chunk_data:
ctx_file.add_metadata(chunk_id, chunk_data['metadata'])
logger.info("[VLite.save] Collection saved successfully.")
print("[VLite.save] Collection saved successfully.")

def clear(self):
logger.info("[VLite.clear] Clearing the collection...")
print("[VLite.clear] Clearing the collection...")
self.index = {}
self.binary_index = BinaryVectorIndex()
self.ctx.delete(self.collection)
logger.info("[VLite.clear] Collection cleared.")
print("[VLite.clear] Collection cleared.")

def info(self):
logger.info("[VLite.info] Collection Information:")
logger.info(f"[VLite.info] Items: {self.count()}")
logger.info(f"[VLite.info] Collection file: {self.collection}")
logger.info(f"[VLite.info] Embedding model: {self.model}")
print("[VLite.info] Collection Information:")
print(f"[VLite.info] Items: {self.count()}")
print(f"[VLite.info] Collection file: {self.collection}")
print(f"[VLite.info] Embedding model: {self.model}")

def __repr__(self):
return f"VLite(collection={self.collection}, device={self.device}, model={self.model})"

def dump(self):
return self.index
return {
'index': self.index,
'binary_index': self.binary_index.index
}
Loading

0 comments on commit 2b55acd

Please sign in to comment.