Skip to content

Commit

Permalink
add telemetry + top k multiplier
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 29, 2024
1 parent cb6bf7b commit fde3355
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 4 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ beautifulsoup4
huggingface_hub
tiktoken
tokenizers==0.15.2
posthog
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'torch==2.2.2',
'transformers==4.39.0',
'tokenizers==0.15.2',
'posthog',
],
extras_require={
'ocr': ['surya-ocr-vlite']
Expand Down
36 changes: 36 additions & 0 deletions vlite/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Constants for VLite application
class Constants:
# Device options
DEVICE_CPU = 'cpu'
DEVICE_CUDA = 'cuda'
DEVICE_MPS = 'mps'

# Model details
DEFAULT_MODEL = 'mixedbread-ai/mxbai-embed-large-v1'

# Metadata keys
METADATA_TEXT = 'text'
METADATA_BINARY_VECTOR = 'binary_vector'
METADATA_METADATA = 'metadata'

# Precision types
PRECISION_BINARY = 'binary'

# Logging
LOG_INIT = "[VLite.__init__] Initializing VLite with device: {}"
LOG_EXEC_TIME = "[VLite.{}] Execution time: {:.5f} seconds"
LOG_NO_COLLECTION = "[VLite.__init__] Collection file {} not found. Initializing empty attributes."
LOG_RETRIEVING = "[VLite.retrieve] Retrieving similar texts..."
LOG_RETRIEVING_QUERY = "[VLite.retrieve] Retrieving top {} similar texts for query: {}"
LOG_RETRIEVAL_COMPLETED = "[VLite.retrieve] Retrieval completed."
LOG_RANK_FILTER = "[VLite.rank_and_filter] Shape of query vector: {}"
LOG_RANK_FILTER_RESHAPE = "[VLite.rank_and_filter] Shape of query vector after reshaping: {}"
LOG_RANK_FILTER_CORPUS_SHAPE = "[VLite.rank_and_filter] Shape of corpus binary vectors array: {}"
LOG_RANK_FILTER_TOP_K = "[VLite.rank_and_filter] Top {} indices: {}"
LOG_RANK_FILTER_TOP_K_SCORES = "[VLite.rank_and_filter] Top {} scores: {}"
LOG_RANK_FILTER_COLLECTION_COUNT = "[VLite.rank_and_filter] No. of items in the collection: {}"
LOG_RANK_FILTER_VLITE_COUNT = "[VLite.rank_and_filter] Vlite count: {}"
LOG_ADD_ENCODING = "[VLite.add] Encoding text... not chunking"

TELEMETRY_POSTHOG = 'phc_i9Aq4aTt4aFpFqyKzxN9LGq3SfoIjYNAcazDnn6dSLP'
TELEMETRY_POSTHOG_HOST = 'https://us.i.posthog.com'
15 changes: 14 additions & 1 deletion vlite/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from typing import List, Dict, Union
import numpy as np
import logging

from posthog import Posthog
from constants import Constants
import uuid

logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
posthog = Posthog(project_api_key=Constants.TELEMETRY_POSTHOG, host=Constants.TELEMETRY_POSTHOG_HOST, disable_geoip=False)

class CtxSectionType(Enum):
HEADER = 0
Expand All @@ -31,6 +34,8 @@ def __init__(self, file_path):
self.embeddings = []
self.contexts = []
self.metadata = {}
# Anonymized telemetry
self.anon_user_id = uuid.uuid4().hex if not os.path.exists(f"./contexts/{uuid.uuid4().hex}.telm") else open(f"./contexts/{uuid.uuid4().hex}.telm", "r").read()

def set_header(self, embedding_model: str, embedding_size: int, embedding_dtype: str, context_length: int):
self.header["embedding_model"] = embedding_model
Expand All @@ -49,6 +54,10 @@ def add_metadata(self, key: str, value: Union[int, float, str]):

def save(self):
with open(self.file_path, "wb") as file:
# Anonymized telemetry
if not os.path.exists(f"./contexts/{self.anon_user_id}.telm"): open(f"./contexts/{self.anon_user_id}.telm", "w").write(self.anon_user_id)
posthog.capture(self.anon_user_id,'ctx_save',{'file_path': self.file_path,'header': self.header})

file.write(self.MAGIC_NUMBER)
file.write(struct.pack("<I", self.VERSION))

Expand Down Expand Up @@ -76,6 +85,10 @@ def save(self):
def load(self):
try:
with open(self.file_path, "rb") as file:
# Anonymized telemetry
if not os.path.exists(f"./contexts/{self.anon_user_id}.telm"): open(f"./contexts/{self.anon_user_id}.telm", "w").write(self.anon_user_id)
posthog.capture(self.anon_user_id,'ctx_load',{'file_path': self.file_path})

# Read and verify header
magic_number = file.read(len(self.MAGIC_NUMBER))
if magic_number != self.MAGIC_NUMBER:
Expand Down
13 changes: 10 additions & 3 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@
from .ctx import Ctx
import time
import logging
from posthog import Posthog
from constants import Constants
import os

# Configure logging
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

posthog = Posthog(project_api_key=Constants.TELEMETRY_POSTHOG, host=Constants.TELEMETRY_POSTHOG_HOST, disable_geoip=False)

class VLite:
def __init__(self, collection=None, device=None, model_name='mixedbread-ai/mxbai-embed-large-v1'):
start_time = time.time()
# Anonymized telemetry
self.anon_user_id = uuid4().hex if not os.path.exists(f"./contexts/{uuid4().hex}.telm") else open(f"./contexts/{uuid4().hex}.telm", "r").read()
posthog.capture(self.anon_user_id,'vlite_init',{'collection': collection,'device': device,'model_name': model_name})

if device is None:
if check_cuda_available():
device = 'cuda'
Expand Down Expand Up @@ -120,12 +127,12 @@ def retrieve(self, text=None, top_k=5, metadata=None, return_scores=False, top_k
else:
return [(idx, self.index[idx]['text'], self.index[idx]['metadata']) for idx, _ in results]

def rank_and_filter(self, query_binary_vector, top_k, metadata=None):
def rank_and_filter(self, query_binary_vector, top_k, metadata=None, top_k_multiplier=4):
start_time = time.time()

# If metadata filter is provided, retrieve more items initially
if metadata:
initial_top_k = top_k * 4 # Adjust this factor as needed
initial_top_k = top_k * top_k_multiplier # Adjust this factor as needed
else:
initial_top_k = top_k

Expand Down

0 comments on commit fde3355

Please sign in to comment.