Skip to content

Commit

Permalink
Configure openai client via config (#153)
Browse files Browse the repository at this point in the history
* pass client_kwargs to completor

* add client_kwargs to QuestionReformulator

* update retriever to accept an embedding_fn

* pass timeouts to embedding fn in example
  • Loading branch information
jerpint authored Nov 20, 2023
1 parent 64ba998 commit a129758
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 51 deletions.
6 changes: 0 additions & 6 deletions buster/completers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,6 @@ def from_dict(cls, completion_dict: dict):
class Completer(ABC):
"""Generic LLM-based completer. Requires a prompt and an input to produce an output."""

def __init__(
self,
completion_kwargs: dict,
):
self.completion_kwargs = completion_kwargs

@abstractmethod
def complete(self, prompt: str, user_input) -> (str | Iterator, bool):
"""Returns the completed message (can be a generator), and a boolean to indicate if an error occured or not."""
Expand Down
16 changes: 12 additions & 4 deletions buster/completers/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import logging
import os
from typing import Iterator
from typing import Iterator, Optional

import openai
from openai import OpenAI

from buster.completers import Completer

client = OpenAI()

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

Expand All @@ -29,9 +27,19 @@


class ChatGPTCompleter(Completer):
def __init__(self, completion_kwargs: dict, client_kwargs: Optional[dict] = None):
# use default client if none passed
self.completion_kwargs = completion_kwargs

if client_kwargs is None:
client_kwargs = {}

self.client = OpenAI(**client_kwargs)

def complete(self, prompt: str, user_input: str, completion_kwargs=None) -> (str | Iterator, bool):
"""Returns the completed message (can be a generator), and a boolean to indicate if an error occured or not."""
# Uses default configuration if not overriden

if completion_kwargs is None:
completion_kwargs = self.completion_kwargs

Expand All @@ -42,7 +50,7 @@ def complete(self, prompt: str, user_input: str, completion_kwargs=None) -> (str

try:
error = False
response = client.chat.completions.create(messages=messages, **completion_kwargs)
response = self.client.chat.completions.create(messages=messages, **completion_kwargs)
except openai.BadRequestError:
error = True
logger.exception("Invalid request to OpenAI API. See traceback:")
Expand Down
3 changes: 0 additions & 3 deletions buster/documents_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
from typing import Optional

import pandas as pd
from openai import OpenAI
from tqdm import tqdm

from buster.llm_utils import compute_embeddings_parallelized, get_openai_embedding

client = OpenAI()

tqdm.pandas()

logger = logging.getLogger(__name__)
Expand Down
17 changes: 15 additions & 2 deletions buster/examples/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@
from buster.completers import ChatGPTCompleter, DocumentAnswerer
from buster.formatters.documents import DocumentsFormatterJSON
from buster.formatters.prompts import PromptFormatter
from buster.llm_utils import get_openai_embedding_constructor
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers import GPTTokenizer
from buster.validators import Validator

# kwargs to pass to OpenAI client
client_kwargs = {
"timeout": 20,
"max_retries": 3,
}

embedding_fn = get_openai_embedding_constructor(client_kwargs=client_kwargs)

buster_cfg = BusterConfig(
validator_cfg={
"question_validator_cfg": {
Expand All @@ -15,6 +24,7 @@
"stream": False,
"temperature": 0,
},
"client_kwargs": client_kwargs,
"check_question_prompt": """You are a chatbot answering questions on artificial intelligence.
Your job is to determine wether or not a question is valid, and should be answered.
More general questions are not considered valid, even if you might know the response.
Expand All @@ -35,22 +45,24 @@
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
],
"unknown_threshold": 0.85,
"embedding_fn": embedding_fn,
},
"documents_validator_cfg": {
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
"client_kwargs": client_kwargs,
},
"use_reranking": True,
"validate_documents": True,
"validate_documents": False,
},
retriever_cfg={
"path": "deeplake_store",
"top_k": 3,
"thresh": 0.7,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": embedding_fn,
},
documents_answerer_cfg={
"no_documents_message": "No documents are available for this question.",
Expand All @@ -61,6 +73,7 @@
"stream": True,
"temperature": 0,
},
"client_kwargs": client_kwargs,
},
tokenizer_cfg={
"model_name": "gpt-3.5-turbo",
Expand Down
9 changes: 8 additions & 1 deletion buster/llm_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
compute_embeddings_parallelized,
cosine_similarity,
get_openai_embedding,
get_openai_embedding_constructor,
)
from buster.llm_utils.question_reformulator import QuestionReformulator

__all__ = [QuestionReformulator, cosine_similarity, get_openai_embedding, compute_embeddings_parallelized]
__all__ = [
QuestionReformulator,
cosine_similarity,
get_openai_embedding,
compute_embeddings_parallelized,
get_openai_embedding_constructor,
]
46 changes: 28 additions & 18 deletions buster/llm_utils/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,42 @@
import logging
from functools import lru_cache
from typing import Optional

import numpy as np
import pandas as pd
from openai import OpenAI
from tqdm.contrib.concurrent import process_map
from tqdm.contrib.concurrent import thread_map

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

client = OpenAI()

def get_openai_embedding_constructor(client_kwargs: Optional[dict] = None, model: str = "text-embedding-ada-002"):
if client_kwargs is None:
client_kwargs = {}
client = OpenAI(**client_kwargs)

@lru_cache
def get_openai_embedding(text: str, model: str = "text-embedding-ada-002") -> np.array:
try:
text = text.replace("\n", " ")
response = client.embeddings.create(
input=text,
model=model,
)
embedding = response.data[0].embedding
return np.array(embedding, dtype="float32")
except Exception as e:
# This rarely happens with the API but in the off chance it does, will allow us not to loose the progress.
logger.exception(e)
logger.warning(f"Embedding failed to compute for {text=}")
return None
@lru_cache
def embedding_fn(text: str, model: str = model) -> np.array:
try:
text = text.replace("\n", " ")
response = client.embeddings.create(
input=text,
model=model,
)
embedding = response.data[0].embedding
return np.array(embedding, dtype="float32")
except Exception as e:
# This rarely happens with the API but in the off chance it does, will allow us not to loose the progress.
logger.exception(e)
logger.warning(f"Embedding failed to compute for {text=}")
return None

return embedding_fn


# default embedding function
get_openai_embedding = get_openai_embedding_constructor()


def cosine_similarity(a, b):
Expand All @@ -50,7 +60,7 @@ def compute_embeddings_parallelized(df: pd.DataFrame, embedding_fn: callable, nu
"""

logger.info(f"Computing embeddings of {len(df)} chunks. Using {num_workers=}")
embeddings = process_map(embedding_fn, df.content.to_list(), max_workers=num_workers)
embeddings = thread_map(embedding_fn, df.content.to_list(), max_workers=num_workers)

logger.info(f"Finished computing embeddings")
return embeddings
9 changes: 7 additions & 2 deletions buster/llm_utils/question_reformulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@


class QuestionReformulator:
def __init__(self, system_prompt: Optional[str] = None, completion_kwargs: Optional[dict] = None):
self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs)
def __init__(
self,
system_prompt: Optional[str] = None,
completion_kwargs: Optional[dict] = None,
client_kwargs: Optional[dict] = None,
):
self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs, client_kwargs=client_kwargs)

if completion_kwargs is None:
# Default kwargs
Expand Down
15 changes: 8 additions & 7 deletions buster/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import Optional
from typing import Callable, Optional

import numpy as np
import pandas as pd
Expand All @@ -18,10 +18,13 @@

@dataclass
class Retriever(ABC):
def __init__(self, top_k, thresh, embedding_model, *args, **kwargs):
def __init__(self, top_k, thresh, embedding_fn: Callable[[str], np.array] = None, *args, **kwargs):
if embedding_fn is None:
embedding_fn = get_openai_embedding

self.top_k = top_k
self.thresh = thresh
self.embedding_model = embedding_model
self.embedding_fn = embedding_fn

# Add your access to documents in your own init

Expand All @@ -37,11 +40,9 @@ def get_source_display_name(self, source: str) -> str:
If source is None, returns all documents. If source does not exist, returns empty dataframe."""
...

@staticmethod
@lru_cache
def get_embedding(query: str, model: str) -> np.ndarray:
def get_embedding(self, query: str) -> np.ndarray:
logger.info("generating embedding")
return get_openai_embedding(query, model=model)
return self.embedding_fn(query)

@abstractmethod
def get_topk_documents(self, query: str, source: str = None, top_k: int = None) -> pd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion buster/retriever/deeplake.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_topk_documents(
If no matches are found, returns an empty dataframe."""

if query is not None:
query_embedding = self.get_embedding(query, model=self.embedding_model)
query_embedding = self.get_embedding(query)
elif embedding is not None:
query_embedding = embedding
else:
Expand Down
2 changes: 1 addition & 1 deletion buster/retriever/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_topk_documents(self, query: str, sources: Optional[list[str]], top_k: in
logger.warning(f"Sources {sources} do not exist. Returning empty dataframe.")
return pd.DataFrame()

query_embedding = self.get_embedding(query, model=self.embedding_model)
query_embedding = self.get_embedding(query)

if isinstance(query_embedding, np.ndarray):
# pinecone expects a list of floats, so convert from ndarray if necessary
Expand Down
11 changes: 5 additions & 6 deletions buster/validators/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
check_question_prompt: Optional[str] = None,
invalid_question_response: Optional[str] = None,
completion_kwargs: Optional[dict] = None,
completer: Optional[Completer] = None,
client_kwargs: Optional[dict] = None,
):
if check_question_prompt is None:
check_question_prompt = (
Expand All @@ -39,10 +39,8 @@ def __init__(
A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.""",
)

if completer is None:
completer = ChatGPTCompleter

if completion_kwargs is None:
# default completion kwargs
completion_kwargs = (
{
"model": "gpt-3.5-turbo",
Expand All @@ -51,7 +49,7 @@ def __init__(
},
)

self.completer = completer(completion_kwargs=completion_kwargs)
self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs, client_kwargs=client_kwargs)
self.check_question_prompt = check_question_prompt
self.invalid_question_response = invalid_question_response

Expand Down Expand Up @@ -117,6 +115,7 @@ class DocumentsValidator:
def __init__(
self,
completion_kwargs: Optional[dict] = None,
client_kwargs: Optional[dict] = None,
system_prompt: Optional[str] = None,
user_input_formatter: Optional[str] = None,
max_calls: int = 30,
Expand Down Expand Up @@ -144,7 +143,7 @@ def __init__(
"temperature": 0,
}

self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs)
self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs, client_kwargs=client_kwargs)

self.max_calls = max_calls

Expand Down
8 changes: 8 additions & 0 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
"model": "gpt-3.5-turbo",
"temperature": 0,
},
"client_kwargs": {
"timeout": 20,
"max_retries": 2,
},
},
validator_cfg={
"validate_documents": False,
Expand All @@ -47,6 +51,10 @@
"stream": False,
"temperature": 0,
},
"client_kwargs": {
"timeout": 20,
"max_retries": 2,
},
"check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.",
},
},
Expand Down

0 comments on commit a129758

Please sign in to comment.