Skip to content

Commit

Permalink
move to ctx and fix file loading issues
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 11, 2024
1 parent bb25e2b commit 87d2687
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 43 deletions.
10 changes: 5 additions & 5 deletions tests/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def tearDownClass(cls):
if os.path.exists('vlite-unit.npz'):
print("[+] Removing vlite")
os.remove('vlite-unit.npz')
if os.path.exists('vlite-unit.omom'):
if os.path.exists('vlite-unit.ctx'):
print("[+] Removing vlite")
os.remove('vlite-unit.omom')
if os.path.exists('omnom/vlite-unit.omom'):
print("[+] Removing vlite omom")
os.remove('omnom/vlite-unit.omom')
os.remove('vlite-unit.ctx')
if os.path.exists('contexts/vlite-unit.ctx'):
print("[+] Removing vlite ctx")
os.remove('contexts/vlite-unit.ctx')

if __name__ == '__main__':
unittest.main(verbosity=2)
42 changes: 23 additions & 19 deletions vlite/omom.py → vlite/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from typing import List, Dict, Union
import numpy as np

class OmomSectionType(Enum):
class CtxSectionType(Enum):
HEADER = 0
EMBEDDINGS = 1
CONTEXTS = 2
METADATA = 3

class OmomFile:
MAGIC_NUMBER = b"OMOM"
class CtxFile:
MAGIC_NUMBER = b"CTXF"
VERSION = 1

def __init__(self, file_path):
Expand Down Expand Up @@ -43,31 +43,35 @@ def add_metadata(self, key: str, value: Union[int, float, str]):
self.metadata[key] = value

def save(self):
print("Number of embeddings to save: ", len(self.embeddings))
print("Number of metadata keys to save: ", len(self.metadata))
with open(self.file_path, "wb") as file:
file.write(self.MAGIC_NUMBER)
file.write(struct.pack("<I", self.VERSION))

header_json = json.dumps(self.header).encode("utf-8")
file.write(struct.pack("<II", OmomSectionType.HEADER.value, len(header_json)))
file.write(struct.pack("<II", CtxSectionType.HEADER.value, len(header_json)))
file.write(header_json)

if self.embeddings:
embeddings_data = b"".join(
struct.pack(f"<{len(emb)}f", *[float(x) if not np.isnan(x) else 0.0 for x in emb])
for emb in self.embeddings
)
file.write(struct.pack("<II", OmomSectionType.EMBEDDINGS.value, len(embeddings_data)))
file.write(struct.pack("<II", CtxSectionType.EMBEDDINGS.value, len(embeddings_data)))
file.write(embeddings_data)

contexts_data = b"".join(struct.pack("<I", len(context.encode("utf-8"))) + context.encode("utf-8") for context in self.contexts)
file.write(struct.pack("<II", OmomSectionType.CONTEXTS.value, len(contexts_data)))
file.write(struct.pack("<II", CtxSectionType.CONTEXTS.value, len(contexts_data)))
file.write(contexts_data)

metadata_json = json.dumps(self.metadata).encode("utf-8")
file.write(struct.pack("<II", OmomSectionType.METADATA.value, len(metadata_json)))
file.write(struct.pack("<II", CtxSectionType.METADATA.value, len(metadata_json)))
file.write(metadata_json)

def load(self):
print("Number of embeddings loaded: ", len(self.embeddings))
print("Number of metadata keys loaded: ", len(self.metadata))
try:
with open(self.file_path, "rb") as file:
# Read and verify header
Expand All @@ -86,18 +90,18 @@ def load(self):
break
section_type, section_length = struct.unpack("<II", section_header)

if section_type == OmomSectionType.HEADER.value:
if section_type == CtxSectionType.HEADER.value:
header_json = file.read(section_length).decode("utf-8")
self.header = json.loads(header_json)
elif section_type == OmomSectionType.EMBEDDINGS.value:
elif section_type == CtxSectionType.EMBEDDINGS.value:
embeddings_data = file.read(section_length)
if embeddings_data:
embedding_size = len(embeddings_data) // 4
self.embeddings = [
list(struct.unpack_from(f"<{embedding_size // len(self.embeddings)}f", embeddings_data, i * embedding_size))
for i in range(len(self.embeddings))
] if self.embeddings else [list(struct.unpack_from(f"<{embedding_size}f", embeddings_data))]
elif section_type == OmomSectionType.CONTEXTS.value:
elif section_type == CtxSectionType.CONTEXTS.value:
contexts_data = file.read(section_length)
self.contexts = []
offset = 0
Expand All @@ -110,7 +114,7 @@ def load(self):
except UnicodeDecodeError as e:
print(f"Error decoding context: {e}")
offset += context_length
elif section_type == OmomSectionType.METADATA.value:
elif section_type == CtxSectionType.METADATA.value:
metadata_json = file.read(section_length).decode("utf-8")
self.metadata = json.loads(metadata_json)
else:
Expand All @@ -120,7 +124,7 @@ def load(self):
pass

def __repr__(self):
output = "OmomFile:\n\n"
output = "CtxFile:\n\n"
output += "Header:\n"
for key, value in self.header.items():
output += f" {key}: {value}\n"
Expand All @@ -142,22 +146,22 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.save()

class Omom:
def __init__(self, directory="omnoms"):
class Ctx:
def __init__(self, directory="contexts"):
self.directory = directory
if not os.path.exists(directory):
os.makedirs(directory)

def get(self, user):
return os.path.join(self.directory, f"{user}.omom")
return os.path.join(self.directory, f"{user}.ctx")

def create(self, user: str) -> OmomFile:
def create(self, user: str) -> CtxFile:
file_path = self.get(user)
return OmomFile(file_path)
return CtxFile(file_path)

def read(self, user_id: str) -> OmomFile:
def read(self, user_id: str) -> CtxFile:
file_path = self.get(user_id)
return OmomFile(file_path)
return CtxFile(file_path)

def delete(self, user_id: str):
file_path = self.get(user_id)
Expand Down
36 changes: 20 additions & 16 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .model import EmbeddingModel
from .utils import chop_and_chunk
import datetime
from .omom import Omom
from .ctx import Ctx

class VLite:
def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxbai-embed-large-v1'):
Expand All @@ -14,19 +14,23 @@ def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxba
self.device = device
self.model = EmbeddingModel(model_name) if model_name else EmbeddingModel()

self.omom = Omom()
self.ctx = Ctx()
self.index = {}

try:
with self.omom.read(collection) as omom_file:
self.index = {
chunk_id: {
'text': chunk_data['text'],
'metadata': chunk_data['metadata'],
'binary_vector': np.array(chunk_data['binary_vector'])
}
for chunk_id, chunk_data in omom_file.metadata.items()
ctx_file = self.ctx.read(collection)
ctx_file.load()
# debug print
print("Number of embeddings: ", len(ctx_file.embeddings))
print("Number of metadata: ", len(ctx_file.metadata))
self.index = {
chunk_id: {
'text': ctx_file.contexts[idx],
'metadata': ctx_file.metadata.get(chunk_id, {}),
'binary_vector': np.array(ctx_file.embeddings[idx])
}
for idx, chunk_id in enumerate(ctx_file.metadata.keys())
}
except FileNotFoundError:
print(f"Collection file {self.collection} not found. Initializing empty attributes.")

Expand Down Expand Up @@ -216,23 +220,23 @@ def count(self):

def save(self):
print(f"Saving collection to {self.collection}")
with self.omom.create(self.collection) as omom_file:
omom_file.set_header(
with self.ctx.create(self.collection) as ctx_file:
ctx_file.set_header(
embedding_model="mixedbread-ai/mxbai-embed-large-v1",
embedding_size=self.model.model_metadata.get('bert.embedding_length', 1024),
embedding_dtype=self.model.embedding_dtype,
context_length=self.model.model_metadata.get('bert.context_length', 512)
)
for chunk_id, chunk_data in self.index.items():
omom_file.add_embedding(chunk_data['binary_vector'])
omom_file.add_context(chunk_data['text'])
omom_file.add_metadata(chunk_id, chunk_data['metadata'])
ctx_file.add_embedding(chunk_data['binary_vector'])
ctx_file.add_context(chunk_data['text'])
ctx_file.add_metadata(chunk_id, chunk_data['metadata'])
print("Collection saved successfully.")

def clear(self):
print("Clearing the collection...")
self.index = {}
self.omom.delete(self.collection)
self.ctx.delete(self.collection)
print("Collection cleared.")

def info(self):
Expand Down
6 changes: 3 additions & 3 deletions vlite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def chop_and_chunk(text, max_seq_length=512, fast=False):
enc = tiktoken.get_encoding("cl100k_base")
chunks = []
print(f"Length of text: {len(text)}")
print(f"Original text: {text}")
# print(f"Original text: {text}")
for t in text:
if fast:
chunk_size = max_seq_length * 4
Expand All @@ -43,8 +43,8 @@ def chop_and_chunk(text, max_seq_length=512, fast=False):
for i in range(0, num_tokens, max_seq_length):
chunk = enc.decode(token_ids[i:i + max_seq_length])
chunks.append(chunk)
print("Chopped text into these chunks:", chunks)
print(f"Chopped text into {len(chunks)} chunks.")
# print("Chopped text into these chunks:", chunks)
# print(f"Chopped text into {len(chunks)} chunks.")
return chunks

def process_pdf(file_path: str, chunk_size: int = 512, use_ocr: bool = False, langs: List[str] = None) -> List[str]:
Expand Down

0 comments on commit 87d2687

Please sign in to comment.