Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove int8 rescoring and use mrl #47

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ docx2txt
pandas
Requests
beautifulsoup4
llama-cpp-python
huggingface_hub
tiktoken
onnxruntime==1.17.1
Expand Down
Binary file added tests/contexts/my_collection.ctx
Binary file not shown.
32 changes: 15 additions & 17 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def add(self, data, metadata=None, item_id=None, need_chunks=True, fast=True):
all_metadata.extend([item_metadata] * len(chunks))
all_ids.extend([item_id] * len(chunks))

encoded_data = self.model.embed(all_chunks, device=self.device)
encoded_data = self.model.encode_with_onnx(all_chunks)
binary_encoded_data = self.model.quantize(encoded_data, precision="binary")

for idx, (chunk, binary_vector, metadata) in enumerate(zip(all_chunks, binary_encoded_data, all_metadata)):
Expand All @@ -86,32 +86,32 @@ def retrieve(self, text=None, top_k=5, metadata=None, return_scores=False):
print("Retrieving similar texts...")
if text:
print(f"Retrieving top {top_k} similar texts for query: {text}")
query_chunks = chop_and_chunk(text, fast=True)
query_vectors = self.model.embed(query_chunks, device=self.device)

# Embed and quantize the query text
query_vectors = self.model.encode_with_onnx([text])
query_binary_vectors = self.model.quantize(query_vectors, precision="binary")

# Perform search on the query binary vectors
results = []
for query_binary_vector in query_binary_vectors:
chunk_results = self.search(query_binary_vector, top_k, metadata)
chunk_results = self.rank_and_filter(query_binary_vector, top_k, metadata)
results.extend(chunk_results)

results.sort(key=lambda x: x[1], reverse=True)
# Sort the results by similarity score
results.sort(key=lambda x: x[1])
results = results[:top_k]

print("Retrieval completed.")
if return_scores:
return [(idx, self.index[idx]['text'], self.index[idx]['metadata'], score) for idx, score in results]
else:
return [(idx, self.index[idx]['text'], self.index[idx]['metadata']) for idx, _ in results]

def search(self, query_binary_vector, top_k, metadata=None):
# Reshape query_binary_vector to 1D array
query_binary_vector = query_binary_vector.reshape(-1)

# Perform binary search
binary_vectors = np.array([item['binary_vector'] for item in self.index.values()])
binary_similarities = np.einsum('i,ji->j', query_binary_vector, binary_vectors)
top_k_indices = np.argpartition(binary_similarities, -top_k)[-top_k:]

def rank_and_filter(self, query_binary_vector, top_k, metadata=None):
query_binary_vector = np.array(query_binary_vector).reshape(-1)

corpus_binary_vectors = np.array([item['binary_vector'] for item in self.index.values()])
top_k_indices, top_k_scores = self.model.search(query_binary_vector, corpus_binary_vectors, top_k)
top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices]

# Apply metadata filter on the retrieved top_k items
Expand All @@ -122,9 +122,7 @@ def search(self, query_binary_vector, top_k, metadata=None):
if all(item_metadata.get(key) == value for key, value in metadata.items()):
filtered_ids.append(chunk_id)
top_k_ids = filtered_ids[:top_k]

# Get the similarity scores for the top_k items
top_k_scores = binary_similarities[top_k_indices]
top_k_scores = top_k_scores[:len(top_k_ids)]

return list(zip(top_k_ids, top_k_scores))

Expand Down
88 changes: 65 additions & 23 deletions vlite/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,65 @@
from tokenizers import Tokenizer

def normalize(v):
norm = np.linalg.norm(v, axis=1)
if v.ndim == 1:
v = v.reshape(1, -1) # Reshape v to 2D array if it is 1D
norm = np.linalg.norm(v, axis=1, keepdims=True)
norm[norm == 0] = 1e-12
return v / norm[:, np.newaxis]
return v / norm

class EmbeddingModel:
def __init__(self, model_name="mixedbread-ai/mxbai-embed-large-v1"):
tokenizer_path = hf_hub_download(repo_id=model_name, filename="tokenizer.json")
model_path = hf_hub_download(repo_id=model_name, filename="onnx/model.onnx")

self.tokenizer = Tokenizer.from_file(tokenizer_path)
self.tokenizer.enable_truncation(max_length=512)
self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=512)

self.model = ort.InferenceSession(model_path)
print("[model]", self.model.get_modelmeta())


self.model_metadata = {
"bert.embedding_length": 1024,
"bert.embedding_length": 512,
"bert.context_length": 512
}
self.embedding_size = self.model_metadata.get("bert.embedding_length", 1024)
self.context_length = self.model_metadata.get("bert.context_length", 512)
self.embedding_dtype = "float32"


def encode_with_onnx(self, texts):
# Ensure all text items are strings
if not all(isinstance(text, str) for text in texts):
raise ValueError("All items in the 'texts' list should be strings.")

try:
# Tokenize texts and convert to the correct format
inputs = self.tokenizer.encode_batch(texts)
input_ids = np.array([x.ids for x in inputs], dtype=np.int64)
attention_mask = np.array([x.attention_mask for x in inputs], dtype=np.int64)
token_type_ids = np.zeros_like(input_ids, dtype=np.int64) # Add token_type_ids input

ort_inputs = {
self.model.get_inputs()[0].name: input_ids,
self.model.get_inputs()[1].name: attention_mask,
self.model.get_inputs()[2].name: token_type_ids # Add token_type_ids input
}

ort_outs = self.model.run(None, ort_inputs)
embeddings = ort_outs[0]
return embeddings
except Exception as e:
print(f"Failed during ONNX encoding: {e}")
raise





def embed(self, texts, max_seq_length=512, device="cpu", batch_size=32):
if isinstance(texts, str):
texts = [texts] # Ensure texts is always a list

def embed(self, texts: List[str], max_seq_length=512, device="cpu", batch_size=32):
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
Expand All @@ -44,31 +78,39 @@ def embed(self, texts: List[str], max_seq_length=512, device="cpu", batch_size=3
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

model_output = self.model.run(None, onnx_input)
last_hidden_state = model_output[0]

input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), last_hidden_state.shape)
embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None)
embeddings = normalize(embeddings).astype(np.float32)
all_embeddings.append(embeddings)

return np.concatenate(all_embeddings)

def token_count(self, texts):
tokens = 0
for text in texts:
encoded = self.tokenizer.encode(text)
tokens += len(encoded.ids)
return tokens

def quantize(self, embeddings, precision="binary"):
embeddings = np.array(embeddings)
# first normalize_embeddings to unit length
embeddings = normalize(embeddings)
# slice to get MRL embeddings
embeddings_slice = embeddings[..., :512]

if precision == "binary":
return np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1)
elif precision == "int8":
return ((embeddings - np.min(embeddings, axis=0)) / (np.max(embeddings, axis=0) - np.min(embeddings, axis=0)) * 255).astype(np.uint8)
return self._binary_quantize(embeddings_slice)
else:
raise ValueError(f"Unsupported precision: {precision}")
raise ValueError(f"Precision {precision} is not supported")

def _binary_quantize(self, embeddings):
return (np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) - 128).astype(np.int8)

def hamming_distance(self, embedding1, embedding2):
# Ensure the embeddings are numpy arrays for the operation.
return np.count_nonzero(np.array(embedding1) != np.array(embedding2))

def search(self, query_embedding, embeddings, top_k):
# Convert embeddings to a numpy array for efficient operations if not already.
embeddings = np.array(embeddings)
distances = np.array([self.hamming_distance(query_embedding, emb) for emb in embeddings])

def rescore(self, query_vector, vectors):
return np.dot(query_vector, vectors.T).flatten()
# Find the indices of the top_k smallest distances
top_k_indices = np.argsort(distances)[:top_k]
return top_k_indices, distances[top_k_indices]
Loading