Skip to content

Commit

Permalink
WIP: refactor embedding fn
Browse files Browse the repository at this point in the history
  • Loading branch information
jerpint committed Nov 20, 2023
1 parent af15e42 commit 1ccd60b
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions buster/llm_utils/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,37 @@
import numpy as np
import pandas as pd
from openai import OpenAI
from tqdm.contrib.concurrent import process_map
from tqdm.contrib.concurrent import thread_map

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

client = OpenAI()

def get_embedding_fn(client_kwargs=None):
if client_kwargs is None:
client_kwargs = {}
client = OpenAI(**client_kwargs)

@lru_cache
def get_openai_embedding(text: str, model: str = "text-embedding-ada-002") -> np.array:
try:
text = text.replace("\n", " ")
response = client.embeddings.create(
input=text,
model=model,
)
embedding = response.data[0].embedding
return np.array(embedding, dtype="float32")
except Exception as e:
# This rarely happens with the API but in the off chance it does, will allow us not to loose the progress.
logger.exception(e)
logger.warning(f"Embedding failed to compute for {text=}")
return None
@lru_cache
def get_openai_embedding(text: str, model: str = "text-embedding-ada-002") -> np.array:
try:
text = text.replace("\n", " ")
response = client.embeddings.create(
input=text,
model=model,
)
embedding = response.data[0].embedding
return np.array(embedding, dtype="float32")
except Exception as e:
# This rarely happens with the API but in the off chance it does, will allow us not to loose the progress.
logger.exception(e)
logger.warning(f"Embedding failed to compute for {text=}")
return None

return get_openai_embedding


get_openai_embedding = get_embedding_fn()


def cosine_similarity(a, b):
Expand All @@ -50,7 +58,7 @@ def compute_embeddings_parallelized(df: pd.DataFrame, embedding_fn: callable, nu
"""

logger.info(f"Computing embeddings of {len(df)} chunks. Using {num_workers=}")
embeddings = process_map(embedding_fn, df.content.to_list(), max_workers=num_workers)
embeddings = thread_map(embedding_fn, df.content.to_list(), max_workers=num_workers)

logger.info(f"Finished computing embeddings")
return embeddings

0 comments on commit 1ccd60b

Please sign in to comment.