Skip to content

Commit

Permalink
swap
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 3, 2024
1 parent f0d39b7 commit e6f3198
Show file tree
Hide file tree
Showing 5 changed files with 1,656 additions and 39 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ setuptools==65.3.0
tiktoken==0.4.0
torch==2.2.2
tqdm==4.65.0
transformers==4.36.2
chromadb==0.4.24
qdrant-client
git+https://github.com/sdan/surya.git
beautifulsoup4==4.12.3
llama-cpp-python==0.2.58
huggingface_hub
1,629 changes: 1,629 additions & 0 deletions tests/notebook.ipynb

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions tests/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@ def test_add_pdf(self):
print(f"[test_add_pdf] after Count of chunks in the collection: {self.vlite.count()}")
print(f"Time to add 71067 tokens: {TestVLite.test_times['add_pdf']} seconds")

def test_add_pdf_ocr(self):
start_time = time.time()
self.vlite.add(process_pdf(os.path.join(os.path.dirname(__file__), 'data/attention2.pdf'), use_ocr=True), need_chunks=False, metadata={"ocr": True})
end_time = time.time()
TestVLite.test_times["add_pdf_ocr"] = end_time - start_time
print(f"Time to add tokens: {TestVLite.test_times['add_pdf_ocr']} seconds")
print(f"[test_add_pdf_ocr] Count of chunks in the collection: {self.vlite.count()}")
# takes too long to run
# def test_add_pdf_ocr(self):
# start_time = time.time()
# self.vlite.add(process_pdf(os.path.join(os.path.dirname(__file__), 'data/attention2.pdf'), use_ocr=True), need_chunks=False, metadata={"ocr": True})
# end_time = time.time()
# TestVLite.test_times["add_pdf_ocr"] = end_time - start_time
# print(f"Time to add tokens: {TestVLite.test_times['add_pdf_ocr']} seconds")
# print(f"[test_add_pdf_ocr] Count of chunks in the collection: {self.vlite.count()}")

def test_retrieve(self):
queries = [
Expand Down
6 changes: 3 additions & 3 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxba
print(f"Collection file {self.collection} not found. Initializing empty attributes.")
self.index = {}

def add(self, data, metadata=None, need_chunks=True):
def add(self, data, metadata=None, need_chunks=True, newEmbedding=False):
"""
Adds text or a list of texts to the collection with optional ID within metadata.
Expand Down Expand Up @@ -85,7 +85,7 @@ def add(self, data, metadata=None, need_chunks=True):
print("Text added successfully.")
return results

def retrieve(self, text=None, top_k=5, metadata=None):
def retrieve(self, text=None, top_k=5, metadata=None, newEmbedding=False):
"""
Retrieves similar texts from the collection based on text content, ID, or metadata.
Expand Down Expand Up @@ -118,7 +118,7 @@ def retrieve(self, text=None, top_k=5, metadata=None):

print("Retrieval completed.")
return [(self.index[idx]['text'], similarities[list(self.index.keys()).index(idx)], self.index[idx]['metadata']) for idx in top_k_ids]

def delete(self, ids):
"""
Deletes items from the collection by their IDs.
Expand Down
42 changes: 14 additions & 28 deletions vlite/model.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,24 @@
import os
import torch
from transformers import AutoModel, AutoTokenizer

# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask, device="cpu"):
device = torch.device(device)
token_embeddings = model_output.last_hidden_state.to(device)
attention_mask = attention_mask.to(device)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
import llama_cpp
from huggingface_hub import hf_hub_download

class EmbeddingModel:
def __init__(self, model_name='mixedbread-ai/mxbai-embed-large-v1'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
self.model = AutoModel.from_pretrained(model_name)
self.dimension = self.model.embeddings.position_embeddings.embedding_dim
self.max_seq_length = self.model.embeddings.position_embeddings.num_embeddings
hf_path = hf_hub_download(repo_id="mixedbread-ai/mxbai-embed-large-v1", filename="gguf/mxbai-embed-large-v1-f16.gguf")
print(f"Downloaded model to {hf_path}")

self.model = llama_cpp.Llama(model_path=hf_path, embedding=True)
self.dimension = 1024 # hardcoded
self.max_seq_length = 512 # hardcoded

def embed(self, texts, max_seq_length=512, device="cpu"):
device = torch.device(device)
self.model.to(device)

encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_seq_length)
encoded_input = {name: tensor.to(device) for name, tensor in encoded_input.items()}

with torch.no_grad():
model_output = self.model(**encoded_input)
embeddings = mean_pooling(model_output, encoded_input['attention_mask'], device=device)
tensor_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
np_embeddings = tensor_embeddings.cpu().numpy()

return np_embeddings

embeddings_dict = self.model.create_embedding(texts)
return [item["embedding"] for item in embeddings_dict["data"]]

def token_count(self, texts):
tokens = 0
for text in texts:
tokens += len(self.tokenizer.tokenize(text))
return tokens
return tokens

0 comments on commit e6f3198

Please sign in to comment.