Skip to content

Commit

Permalink
onnx flag
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 13, 2024
1 parent 37b132b commit cf0754f
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 13 deletions.
104 changes: 104 additions & 0 deletions tests/notebook3.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/sdan/miniforge3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Similarities: [[0.7919583 0.6369279 0.16512007 0.36207786]]\n"
]
}
],
"source": [
"from typing import Dict\n",
"import numpy as np\n",
"from transformers import AutoModel, AutoTokenizer\n",
"\n",
"def transform_query(query: str) -> str:\n",
" return f'Represent this sentence for searching relevant passages: {query}'\n",
"\n",
"def pooling_np(outputs, attention_mask, strategy='cls'):\n",
" if strategy == 'cls':\n",
" # Taking the first token (CLS token) for each sequence\n",
" return outputs[:, 0]\n",
" elif strategy == 'mean':\n",
" # Applying attention mask and computing mean pooling\n",
" outputs_masked = outputs * attention_mask[:, :, None]\n",
" return np.sum(outputs_masked, axis=1) / np.sum(attention_mask, axis=1)[:, None]\n",
" else:\n",
" raise NotImplementedError\n",
"\n",
"def cos_sim_np(a, b):\n",
" dot_product = np.dot(a, b.T)\n",
" norm_a = np.linalg.norm(a, axis=1, keepdims=True)\n",
" norm_b = np.linalg.norm(b, axis=1)\n",
" return dot_product / (norm_a * norm_b)\n",
"\n",
"# Load the model and tokenizer\n",
"model_id = 'mixedbread-ai/mxbai-embed-large-v1'\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"model = AutoModel.from_pretrained(model_id) # Running on CPU\n",
"\n",
"# Example documents\n",
"docs = [transform_query('A man is eating a piece of bread')] + [\n",
" \"A man is eating food.\",\n",
" \"A man is eating pasta.\",\n",
" \"The girl is carrying a baby.\",\n",
" \"A man is riding a horse.\",\n",
"]\n",
"\n",
"# Tokenize and process with the model\n",
"inputs = tokenizer(docs, padding=True, return_tensors='pt')\n",
"outputs = model(**inputs).last_hidden_state.detach().numpy() # Convert to NumPy array\n",
"attention_mask = inputs['attention_mask'].numpy() # Convert attention mask to NumPy array\n",
"\n",
"# Pool embeddings using NumPy\n",
"embeddings = pooling_np(outputs, attention_mask, 'cls')\n",
"\n",
"# Calculate cosine similarities with NumPy\n",
"similarities_np = cos_sim_np(embeddings[0:1], embeddings[1:])\n",
"print('Similarities:', similarities_np)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
19 changes: 14 additions & 5 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .ctx import Ctx
import time

USE_ONNX = False

class VLite:
def __init__(self, collection=None, device='mps', model_name='mixedbread-ai/mxbai-embed-large-v1'):
start_time = time.time()
Expand Down Expand Up @@ -69,9 +71,12 @@ def add(self, data, metadata=None, item_id=None, need_chunks=True, fast=True):
all_chunks.extend(chunks)
all_metadata.extend([item_metadata] * len(chunks))
all_ids.extend([item_id] * len(chunks))

encoded_data = self.model.embed(all_chunks, device=self.device)
# encoded_data = self.model.encode_with_onnx(all_chunks)

if USE_ONNX:
encoded_data = self.model.encode_with_onnx(all_chunks)
else:
encoded_data = self.model.embed(all_chunks, device=self.device)

binary_encoded_data = self.model.quantize(encoded_data, precision="binary")

for idx, (chunk, binary_vector, metadata) in enumerate(zip(all_chunks, binary_encoded_data, all_metadata)):
Expand All @@ -98,8 +103,12 @@ def retrieve(self, text=None, top_k=5, metadata=None, return_scores=False):
print(f"Retrieving top {top_k} similar texts for query: {text}")

# Embed and quantize the query text
query_vectors = self.model.embed(text, device=self.device)
# query_vectors = self.model.encode_with_onnx([text])

if USE_ONNX:
query_vectors = self.model.encode_with_onnx([text])
else:
query_vectors = self.model.embed(text, device=self.device)

query_binary_vectors = self.model.quantize(query_vectors, precision="binary")

# Perform search on the query binary vectors
Expand Down
24 changes: 16 additions & 8 deletions vlite/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sentence_transformers import SentenceTransformer
import time

USE_ONNX = False

def normalize(v):
if v.ndim == 1:
Expand All @@ -28,16 +29,18 @@ def normalize_onnx(v):
class EmbeddingModel:
def __init__(self, model_name="mixedbread-ai/mxbai-embed-large-v1"):
start_time = time.time()
self.model = SentenceTransformer(model_name, device="mps")

# tokenizer_path = hf_hub_download(repo_id=model_name, filename="tokenizer.json")
# model_path = hf_hub_download(repo_id=model_name, filename="onnx/model_fp16.onnx")
if USE_ONNX:
tokenizer_path = hf_hub_download(repo_id=model_name, filename="tokenizer.json")
model_path = hf_hub_download(repo_id=model_name, filename="onnx/model_fp16.onnx")

# self.tokenizer = Tokenizer.from_file(tokenizer_path)
# self.tokenizer.enable_truncation(max_length=512)
# self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=512)
self.tokenizer = Tokenizer.from_file(tokenizer_path)
self.tokenizer.enable_truncation(max_length=512)
self.tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=512)

# self.model = ort.InferenceSession(model_path)
self.model = ort.InferenceSession(model_path)
else:
self.model = SentenceTransformer(model_name, device="mps")

self.model_metadata = {
"bert.embedding_length": 512,
Expand All @@ -53,7 +56,12 @@ def embed(self, texts, max_seq_length=512, device="mps", batch_size=32):
start_time = time.time()
if isinstance(texts, str):
texts = [texts] # Ensure texts is always a list
embeddings = self.model.encode(texts, device=device, batch_size=batch_size, normalize_embeddings=True)

if USE_ONNX:
embeddings = self.encode_with_onnx(texts)
else:
embeddings = self.model.encode(texts, device=device, batch_size=batch_size, normalize_embeddings=True)

end_time = time.time()
print(f"[model.embed] Execution time: {end_time - start_time:.5f} seconds")
return embeddings
Expand Down

0 comments on commit cf0754f

Please sign in to comment.