diff --git a/buster/completers/base.py b/buster/completers/base.py index b3fefab..559e834 100644 --- a/buster/completers/base.py +++ b/buster/completers/base.py @@ -204,12 +204,6 @@ def from_dict(cls, completion_dict: dict): class Completer(ABC): """Generic LLM-based completer. Requires a prompt and an input to produce an output.""" - def __init__( - self, - completion_kwargs: dict, - ): - self.completion_kwargs = completion_kwargs - @abstractmethod def complete(self, prompt: str, user_input) -> (str | Iterator, bool): """Returns the completed message (can be a generator), and a boolean to indicate if an error occured or not.""" diff --git a/buster/completers/chatgpt.py b/buster/completers/chatgpt.py index fb01672..b4a4021 100644 --- a/buster/completers/chatgpt.py +++ b/buster/completers/chatgpt.py @@ -1,14 +1,12 @@ import logging import os -from typing import Iterator +from typing import Iterator, Optional import openai from openai import OpenAI from buster.completers import Completer -client = OpenAI() - logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -29,9 +27,19 @@ class ChatGPTCompleter(Completer): + def __init__(self, completion_kwargs: dict, client_kwargs: Optional[dict] = None): + # use default client if none passed + self.completion_kwargs = completion_kwargs + + if client_kwargs is None: + client_kwargs = {} + + self.client = OpenAI(**client_kwargs) + def complete(self, prompt: str, user_input: str, completion_kwargs=None) -> (str | Iterator, bool): """Returns the completed message (can be a generator), and a boolean to indicate if an error occured or not.""" # Uses default configuration if not overriden + if completion_kwargs is None: completion_kwargs = self.completion_kwargs @@ -42,7 +50,7 @@ def complete(self, prompt: str, user_input: str, completion_kwargs=None) -> (str try: error = False - response = client.chat.completions.create(messages=messages, **completion_kwargs) + response = self.client.chat.completions.create(messages=messages, **completion_kwargs) except openai.BadRequestError: error = True logger.exception("Invalid request to OpenAI API. See traceback:") diff --git a/buster/documents_manager/base.py b/buster/documents_manager/base.py index 22b9f86..38a7cf7 100644 --- a/buster/documents_manager/base.py +++ b/buster/documents_manager/base.py @@ -5,13 +5,10 @@ from typing import Optional import pandas as pd -from openai import OpenAI from tqdm import tqdm from buster.llm_utils import compute_embeddings_parallelized, get_openai_embedding -client = OpenAI() - tqdm.pandas() logger = logging.getLogger(__name__) diff --git a/buster/examples/cfg.py b/buster/examples/cfg.py index e47422f..42ec66d 100644 --- a/buster/examples/cfg.py +++ b/buster/examples/cfg.py @@ -2,10 +2,19 @@ from buster.completers import ChatGPTCompleter, DocumentAnswerer from buster.formatters.documents import DocumentsFormatterJSON from buster.formatters.prompts import PromptFormatter +from buster.llm_utils import get_openai_embedding_constructor from buster.retriever import DeepLakeRetriever, Retriever from buster.tokenizers import GPTTokenizer from buster.validators import Validator +# kwargs to pass to OpenAI client +client_kwargs = { + "timeout": 20, + "max_retries": 3, +} + +embedding_fn = get_openai_embedding_constructor(client_kwargs=client_kwargs) + buster_cfg = BusterConfig( validator_cfg={ "question_validator_cfg": { @@ -15,6 +24,7 @@ "stream": False, "temperature": 0, }, + "client_kwargs": client_kwargs, "check_question_prompt": """You are a chatbot answering questions on artificial intelligence. Your job is to determine wether or not a question is valid, and should be answered. More general questions are not considered valid, even if you might know the response. @@ -35,6 +45,7 @@ "I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?", ], "unknown_threshold": 0.85, + "embedding_fn": embedding_fn, }, "documents_validator_cfg": { "completion_kwargs": { @@ -42,15 +53,16 @@ "stream": False, "temperature": 0, }, + "client_kwargs": client_kwargs, }, "use_reranking": True, - "validate_documents": True, + "validate_documents": False, }, retriever_cfg={ "path": "deeplake_store", "top_k": 3, "thresh": 0.7, - "embedding_model": "text-embedding-ada-002", + "embedding_fn": embedding_fn, }, documents_answerer_cfg={ "no_documents_message": "No documents are available for this question.", @@ -61,6 +73,7 @@ "stream": True, "temperature": 0, }, + "client_kwargs": client_kwargs, }, tokenizer_cfg={ "model_name": "gpt-3.5-turbo", diff --git a/buster/llm_utils/__init__.py b/buster/llm_utils/__init__.py index f1c6be2..7588d6c 100644 --- a/buster/llm_utils/__init__.py +++ b/buster/llm_utils/__init__.py @@ -2,7 +2,14 @@ compute_embeddings_parallelized, cosine_similarity, get_openai_embedding, + get_openai_embedding_constructor, ) from buster.llm_utils.question_reformulator import QuestionReformulator -__all__ = [QuestionReformulator, cosine_similarity, get_openai_embedding, compute_embeddings_parallelized] +__all__ = [ + QuestionReformulator, + cosine_similarity, + get_openai_embedding, + compute_embeddings_parallelized, + get_openai_embedding_constructor, +] diff --git a/buster/llm_utils/embeddings.py b/buster/llm_utils/embeddings.py index 4a91a27..fab5aee 100644 --- a/buster/llm_utils/embeddings.py +++ b/buster/llm_utils/embeddings.py @@ -1,32 +1,42 @@ import logging from functools import lru_cache +from typing import Optional import numpy as np import pandas as pd from openai import OpenAI -from tqdm.contrib.concurrent import process_map +from tqdm.contrib.concurrent import thread_map logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) -client = OpenAI() +def get_openai_embedding_constructor(client_kwargs: Optional[dict] = None, model: str = "text-embedding-ada-002"): + if client_kwargs is None: + client_kwargs = {} + client = OpenAI(**client_kwargs) -@lru_cache -def get_openai_embedding(text: str, model: str = "text-embedding-ada-002") -> np.array: - try: - text = text.replace("\n", " ") - response = client.embeddings.create( - input=text, - model=model, - ) - embedding = response.data[0].embedding - return np.array(embedding, dtype="float32") - except Exception as e: - # This rarely happens with the API but in the off chance it does, will allow us not to loose the progress. - logger.exception(e) - logger.warning(f"Embedding failed to compute for {text=}") - return None + @lru_cache + def embedding_fn(text: str, model: str = model) -> np.array: + try: + text = text.replace("\n", " ") + response = client.embeddings.create( + input=text, + model=model, + ) + embedding = response.data[0].embedding + return np.array(embedding, dtype="float32") + except Exception as e: + # This rarely happens with the API but in the off chance it does, will allow us not to loose the progress. + logger.exception(e) + logger.warning(f"Embedding failed to compute for {text=}") + return None + + return embedding_fn + + +# default embedding function +get_openai_embedding = get_openai_embedding_constructor() def cosine_similarity(a, b): @@ -50,7 +60,7 @@ def compute_embeddings_parallelized(df: pd.DataFrame, embedding_fn: callable, nu """ logger.info(f"Computing embeddings of {len(df)} chunks. Using {num_workers=}") - embeddings = process_map(embedding_fn, df.content.to_list(), max_workers=num_workers) + embeddings = thread_map(embedding_fn, df.content.to_list(), max_workers=num_workers) logger.info(f"Finished computing embeddings") return embeddings diff --git a/buster/llm_utils/question_reformulator.py b/buster/llm_utils/question_reformulator.py index 7b50050..9ce549e 100644 --- a/buster/llm_utils/question_reformulator.py +++ b/buster/llm_utils/question_reformulator.py @@ -5,8 +5,13 @@ class QuestionReformulator: - def __init__(self, system_prompt: Optional[str] = None, completion_kwargs: Optional[dict] = None): - self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs) + def __init__( + self, + system_prompt: Optional[str] = None, + completion_kwargs: Optional[dict] = None, + client_kwargs: Optional[dict] = None, + ): + self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs, client_kwargs=client_kwargs) if completion_kwargs is None: # Default kwargs diff --git a/buster/retriever/base.py b/buster/retriever/base.py index 0436ae1..b02ed09 100644 --- a/buster/retriever/base.py +++ b/buster/retriever/base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import lru_cache -from typing import Optional +from typing import Callable, Optional import numpy as np import pandas as pd @@ -18,10 +18,13 @@ @dataclass class Retriever(ABC): - def __init__(self, top_k, thresh, embedding_model, *args, **kwargs): + def __init__(self, top_k, thresh, embedding_fn: Callable[[str], np.array] = None, *args, **kwargs): + if embedding_fn is None: + embedding_fn = get_openai_embedding + self.top_k = top_k self.thresh = thresh - self.embedding_model = embedding_model + self.embedding_fn = embedding_fn # Add your access to documents in your own init @@ -37,11 +40,9 @@ def get_source_display_name(self, source: str) -> str: If source is None, returns all documents. If source does not exist, returns empty dataframe.""" ... - @staticmethod - @lru_cache - def get_embedding(query: str, model: str) -> np.ndarray: + def get_embedding(self, query: str) -> np.ndarray: logger.info("generating embedding") - return get_openai_embedding(query, model=model) + return self.embedding_fn(query) @abstractmethod def get_topk_documents(self, query: str, source: str = None, top_k: int = None) -> pd.DataFrame: diff --git a/buster/retriever/deeplake.py b/buster/retriever/deeplake.py index 236626a..84103e6 100644 --- a/buster/retriever/deeplake.py +++ b/buster/retriever/deeplake.py @@ -118,7 +118,7 @@ def get_topk_documents( If no matches are found, returns an empty dataframe.""" if query is not None: - query_embedding = self.get_embedding(query, model=self.embedding_model) + query_embedding = self.get_embedding(query) elif embedding is not None: query_embedding = embedding else: diff --git a/buster/retriever/service.py b/buster/retriever/service.py index 38c38b8..78834c8 100644 --- a/buster/retriever/service.py +++ b/buster/retriever/service.py @@ -77,7 +77,7 @@ def get_topk_documents(self, query: str, sources: Optional[list[str]], top_k: in logger.warning(f"Sources {sources} do not exist. Returning empty dataframe.") return pd.DataFrame() - query_embedding = self.get_embedding(query, model=self.embedding_model) + query_embedding = self.get_embedding(query) if isinstance(query_embedding, np.ndarray): # pinecone expects a list of floats, so convert from ndarray if necessary diff --git a/buster/validators/validators.py b/buster/validators/validators.py index 93336c6..be6a467 100644 --- a/buster/validators/validators.py +++ b/buster/validators/validators.py @@ -19,7 +19,7 @@ def __init__( check_question_prompt: Optional[str] = None, invalid_question_response: Optional[str] = None, completion_kwargs: Optional[dict] = None, - completer: Optional[Completer] = None, + client_kwargs: Optional[dict] = None, ): if check_question_prompt is None: check_question_prompt = ( @@ -39,10 +39,8 @@ def __init__( A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.""", ) - if completer is None: - completer = ChatGPTCompleter - if completion_kwargs is None: + # default completion kwargs completion_kwargs = ( { "model": "gpt-3.5-turbo", @@ -51,7 +49,7 @@ def __init__( }, ) - self.completer = completer(completion_kwargs=completion_kwargs) + self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs, client_kwargs=client_kwargs) self.check_question_prompt = check_question_prompt self.invalid_question_response = invalid_question_response @@ -117,6 +115,7 @@ class DocumentsValidator: def __init__( self, completion_kwargs: Optional[dict] = None, + client_kwargs: Optional[dict] = None, system_prompt: Optional[str] = None, user_input_formatter: Optional[str] = None, max_calls: int = 30, @@ -144,7 +143,7 @@ def __init__( "temperature": 0, } - self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs) + self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs, client_kwargs=client_kwargs) self.max_calls = max_calls diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index a074b5b..a0ed267 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -30,6 +30,10 @@ "model": "gpt-3.5-turbo", "temperature": 0, }, + "client_kwargs": { + "timeout": 20, + "max_retries": 2, + }, }, validator_cfg={ "validate_documents": False, @@ -47,6 +51,10 @@ "stream": False, "temperature": 0, }, + "client_kwargs": { + "timeout": 20, + "max_retries": 2, + }, "check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.", }, },