Skip to content

Commit

Permalink
embed all at once
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 4, 2024
1 parent 967a5e8 commit 7c87085
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 41 deletions.
66 changes: 42 additions & 24 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,25 @@ def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxba
print(f"Collection file {self.collection} not found. Initializing empty attributes.")
self.index = {}

def add(self, data, metadata=None, need_chunks=True, newEmbedding=False):
def add(self, data, metadata=None, need_chunks=True, newEmbedding=False, fast=True):
"""
Adds text or a list of texts to the collection with optional ID within metadata.
Args:
data (str, dict, or list): Text data to be added. Can be a string, a dictionary containing text, id, and/or metadata, or a list of strings or dictionaries.
metadata (dict, optional): Additional metadata to be appended to each text entry.
need_chunks (bool, optional): Whether to split the text into chunks before embedding. Defaults to True.
fast (bool, optional): Whether to use fast mode for chunking. Defaults to True.
Returns:
list: A list of tuples, each containing the ID of the added text and the updated vectors array.
"""
print("Adding text to the collection...")
data = [data] if not isinstance(data, list) else data
results = []
all_chunks = []
all_metadata = []
all_ids = []

for item in data:
if isinstance(item, dict):
Expand All @@ -73,28 +77,31 @@ def add(self, data, metadata=None, need_chunks=True, newEmbedding=False):
item_metadata['id'] = item_id

if need_chunks:
chunks = chop_and_chunk(text_content)
encoded_data = self.model.embed(chunks, device=self.device)
chunks = chop_and_chunk(text_content, fast=fast)
else:
chunks = [text_content]
print("Encoding text... not chunking")
encoded_data = self.model.embed(chunks, device=self.device)

binary_encoded_data = self.model.quantize(encoded_data, precision="binary")
int8_encoded_data = self.model.quantize(encoded_data, precision="int8")

for idx, (chunk, vector, binary_vector, int8_vector) in enumerate(zip(chunks, encoded_data, binary_encoded_data, int8_encoded_data)):
chunk_id = f"{item_id}_{idx}"
self.index[chunk_id] = {
'text': chunk,
'metadata': item_metadata,
'vector': vector,
'binary_vector': binary_vector.tolist(),
'int8_vector': int8_vector.tolist()
}

all_chunks.extend(chunks)
all_metadata.extend([item_metadata] * len(chunks))
all_ids.extend([item_id] * len(chunks))

encoded_data = self.model.embed(all_chunks, device=self.device)
binary_encoded_data = self.model.quantize(encoded_data, precision="binary")
int8_encoded_data = self.model.quantize(encoded_data, precision="int8")

for idx, (chunk, vector, binary_vector, int8_vector, metadata, item_id) in enumerate(zip(all_chunks, encoded_data, binary_encoded_data, int8_encoded_data, all_metadata, all_ids)):
chunk_id = f"{item_id}_{idx}"
self.index[chunk_id] = {
'text': chunk,
'metadata': metadata,
'vector': vector,
'binary_vector': binary_vector.tolist(),
'int8_vector': int8_vector.tolist()
}

results.append((item_id, encoded_data, item_metadata))
if item_id not in [result[0] for result in results]:
results.append((item_id, encoded_data, metadata))

self.save()
print("Text added successfully.")
Expand All @@ -115,11 +122,18 @@ def retrieve(self, text=None, top_k=5, metadata=None, newEmbedding=False):
print("Retrieving similar texts...")
if text:
print(f"Retrieving top {top_k} similar texts for query: {text}")
query_vector = self.model.embed([text], device=self.device)
query_binary_vector = self.model.quantize(query_vector, precision="binary")
query_int8_vector = self.model.quantize(query_vector, precision="int8")
query_chunks = chop_and_chunk(text, fast=True)
query_vectors = self.model.embed(query_chunks, device=self.device)
query_binary_vectors = self.model.quantize(query_vectors, precision="binary")
query_int8_vectors = self.model.quantize(query_vectors, precision="int8")

results = self.rescore(query_binary_vector, query_int8_vector, top_k, metadata)
results = []
for query_binary_vector, query_int8_vector in zip(query_binary_vectors, query_int8_vectors):
chunk_results = self.rescore(query_binary_vector, query_int8_vector, top_k, metadata)
results.extend(chunk_results)

results.sort(key=lambda x: x[1], reverse=True)
results = results[:top_k]

print("Retrieval completed.")
return [(self.index[idx]['text'], score, self.index[idx]['metadata']) for idx, score in results]
Expand All @@ -137,9 +151,13 @@ def rescore(self, query_binary_vector, query_int8_vector, top_k, metadata=None):
Returns:
list: A list of tuples containing the chunk IDs and their similarity scores.
"""
# Reshape query_binary_vector and query_int8_vector to 1D arrays
query_binary_vector = query_binary_vector.reshape(-1)
query_int8_vector = query_int8_vector.reshape(-1)

# Perform binary search
binary_vectors = np.array([item['binary_vector'] for item in self.index.values()])
binary_similarities = np.einsum('i,ji->j', query_binary_vector[0], binary_vectors)
binary_similarities = np.einsum('i,ji->j', query_binary_vector, binary_vectors)
top_k_indices = np.argpartition(binary_similarities, -top_k*4)[-top_k*4:]
top_k_ids = [list(self.index.keys())[idx] for idx in top_k_indices]

Expand All @@ -154,7 +172,7 @@ def rescore(self, query_binary_vector, query_int8_vector, top_k, metadata=None):

# Perform rescoring using int8 embeddings
int8_vectors = np.array([self.index[idx]['int8_vector'] for idx in top_k_ids])
int8_similarities = np.einsum('i,ji->j', query_int8_vector[0], int8_vectors)
int8_similarities = np.einsum('i,ji->j', query_int8_vector, int8_vectors)

# Sort the results based on the int8 similarities
sorted_indices = np.argpartition(int8_similarities, -top_k)[-top_k:]
Expand Down
2 changes: 2 additions & 0 deletions vlite/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def __init__(self, model_name='mixedbread-ai/mxbai-embed-large-v1'):
self.max_seq_length = 512 # hardcoded

def embed(self, texts, max_seq_length=512, device="cpu"):
if isinstance(texts, str):
texts = [texts]
embeddings_dict = self.model.create_embedding(texts)
return [item["embedding"] for item in embeddings_dict["data"]]

Expand Down
32 changes: 15 additions & 17 deletions vlite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,30 @@
except ImportError:
run_ocr = None

def chop_and_chunk(text, max_seq_length=512):
def chop_and_chunk(text, max_seq_length=512, fast=False):
"""
Chop text into chunks of max_seq_length tokens.
Chop text into chunks of max_seq_length tokens or max_seq_length*4 characters (fast mode).
"""
if isinstance(text, str):
text = [text]

enc = tiktoken.get_encoding("cl100k_base")
chunks = []

print(f"Lenght of text: {len(text)}")
print(f"Length of text: {len(text)}")
print(f"Original text: {text}")

for t in text:
token_ids = enc.encode(t, disallowed_special=())
num_tokens = len(token_ids)

if num_tokens <= max_seq_length:
chunks.append(t)
if fast:
chunk_size = max_seq_length * 4
chunks.extend([t[i:i + chunk_size] for i in range(0, len(t), chunk_size)])
else:
for i in range(0, num_tokens, max_seq_length):
chunk = enc.decode(token_ids[i:i + max_seq_length])
chunks.append(chunk)

print("Chopped text into this chunk:",chunks)

token_ids = enc.encode(t, disallowed_special=())
num_tokens = len(token_ids)
if num_tokens <= max_seq_length:
chunks.append(t)
else:
for i in range(0, num_tokens, max_seq_length):
chunk = enc.decode(token_ids[i:i + max_seq_length])
chunks.append(chunk)
print("Chopped text into these chunks:", chunks)
print(f"Chopped text into {len(chunks)} chunks.")
return chunks

Expand Down

0 comments on commit 7c87085

Please sign in to comment.