Skip to content

Commit

Permalink
add timing and default to ST
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 13, 2024
1 parent d662f86 commit 2598156
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 60 deletions.
49 changes: 33 additions & 16 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from .utils import chop_and_chunk
import datetime
from .ctx import Ctx
import time

class VLite:
def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxbai-embed-large-v1'):
start_time = time.time()
if collection is None:
current_datetime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
collection = f"vlite_{current_datetime}"
Expand All @@ -33,8 +35,12 @@ def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxba
}
except FileNotFoundError:
print(f"Collection file {self.collection} not found. Initializing empty attributes.")

end_time = time.time()
print(f"[__init__] Execution time: {end_time - start_time:.5f} seconds")

def add(self, data, metadata=None, item_id=None, need_chunks=True, fast=True):
start_time = time.time()
data = [data] if not isinstance(data, list) else data
results = []
all_chunks = []
Expand Down Expand Up @@ -65,6 +71,7 @@ def add(self, data, metadata=None, item_id=None, need_chunks=True, fast=True):
all_ids.extend([item_id] * len(chunks))

encoded_data = self.model.embed(all_chunks, device=self.device)
# encoded_data = self.model.encode_with_onnx(all_chunks)
binary_encoded_data = self.model.quantize(encoded_data, precision="binary")

for idx, (chunk, binary_vector, metadata) in enumerate(zip(all_chunks, binary_encoded_data, all_metadata)):
Expand All @@ -80,38 +87,45 @@ def add(self, data, metadata=None, item_id=None, need_chunks=True, fast=True):

self.save()
print("Text added successfully.")
end_time = time.time()
print(f"[add] Execution time: {end_time - start_time:.5f} seconds")
return results

def retrieve(self, text=None, top_k=5, metadata=None, return_scores=False):
start_time = time.time()
print("Retrieving similar texts...")
if text:
print(f"Retrieving top {top_k} similar texts for query: {text}")
query_chunks = chop_and_chunk(text, fast=True)
query_vectors = self.model.embed(query_chunks, device=self.device)

# Embed and quantize the query text
query_vectors = self.model.embed(text, device=self.device)
# query_vectors = self.model.encode_with_onnx([text])
query_binary_vectors = self.model.quantize(query_vectors, precision="binary")

# Perform search on the query binary vectors
results = []
for query_binary_vector in query_binary_vectors:
chunk_results = self.search(query_binary_vector, top_k, metadata)
chunk_results = self.rank_and_filter(query_binary_vector, top_k, metadata)
results.extend(chunk_results)

results.sort(key=lambda x: x[1], reverse=True)
# Sort the results by similarity score
results.sort(key=lambda x: x[1])
results = results[:top_k]

print("Retrieval completed.")
end_time = time.time()
print(f"[retrieve] Execution time: {end_time - start_time:.5f} seconds")
if return_scores:
return [(idx, self.index[idx]['text'], self.index[idx]['metadata'], score) for idx, score in results]
else:
return [(idx, self.index[idx]['text'], self.index[idx]['metadata']) for idx, _ in results]

def search(self, query_binary_vector, top_k, metadata=None):
# Reshape query_binary_vector to 1D array
query_binary_vector = query_binary_vector.reshape(-1)

# Perform binary search
binary_vectors = np.array([item['binary_vector'] for item in self.index.values()])
binary_similarities = np.einsum('i,ji->j', query_binary_vector, binary_vectors)
top_k_indices = np.argpartition(binary_similarities, -top_k)[-top_k:]

def rank_and_filter(self, query_binary_vector, top_k, metadata=None):
start_time = time.time()
query_binary_vector = np.array(query_binary_vector).reshape(-1)

corpus_binary_vectors = np.array([item['binary_vector'] for item in self.index.values()])
top_k_indices, top_k_scores = self.model.search(query_binary_vector, corpus_binary_vectors, top_k)
top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices]

# Apply metadata filter on the retrieved top_k items
Expand All @@ -122,14 +136,15 @@ def search(self, query_binary_vector, top_k, metadata=None):
if all(item_metadata.get(key) == value for key, value in metadata.items()):
filtered_ids.append(chunk_id)
top_k_ids = filtered_ids[:top_k]

# Get the similarity scores for the top_k items
top_k_scores = binary_similarities[top_k_indices]
top_k_scores = top_k_scores[:len(top_k_ids)]

return list(zip(top_k_ids, top_k_scores))
end_time = time.time()
print(f"[rank_and_filter] Execution time: {end_time - start_time:.5f} seconds")


def update(self, id, text=None, metadata=None, vector=None):
start_time = time.time()
chunk_ids = [chunk_id for chunk_id in self.index if chunk_id.startswith(f"{id}_")]
if chunk_ids:
for chunk_id in chunk_ids:
Expand All @@ -141,6 +156,8 @@ def update(self, id, text=None, metadata=None, vector=None):
self.index[chunk_id]['vector'] = vector
self.save()
print(f"Item with ID '{id}' updated successfully.")
end_time = time.time()
print(f"[update] Execution time: {end_time - start_time:.5f} seconds")
return True
else:
print(f"Item with ID '{id}' not found.")
Expand Down
126 changes: 82 additions & 44 deletions vlite/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,109 @@
import numpy as np
from typing import List
from tokenizers import Tokenizer
import numpy as np
from typing import List
from sentence_transformers import SentenceTransformer
import time


def normalize(v):
norm = np.linalg.norm(v, axis=1)
if v.ndim == 1:
v = v.reshape(1, -1) # Reshape v to 2D array if it is 1D
norm = np.linalg.norm(v, axis=1, keepdims=True)
norm[norm == 0] = 1e-12
return v / norm[:, np.newaxis]
return v / norm



class EmbeddingModel:
def __init__(self, model_name="mixedbread-ai/mxbai-embed-large-v1"):
tokenizer_path = hf_hub_download(repo_id=model_name, filename="tokenizer.json")
model_path = hf_hub_download(repo_id=model_name, filename="onnx/model.onnx")

self.tokenizer = Tokenizer.from_file(tokenizer_path)
self.tokenizer.enable_truncation(max_length=512)
self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=512)
start_time = time.time()
self.model = SentenceTransformer(model_name)

self.model = ort.InferenceSession(model_path)
print("[model]", self.model.get_modelmeta())
# tokenizer_path = hf_hub_download(repo_id=model_name, filename="tokenizer.json")
# model_path = hf_hub_download(repo_id=model_name, filename="onnx/model.onnx")

# self.tokenizer = Tokenizer.from_file(tokenizer_path)
# self.tokenizer.enable_truncation(max_length=512)
# self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=512)

# self.model = ort.InferenceSession(model_path)

self.model_metadata = {
"bert.embedding_length": 1024,
"bert.embedding_length": 512,
"bert.context_length": 512
}
self.embedding_size = self.model_metadata.get("bert.embedding_length", 1024)
self.context_length = self.model_metadata.get("bert.context_length", 512)
self.embedding_dtype = "float32"
end_time = time.time()
print(f"[model.__init__] Execution time: {end_time - start_time:.5f} seconds")

def embed(self, texts, max_seq_length=512, device="cpu", batch_size=32):
start_time = time.time()
if isinstance(texts, str):
texts = [texts] # Ensure texts is always a list
embeddings = self.model.encode(texts, device=device, batch_size=batch_size, normalize_embeddings=True)
end_time = time.time()
print(f"[model.embed] Execution time: {end_time - start_time:.5f} seconds")
return embeddings

def encode_with_onnx(self, texts):
# Ensure all text items are strings
if not all(isinstance(text, str) for text in texts):
raise ValueError("All items in the 'texts' list should be strings.")

def embed(self, texts: List[str], max_seq_length=512, device="cpu", batch_size=32):
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
encoded = [self.tokenizer.encode(d) for d in batch]
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
token_type_ids = np.zeros_like(input_ids, dtype=np.int64)
try:
# Tokenize texts and convert to the correct format
inputs = self.tokenizer.encode_batch(texts)
input_ids = np.array([x.ids for x in inputs], dtype=np.int64)
attention_mask = np.array([x.attention_mask for x in inputs], dtype=np.int64)
token_type_ids = np.zeros_like(input_ids, dtype=np.int64) # Add token_type_ids input

onnx_input = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
ort_inputs = {
self.model.get_inputs()[0].name: input_ids,
self.model.get_inputs()[1].name: attention_mask,
self.model.get_inputs()[2].name: token_type_ids # Add token_type_ids input
}
model_output = self.model.run(None, onnx_input)
last_hidden_state = model_output[0]

input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), last_hidden_state.shape)
embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None)
embeddings = normalize(embeddings).astype(np.float32)
all_embeddings.append(embeddings)
ort_outs = self.model.run(None, ort_inputs)
embeddings = ort_outs[0]
return embeddings
except Exception as e:
print(f"Failed during ONNX encoding: {e}")
raise

return np.concatenate(all_embeddings)

def token_count(self, texts):
tokens = 0
for text in texts:
encoded = self.tokenizer.encode(text)
tokens += len(encoded.ids)
return tokens

def quantize(self, embeddings, precision="binary"):
embeddings = np.array(embeddings)
start_time = time.time()
# first normalize_embeddings to unit length
embeddings = normalize(embeddings)
# slice to get MRL embeddings
embeddings_slice = embeddings[..., :512]

if precision == "binary":
return np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1)
elif precision == "int8":
return ((embeddings - np.min(embeddings, axis=0)) / (np.max(embeddings, axis=0) - np.min(embeddings, axis=0)) * 255).astype(np.uint8)
end_time = time.time()
print(f"[model.quantize] Execution time: {end_time - start_time:.5f} seconds")
return self._binary_quantize(embeddings_slice)
else:
raise ValueError(f"Unsupported precision: {precision}")
raise ValueError(f"Precision {precision} is not supported")

def _binary_quantize(self, embeddings):
return (np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) - 128).astype(np.int8)

def rescore(self, query_vector, vectors):
return np.dot(query_vector, vectors.T).flatten()
def hamming_distance(self, embedding1, embedding2):
# Ensure the embeddings are numpy arrays for the operation.
return np.count_nonzero(np.array(embedding1) != np.array(embedding2))

def search(self, query_embedding, embeddings, top_k):
start_time = time.time()
# Convert embeddings to a numpy array for efficient operations if not already.
embeddings = np.array(embeddings)
distances = np.array([self.hamming_distance(query_embedding, emb) for emb in embeddings])

# Find the indices of the top_k smallest distances
top_k_indices = np.argsort(distances)[:top_k]
end_time = time.time()
print(f"[model.search] Execution time: {end_time - start_time:.5f} seconds")
return top_k_indices, distances[top_k_indices]

0 comments on commit 2598156

Please sign in to comment.