Skip to content

Commit

Permalink
default to mps and its 5 times faster
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 13, 2024
1 parent 2598156 commit 37b132b
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 19 deletions.
2 changes: 1 addition & 1 deletion vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time

class VLite:
def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxbai-embed-large-v1'):
def __init__(self, collection=None, device='mps', model_name='mixedbread-ai/mxbai-embed-large-v1'):
start_time = time.time()
if collection is None:
current_datetime = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
Expand Down
50 changes: 32 additions & 18 deletions vlite/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@ def normalize(v):
return v / norm


def normalize_onnx(v):
norm = np.linalg.norm(v, axis=1)
norm[norm == 0] = 1e-12
return v / norm[:, np.newaxis]



class EmbeddingModel:
def __init__(self, model_name="mixedbread-ai/mxbai-embed-large-v1"):
start_time = time.time()
self.model = SentenceTransformer(model_name)
self.model = SentenceTransformer(model_name, device="mps")

# tokenizer_path = hf_hub_download(repo_id=model_name, filename="tokenizer.json")
# model_path = hf_hub_download(repo_id=model_name, filename="onnx/model.onnx")
# model_path = hf_hub_download(repo_id=model_name, filename="onnx/model_fp16.onnx")

# self.tokenizer = Tokenizer.from_file(tokenizer_path)
# self.tokenizer.enable_truncation(max_length=512)
Expand All @@ -43,7 +49,7 @@ def __init__(self, model_name="mixedbread-ai/mxbai-embed-large-v1"):
end_time = time.time()
print(f"[model.__init__] Execution time: {end_time - start_time:.5f} seconds")

def embed(self, texts, max_seq_length=512, device="cpu", batch_size=32):
def embed(self, texts, max_seq_length=512, device="mps", batch_size=32):
start_time = time.time()
if isinstance(texts, str):
texts = [texts] # Ensure texts is always a list
Expand All @@ -53,26 +59,34 @@ def embed(self, texts, max_seq_length=512, device="cpu", batch_size=32):
return embeddings

def encode_with_onnx(self, texts):
# texts is a List[str]

batch_size = 32
# Ensure all text items are strings
if not all(isinstance(text, str) for text in texts):
raise ValueError("All items in the 'texts' list should be strings.")

try:
# Tokenize texts and convert to the correct format
inputs = self.tokenizer.encode_batch(texts)
input_ids = np.array([x.ids for x in inputs], dtype=np.int64)
attention_mask = np.array([x.attention_mask for x in inputs], dtype=np.int64)
token_type_ids = np.zeros_like(input_ids, dtype=np.int64) # Add token_type_ids input

ort_inputs = {
self.model.get_inputs()[0].name: input_ids,
self.model.get_inputs()[1].name: attention_mask,
self.model.get_inputs()[2].name: token_type_ids # Add token_type_ids input
}

ort_outs = self.model.run(None, ort_inputs)
embeddings = ort_outs[0]
return embeddings
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
encoded = [self.tokenizer.encode(d) for d in batch]
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
onnx_input = {
"input_ids": np.array(input_ids, dtype=np.int64),
"attention_mask": np.array(attention_mask, dtype=np.int64),
"token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64),
}
model_output = self.model.run(None, onnx_input)
last_hidden_state = model_output[0]
# Perform mean pooling with attention weighting
input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), last_hidden_state.shape)
embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None)
embeddings = normalize_onnx(embeddings).astype(np.float32)
all_embeddings.append(embeddings)
return np.concatenate(all_embeddings)

except Exception as e:
print(f"Failed during ONNX encoding: {e}")
raise
Expand Down
67 changes: 67 additions & 0 deletions vlite/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# A dependency-light way to run the onnx model

from tokenizers import Tokenizer
import onnxruntime as ort
import numpy as np
from typing import List

MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"

# Use pytorches default epsilon for division by zero
# https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
def normalize(v):
norm = np.linalg.norm(v, axis=1)
norm[norm == 0] = 1e-12
return v / norm[:, np.newaxis]

# Sampel implementation of the default sentence-transformers model using ONNX
class ONNXModel():

def __init__(self):
# max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128
# https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480
self.tokenizer = Tokenizer.from_file("onnx/tokenizer.json")
print("[tokenizer ]",self.tokenizer.get_vocab_size())
print("[tokenizer ]",self.tokenizer.get_vocab())


self.tokenizer.enable_truncation(max_length=512)
self.model = ort.InferenceSession("onnx/model.onnx")
print("[model ]",self.model.get_modelmeta())


def __call__(self, documents: List[str], batch_size: int = 32):
all_embeddings = []
for i in range(0, len(documents), batch_size):
batch = documents[i:i + batch_size]
encoded = [self.tokenizer.encode(d) for d in batch]
input_ids = np.array([e.ids for e in encoded])
attention_mask = np.array([e.attention_mask for e in encoded])
onnx_input = {
"input_ids": np.array(input_ids, dtype=np.int64),
"attention_mask": np.array(attention_mask, dtype=np.int64),
"token_type_ids": np.array([np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64),
}
model_output = self.model.run(None, onnx_input)
last_hidden_state = model_output[0]
# Perform mean pooling with attention weighting
input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), last_hidden_state.shape)
embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=None)
embeddings = normalize(embeddings).astype(np.float32)
all_embeddings.append(embeddings)
return np.concatenate(all_embeddings)


# sample_text = "This is a sample text that is likely to overflow the entire model and will be truncated. \
# Keep writing and writing until we reach the end of the model.This is a sample text that is likely to overflow the entire model and \
# will be truncated. Keep writing and writing until we reach the end of the model.This is a sample text that is likely to overflow the entire \
# model and will be truncated. Keep writing and writing until we reach the end of the model. This is a sample text that is likely to overflow \
# the entire model and will be truncated. Keep writing and writing until we reach the end of the model. This is a sample text that is likely to overflow \
# the entire model and will be truncated. Keep writing and writing until we reach the end of the model."
# model = DefaultEmbeddingModel()
# # print(model([sample_text, sample_text]))

# embeddings = model([sample_text, sample_text])
# print(embeddings.shape)
# # print(embeddings[0] == embeddings[1])
# # print(embeddings)

0 comments on commit 37b132b

Please sign in to comment.