Skip to content

Commit

Permalink
tests run, omnoms
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 5, 2024
1 parent 4d25801 commit ef81aed
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 195 deletions.
Binary file added tests/omnoms/vlite-unit.omom
Binary file not shown.
3 changes: 3 additions & 0 deletions tests/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def tearDownClass(cls):
if os.path.exists('vlite-unit.npz'):
print("[+] Removing vlite")
os.remove('vlite-unit.npz')
if os.path.exists('vlite-unit.omom'):
print("[+] Removing vlite")
os.remove('vlite-unit.omom')

if __name__ == '__main__':
unittest.main(verbosity=2)
203 changes: 39 additions & 164 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,35 @@
from .model import EmbeddingModel
from .utils import chop_and_chunk
import datetime
from .omom import Omom

class VLite:
"""
A simple vector database for text embedding and retrieval.
Attributes:
collection (str): Path to the collection file.
device (str): Device to use for embedding ('cpu' or 'cuda').
model (EmbeddingModel): The embedding model used for text representation.
Methods:
add(text, id=None, metadata=None): Adds a text to the collection with optional ID and metadata.
retrieve(text=None, id=None, top_k=5): Retrieves similar texts from the collection.
save(): Saves the collection to a file.
"""
def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxbai-embed-large-v1'):
if collection is None:
current_datetime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
collection = f"vlite_{current_datetime}"
self.collection = f"{collection}.npz"
self.collection = f"{collection}"
self.device = device
self.model = EmbeddingModel(model_name) if model_name else EmbeddingModel()

self.omom = Omom()
self.index = {}

try:
with np.load(self.collection, allow_pickle=True) as data:
index_data = data['index'].item()
with self.omom.read(collection) as omom_file:
self.index = {
chunk_id: {
'text': chunk_data['text'],
'metadata': chunk_data['metadata'],
'vector': np.array(chunk_data['vector']), # Convert back to numpy array
'binary_vector': np.array(chunk_data['binary_vector']), # Convert back to numpy array
'int8_vector': np.array(chunk_data['int8_vector']) # Convert back to numpy array
'binary_vector': np.array(chunk_data['binary_vector'])
}
for chunk_id, chunk_data in index_data.items()
for chunk_id, chunk_data in omom_file.metadata.items()
}
except FileNotFoundError:
print(f"Collection file {self.collection} not found. Initializing empty attributes.")
self.index = {}

def add(self, data, metadata=None, need_chunks=True, newEmbedding=False, fast=True):
"""
Adds text or a list of texts to the collection with optional ID within metadata.
Args:
data (str, dict, or list): Text data to be added. Can be a string, a dictionary containing text, id, and/or metadata, or a list of strings or dictionaries.
metadata (dict, optional): Additional metadata to be appended to each text entry.
need_chunks (bool, optional): Whether to split the text into chunks before embedding. Defaults to True.
fast (bool, optional): Whether to use fast mode for chunking. Defaults to True.
Returns:
list: A list of tuples, each containing the ID of the added text and the updated vectors array.
"""
print("Adding text to the collection...")

def add(self, data, metadata=None, need_chunks=True, fast=True):
print("Adding text to the collection...", self.collection)
data = [data] if not isinstance(data, list) else data
results = []
all_chunks = []
Expand Down Expand Up @@ -88,48 +63,33 @@ def add(self, data, metadata=None, need_chunks=True, newEmbedding=False, fast=Tr

encoded_data = self.model.embed(all_chunks, device=self.device)
binary_encoded_data = self.model.quantize(encoded_data, precision="binary")
int8_encoded_data = self.model.quantize(encoded_data, precision="int8")

for idx, (chunk, vector, binary_vector, int8_vector, metadata, item_id) in enumerate(zip(all_chunks, encoded_data, binary_encoded_data, int8_encoded_data, all_metadata, all_ids)):

for idx, (chunk, binary_vector, metadata, item_id) in enumerate(zip(all_chunks, binary_encoded_data, all_metadata, all_ids)):
chunk_id = f"{item_id}_{idx}"
self.index[chunk_id] = {
'text': chunk,
'metadata': metadata,
'vector': vector,
'binary_vector': binary_vector.tolist(),
'int8_vector': int8_vector.tolist()
'binary_vector': binary_vector.tolist()
}

if item_id not in [result[0] for result in results]:
results.append((item_id, encoded_data, metadata))
results.append((item_id, binary_encoded_data, metadata))

self.save()
print("Text added successfully.")
return results

def retrieve(self, text=None, top_k=5, metadata=None, newEmbedding=False):
"""
Retrieves similar texts from the collection based on text content, ID, or metadata.
Args:
text (str, optional): Query text for finding similar texts.
top_k (int, optional): Number of top similar texts to retrieve. Defaults to 5.
metadata (dict, optional): Metadata to filter the retrieved texts.
Returns:
tuple: A tuple containing a list of similar texts, their similarity scores, and metadata (if applicable).
"""
def retrieve(self, text=None, top_k=5, metadata=None):
print("Retrieving similar texts...")
if text:
print(f"Retrieving top {top_k} similar texts for query: {text}")
query_chunks = chop_and_chunk(text, fast=True)
query_vectors = self.model.embed(query_chunks, device=self.device)
query_binary_vectors = self.model.quantize(query_vectors, precision="binary")
query_int8_vectors = self.model.quantize(query_vectors, precision="int8")

results = []
for query_binary_vector, query_int8_vector in zip(query_binary_vectors, query_int8_vectors):
chunk_results = self.rescore(query_binary_vector, query_int8_vector, top_k, metadata)
for query_binary_vector in query_binary_vectors:
chunk_results = self.search(query_binary_vector, top_k, metadata)
results.extend(chunk_results)

results.sort(key=lambda x: x[1], reverse=True)
Expand All @@ -138,59 +98,31 @@ def retrieve(self, text=None, top_k=5, metadata=None, newEmbedding=False):
print("Retrieval completed.")
return [(self.index[idx]['text'], score, self.index[idx]['metadata']) for idx, score in results]

def rescore(self, query_binary_vector, query_int8_vector, top_k, metadata=None):
"""
Performs retrieval using binary search and rescoring using int8 embeddings.
Args:
query_binary_vector (numpy.ndarray): Binary vector of the query.
query_int8_vector (numpy.ndarray): Int8 vector of the query.
top_k (int): Number of top similar texts to retrieve.
metadata (dict, optional): Metadata to filter the retrieved texts.
Returns:
list: A list of tuples containing the chunk IDs and their similarity scores.
"""
# Reshape query_binary_vector and query_int8_vector to 1D arrays
def search(self, query_binary_vector, top_k, metadata=None):
# Reshape query_binary_vector to 1D array
query_binary_vector = query_binary_vector.reshape(-1)
query_int8_vector = query_int8_vector.reshape(-1)

# Perform binary search
binary_vectors = np.array([item['binary_vector'] for item in self.index.values()])
binary_similarities = np.einsum('i,ji->j', query_binary_vector, binary_vectors)
top_k_indices = np.argpartition(binary_similarities, -top_k*4)[-top_k*4:]
top_k_indices = np.argpartition(binary_similarities, -top_k)[-top_k:]
top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices]

# Apply metadata filter on the retrieved top_k*4 items
# Apply metadata filter on the retrieved top_k items
if metadata:
filtered_ids = []
for item_id in top_k_ids:
item_metadata = self.index[item_id]['metadata']
if all(item_metadata.get(key) == value for key, value in metadata.items()):
filtered_ids.append(item_id)
top_k_ids = filtered_ids[:top_k*4]

# Perform rescoring using int8 embeddings
int8_vectors = np.array([self.index[idx]['int8_vector'] for idx in top_k_ids])
int8_similarities = np.einsum('i,ji->j', query_int8_vector, int8_vectors)
top_k_ids = filtered_ids[:top_k]

# Sort the results based on the int8 similarities
sorted_indices = np.argpartition(int8_similarities, -top_k)[-top_k:]
sorted_ids = np.take(top_k_ids, sorted_indices)
sorted_scores = int8_similarities[sorted_indices]
# Get the similarity scores for the top_k items
top_k_scores = binary_similarities[top_k_indices]

return list(zip(sorted_ids, sorted_scores))
return list(zip(top_k_ids, top_k_scores))

def delete(self, ids):
"""
Deletes items from the collection by their IDs.
Args:
ids (list or str): A single ID or a list of IDs of the items to delete.
Returns:
int: The number of items deleted from the collection.
"""
if isinstance(ids, str):
ids = [ids]

Expand All @@ -209,18 +141,6 @@ def delete(self, ids):
return deleted_count

def update(self, id, text=None, metadata=None, vector=None):
"""
Updates an item in the collection by its ID.
Args:
id (str): The ID of the item to update.
text (str, optional): The updated text content of the item.
metadata (dict, optional): The updated metadata of the item.
vector (numpy.ndarray, optional): The updated embedding vector of the item.
Returns:
bool: True if the item was successfully updated, False otherwise.
"""
if id in self.index:
if text is not None:
self.index[id]['text'] = text
Expand All @@ -239,40 +159,18 @@ def update(self, id, text=None, metadata=None, vector=None):
return False

def get(self, ids=None, where=None):
"""
Retrieves items from the collection based on IDs and/or metadata.
Args:
ids (list, optional): List of IDs to retrieve. If provided, only items with the specified IDs will be returned.
where (dict, optional): Metadata filter to apply. Items matching the filter will be returned.
Returns:
list: A list of retrieved items, each item being a tuple of (text, metadata).
"""
if ids is not None:
# Convert ids to a set for faster membership testing
id_set = set(ids)
items = [(self.index[id]['text'], self.index[id]['metadata']) for id in self.index if id in id_set]
else:
items = [(self.index[id]['text'], self.index[id]['metadata']) for id in self.index]

if where is not None:
# Filter items based on metadata
items = [item for item in items if all(item[1].get(key) == value for key, value in where.items())]

return items


def set(self, id, text=None, metadata=None, vector=None):
"""
Updates the attributes of an item in the collection by ID.
Args:
id (str): ID of the item to update.
text (str, optional): Updated text content of the item.
metadata (dict, optional): Updated metadata of the item.
vector (numpy.ndarray, optional): Updated embedding vector of the item.
"""
print(f"Setting attributes for item with ID: {id}")
if id in self.index:
if text is not None:
Expand All @@ -286,48 +184,31 @@ def set(self, id, text=None, metadata=None, vector=None):
print(f"Item with ID {id} not found.")

def count(self):
"""
Returns the number of items in the collection.
Returns:
int: The count of items in the collection.
"""
return len(self.index)


def save(self):
"""
Saves the current state of the collection to a file.
"""
print(f"Saving collection to {self.collection}")
index_data = {
chunk_id: {
'text': chunk_data['text'],
'metadata': chunk_data['metadata'],
'vector': chunk_data['vector'],
'binary_vector': chunk_data['binary_vector'],
'int8_vector': chunk_data['int8_vector']
}
for chunk_id, chunk_data in self.index.items()
}
with open(self.collection, 'wb') as f:
np.savez(f, index=index_data)
with self.omom.create(self.collection) as omom_file:
omom_file.set_header(
embedding_model=self.model.model_metadata['general.name'],
embedding_size=self.model.model_metadata.get('bert.embedding_length', 1024),
embedding_dtype=self.model.embedding_dtype,
context_length=self.model.model_metadata.get('bert.context_length', 512)
)
for chunk_id, chunk_data in self.index.items():
omom_file.add_embedding(chunk_data['binary_vector'])
omom_file.add_context(chunk_data['text'])
omom_file.add_metadata(chunk_id, chunk_data['metadata'])
print("Collection saved successfully.")


def clear(self):
"""
Clears the entire collection, removing all items and resetting the attributes.
"""
print("Clearing the collection...")
self.index = {}
self.save()
self.omom.delete(self.collection)
print("Collection cleared.")

def info(self):
"""
Prints information about the collection, including the number of items, collection file path,
and the embedding model used.
"""
print("Collection Information:")
print(f" Items: {self.count()}")
print(f" Collection file: {self.collection}")
Expand All @@ -337,10 +218,4 @@ def __repr__(self):
return f"VLite(collection={self.collection}, device={self.device}, model={self.model})"

def dump(self):
"""
Dumps the collection data to a dictionary for serialization.
Returns:
dict: A dictionary containing the collection data.
"""
return self.index
Loading

0 comments on commit ef81aed

Please sign in to comment.