From 8d221e4d82f03fc0f24ca8705a5b70438117ca3c Mon Sep 17 00:00:00 2001 From: Surya Date: Fri, 12 Apr 2024 16:24:59 -0700 Subject: [PATCH] add sentence-transformers uses MRL --- requirements.txt | 3 +- vlite/main.py | 30 ++++++++-------- vlite/model.py | 90 +++++++++++++++++++++--------------------------- 3 files changed, 56 insertions(+), 67 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8226f39..602ca71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ llama-cpp-python huggingface_hub tiktoken onnxruntime==1.17.1 -tokenizers==0.15.2 \ No newline at end of file +tokenizers==0.15.2 +sentence_transformers \ No newline at end of file diff --git a/vlite/main.py b/vlite/main.py index 6d3c456..70c8b60 100644 --- a/vlite/main.py +++ b/vlite/main.py @@ -86,16 +86,19 @@ def retrieve(self, text=None, top_k=5, metadata=None, return_scores=False): print("Retrieving similar texts...") if text: print(f"Retrieving top {top_k} similar texts for query: {text}") - query_chunks = chop_and_chunk(text, fast=True) - query_vectors = self.model.embed(query_chunks, device=self.device) + + # Embed and quantize the query text + query_vectors = self.model.embed(text, device=self.device) query_binary_vectors = self.model.quantize(query_vectors, precision="binary") + # Perform search on the query binary vectors results = [] for query_binary_vector in query_binary_vectors: - chunk_results = self.search(query_binary_vector, top_k, metadata) + chunk_results = self.rank_and_filter(query_binary_vector, top_k, metadata) results.extend(chunk_results) - results.sort(key=lambda x: x[1], reverse=True) + # Sort the results by similarity score + results.sort(key=lambda x: x[1]) results = results[:top_k] print("Retrieval completed.") @@ -103,15 +106,12 @@ def retrieve(self, text=None, top_k=5, metadata=None, return_scores=False): return [(idx, self.index[idx]['text'], self.index[idx]['metadata'], score) for idx, score in results] else: return [(idx, self.index[idx]['text'], self.index[idx]['metadata']) for idx, _ in results] - - def search(self, query_binary_vector, top_k, metadata=None): - # Reshape query_binary_vector to 1D array - query_binary_vector = query_binary_vector.reshape(-1) - - # Perform binary search - binary_vectors = np.array([item['binary_vector'] for item in self.index.values()]) - binary_similarities = np.einsum('i,ji->j', query_binary_vector, binary_vectors) - top_k_indices = np.argpartition(binary_similarities, -top_k)[-top_k:] + + def rank_and_filter(self, query_binary_vector, top_k, metadata=None): + query_binary_vector = np.array(query_binary_vector).reshape(-1) + + corpus_binary_vectors = np.array([item['binary_vector'] for item in self.index.values()]) + top_k_indices, top_k_scores = self.model.search(query_binary_vector, corpus_binary_vectors, top_k) top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices] # Apply metadata filter on the retrieved top_k items @@ -122,9 +122,7 @@ def search(self, query_binary_vector, top_k, metadata=None): if all(item_metadata.get(key) == value for key, value in metadata.items()): filtered_ids.append(chunk_id) top_k_ids = filtered_ids[:top_k] - - # Get the similarity scores for the top_k items - top_k_scores = binary_similarities[top_k_indices] + top_k_scores = top_k_scores[:len(top_k_ids)] return list(zip(top_k_ids, top_k_scores)) diff --git a/vlite/model.py b/vlite/model.py index 054ea13..64970e9 100644 --- a/vlite/model.py +++ b/vlite/model.py @@ -4,71 +4,61 @@ import numpy as np from typing import List from tokenizers import Tokenizer +import numpy as np +from typing import List +from sentence_transformers import SentenceTransformer + def normalize(v): - norm = np.linalg.norm(v, axis=1) + if v.ndim == 1: + v = v.reshape(1, -1) # Reshape v to 2D array if it is 1D + norm = np.linalg.norm(v, axis=1, keepdims=True) norm[norm == 0] = 1e-12 - return v / norm[:, np.newaxis] + return v / norm + + class EmbeddingModel: def __init__(self, model_name="mixedbread-ai/mxbai-embed-large-v1"): - tokenizer_path = hf_hub_download(repo_id=model_name, filename="tokenizer.json") - model_path = hf_hub_download(repo_id=model_name, filename="onnx/model.onnx") - - self.tokenizer = Tokenizer.from_file(tokenizer_path) - self.tokenizer.enable_truncation(max_length=512) - self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=512) - - self.model = ort.InferenceSession(model_path) - print("[model]", self.model.get_modelmeta()) - + self.model = SentenceTransformer(model_name) self.model_metadata = { - "bert.embedding_length": 1024, + "bert.embedding_length": 512, "bert.context_length": 512 } self.embedding_size = self.model_metadata.get("bert.embedding_length", 1024) self.context_length = self.model_metadata.get("bert.context_length", 512) self.embedding_dtype = "float32" + + def embed(self, texts, max_seq_length=512, device="cpu", batch_size=32): + if isinstance(texts, str): + texts = [texts] # Ensure texts is always a list + embeddings = self.model.encode(texts, device=device, batch_size=batch_size, normalize_embeddings=True) + return embeddings - def embed(self, texts: List[str], max_seq_length=512, device="cpu", batch_size=32): - all_embeddings = [] - for i in range(0, len(texts), batch_size): - batch = texts[i:i + batch_size] - encoded = [self.tokenizer.encode(d) for d in batch] - input_ids = np.array([e.ids for e in encoded]) - attention_mask = np.array([e.attention_mask for e in encoded]) - token_type_ids = np.zeros_like(input_ids, dtype=np.int64) - - onnx_input = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - model_output = self.model.run(None, onnx_input) - last_hidden_state = model_output[0] - - input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), last_hidden_state.shape) - embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None) - embeddings = normalize(embeddings).astype(np.float32) - all_embeddings.append(embeddings) - - return np.concatenate(all_embeddings) - - def token_count(self, texts): - tokens = 0 - for text in texts: - encoded = self.tokenizer.encode(text) - tokens += len(encoded.ids) - return tokens def quantize(self, embeddings, precision="binary"): - embeddings = np.array(embeddings) + # first normalize_embeddings to unit length + embeddings = normalize(embeddings) + # slice to get MRL embeddings + embeddings_slice = embeddings[..., :512] + if precision == "binary": - return np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) - elif precision == "int8": - return ((embeddings - np.min(embeddings, axis=0)) / (np.max(embeddings, axis=0) - np.min(embeddings, axis=0)) * 255).astype(np.uint8) + return self._binary_quantize(embeddings_slice) else: - raise ValueError(f"Unsupported precision: {precision}") + raise ValueError(f"Precision {precision} is not supported") - def rescore(self, query_vector, vectors): - return np.dot(query_vector, vectors.T).flatten() \ No newline at end of file + def _binary_quantize(self, embeddings): + return (np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) - 128).astype(np.int8) + + def hamming_distance(self, embedding1, embedding2): + # Ensure the embeddings are numpy arrays for the operation. + return np.count_nonzero(np.array(embedding1) != np.array(embedding2)) + + def search(self, query_embedding, embeddings, top_k): + # Convert embeddings to a numpy array for efficient operations if not already. + embeddings = np.array(embeddings) + distances = np.array([self.hamming_distance(query_embedding, emb) for emb in embeddings]) + + # Find the indices of the top_k smallest distances + top_k_indices = np.argsort(distances)[:top_k] + return top_k_indices, distances[top_k_indices] \ No newline at end of file