Skip to content

Commit

Permalink
Merge branch 'main' into hybrid_retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
hbertrand committed Dec 7, 2023
2 parents 7c798db + 13ab209 commit 70e7e9d
Show file tree
Hide file tree
Showing 13 changed files with 488 additions and 59 deletions.
47 changes: 45 additions & 2 deletions buster/completers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,32 @@


class Completion:
"""
A class to represent the completion object of a model's output for a user's question.
Attributes:
error (bool): A boolean indicating if an error occurred when generating the completion.
user_inputs (UserInputs): The inputs from the user.
matched_documents (pd.DataFrame): The documents that were matched to the user's question.
answer_generator (Iterator): An optional iterator used to generate the model's answer.
answer_text (str): An optional answer text.
answer_relevant (bool): An optional boolean indicating if the answer is relevant.
question_relevant (bool): An optional boolean indicating if the question is relevant.
completion_kwargs (dict): Optional arguments for the completion.
validator (Validator): An optional Validator object.
Methods:
__repr__: Outputs a string representation of the object.
_validate_arguments: Validates answer_generator and answer_text arguments.
answer_relevant: Determines if the answer is relevant or not.
question_relevant: Retrieves the relevance of the question.
answer_text: Retrieves the answer text.
answer_generator: Retrieves the answer generator.
postprocess: Postprocesses the results after generating the model's answer.
to_json: Outputs selected attributes of the object in JSON format.
from_dict: Creates a Completion object from a dictionary.
"""

def __init__(
self,
error: bool,
Expand Down Expand Up @@ -202,7 +228,12 @@ def from_dict(cls, completion_dict: dict):


class Completer(ABC):
"""Generic LLM-based completer. Requires a prompt and an input to produce an output."""
"""
Abstract base class for completers, which generate an answer to a prompt.
Methods:
complete: The method that should be implemented by any child class to provide an answer to a prompt.
"""

@abstractmethod
def complete(self, prompt: str, user_input) -> (str | Iterator, bool):
Expand All @@ -211,9 +242,21 @@ def complete(self, prompt: str, user_input) -> (str | Iterator, bool):


class DocumentAnswerer:
"""Completer that will answer questions based on documents.
"""
A class that answers questions based on documents.
It takes care of formatting the prompts and the documents, and generating the answer when relevant.
Attributes:
completer (Completer): Object that actually generates an answer to the prompt.
documents_formatter (DocumentsFormatter): Object that formats the documents for the prompt.
prompt_formatter (PromptFormatter): Object that prepares the prompt for the completer.
no_documents_message (str): Message to display when no documents are found to match the query.
completion_class (Completion): Class to use for the resulting completion.
Methods:
prepare_prompt: Prepares the prompt that will be passed to the completer.
get_completion: Generates a completion to the user's question based on matched documents.
"""

def __init__(
Expand Down
25 changes: 22 additions & 3 deletions buster/completers/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@

class ChatGPTCompleter(Completer):
def __init__(self, completion_kwargs: dict, client_kwargs: Optional[dict] = None):
"""Initialize the ChatGPTCompleter with completion and client keyword arguments.
Args:
completion_kwargs: A dictionary of keyword arguments to be used for completions.
client_kwargs: An optional dictionary of keyword arguments to be used for the OpenAI client.
"""
# use default client if none passed
self.completion_kwargs = completion_kwargs

Expand All @@ -37,8 +43,21 @@ def __init__(self, completion_kwargs: dict, client_kwargs: Optional[dict] = None
self.client = OpenAI(**client_kwargs)

def complete(self, prompt: str, user_input: str, completion_kwargs=None) -> (str | Iterator, bool):
"""Returns the completed message (can be a generator), and a boolean to indicate if an error occured or not."""
# Uses default configuration if not overriden
"""Given a prompt and user input, returns the generated message and error flag.
Args:
prompt: The prompt containing the formatted documents and instructions.
user_input: The user input to be responded to.
completion_kwargs: An optional dictionary of keyword arguments to override the default completion kwargs.
Returns:
A tuple containing the completed message and a boolean indicating if an error occurred.
Raises:
openai.BadRequestError: If the completion request is invalid.
openai.RateLimitError: If the OpenAI servers are overloaded.
"""
# Uses default configuration if not overridden

if completion_kwargs is None:
completion_kwargs = self.completion_kwargs
Expand Down Expand Up @@ -70,7 +89,7 @@ def complete(self, prompt: str, user_input: str, completion_kwargs=None) -> (str
return error_message, error

if completion_kwargs.get("stream") is True:
# We are entering streaming mode, so here were just wrapping the streamed
# We are entering streaming mode, so here we're just wrapping the streamed
# openai response to be easier to handle later
def answer_generator():
for chunk in response:
Expand Down
15 changes: 15 additions & 0 deletions buster/completers/user_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,24 @@

@dataclass
class UserInputs:
"""A class that represents user inputs.
Attributes:
original_input: The original user input.
reformulated_input: The reformulated user input (optional).
"""

original_input: str
reformulated_input: Optional[str] = None

@property
def current_input(self):
"""Returns the current user input.
If the reformulated input is not None, it returns the reformulated input.
Otherwise, it returns the original input.
Returns:
The current user input.
"""
return self.reformulated_input if self.reformulated_input is not None else self.original_input
34 changes: 20 additions & 14 deletions buster/documents_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(self, required_columns: Optional[list[str]] = None):
"""
Constructor for DocumentsManager class.
Parameters:
Args:
required_columns (Optional[list[str]]): A list of column names that are required for the dataframe to contain.
If None, no columns are enforced.
If None, no columns are enforced.
"""

self.required_columns = required_columns
Expand All @@ -35,6 +35,14 @@ def _check_required_columns(self, df: pd.DataFrame):
raise ValueError(f"DataFrame is missing one or more of {self.required_columns=}")

def _checkpoint_csv(self, df, csv_filename: str, csv_overwrite: bool = True):
"""
Saves DataFrame with embeddings to a CSV checkpoint.
Args:
df (pd.DataFrame): The DataFrame with embeddings.
csv_filename (str): Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to True.
"""
import os

if csv_overwrite:
Expand Down Expand Up @@ -68,7 +76,7 @@ def add(
1. Checks if the required columns are present in the DataFrame.
2. Computes embeddings for the 'content' column if they are not already present.
3. Optionally saves the DataFrame with computed embeddings to a CSV checkpoint.
4. Calls the '_add_documents' method to add documents with embeddinsg to the DocumentsManager.
4. Calls the '_add_documents' method to add documents with embeddings to the DocumentsManager.
Args:
df (pd.DataFrame): The DataFrame containing the documents to be added.
Expand All @@ -77,9 +85,8 @@ def add(
Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model.
sparse_embedding_fn (callable, optional): A function that computes sparse embeddings for a given input string.
Default is None. Only use if you want sparse embeddings.
csv_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file.
csv_filename (str, optional): Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to True.
**add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method.
"""

Expand Down Expand Up @@ -110,30 +117,29 @@ def batch_add(
**add_kwargs,
):
"""
Adds DataFrame data to a DataManager instance in batches.
This function takes a DataFrame and adds its data to a DataManager instance in batches.
It ensures that a minimum time interval is maintained between successive batches
to prevent timeouts or excessive load. This is useful for APIs like openAI with rate limits.
Args:
df (pandas.DataFrame): The input DataFrame containing data to be added.
df (pd.DataFrame): The input DataFrame containing data to be added.
batch_size (int, optional): The size of each batch. Defaults to 3000.
min_time_interval (int, optional): The minimum time interval (in seconds) between batches.
Defaults to 60.
Defaults to 60.
num_workers (int, optional): The number of parallel workers to use when adding data.
Defaults to 32.
embedding_fn (callable, optional): A function that computes embeddings for a given input string.
Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model.
sparse_embedding_fn (callable, optional): A function that computes sparse embeddings for a given input string.
Default is None. Only use if you want sparse embeddings.
csv_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file.
csv_filename (str, optional): Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to False.
When using batches, set to False to keep all embeddings in the same file. You may want to manually remove the file if experimenting.
**add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method.
Returns:
None
"""

total_batches = (len(df) // batch_size) + 1

logger.info(f"Adding {len(df)} documents with {batch_size=} for {total_batches=}")
Expand Down
37 changes: 33 additions & 4 deletions buster/documents_manager/deeplake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ def __init__(
required_columns: Optional[list[str]] = None,
**vector_store_kwargs,
):
"""Initialize a DeepLakeDocumentsManager object.
Args:
vector_store_path: The path to the vector store.
required_columns: A list of columns that are required in the dataframe.
**vector_store_kwargs: Additional keyword arguments to pass to the VectorStore initializer.
"""
from deeplake.core.vectorstore import VectorStore

self.vector_store_path = vector_store_path
Expand All @@ -28,12 +35,23 @@ def __init__(
)

def __len__(self):
"""Get the number of documents in the vector store.
Returns:
The number of documents in the vector store.
"""
return len(self.vector_store)

@classmethod
def _extract_metadata(cls, df: pd.DataFrame) -> dict:
"""extract the metadata from the dataframe in deeplake dict format"""
"""Extract metadata from the dataframe in DeepLake dict format.
Args:
df: The dataframe from which to extract metadata.
Returns:
The extracted metadata in DeepLake dict format.
"""
# Ignore the content and embedding column for metadata
df = df.drop(columns=["content", "embedding"], errors="ignore")

Expand All @@ -46,12 +64,16 @@ def _extract_metadata(cls, df: pd.DataFrame) -> dict:
return metadata

def _add_documents(self, df: pd.DataFrame, **add_kwargs):
"""Write all documents from the dataframe into the db as a new version.
"""Write all documents from the dataframe into the vector store as a new version.
Each entry in the df is expected to have at least the following columns:
Each entry in the dataframe is expected to have at least the following columns:
["content", "embedding"]
Embeddings will have been precomputed in the self.add() method, which calls this one.
Args:
df: The dataframe containing the documents to add.
**add_kwargs: Additional keyword arguments to pass to the add method of the vector store.
"""
# Embedding should already be computed in the .add method
assert "embedding" in df.columns, "expected column=embedding in the dataframe"
Expand All @@ -70,7 +92,14 @@ def _add_documents(self, df: pd.DataFrame, **add_kwargs):
)

def to_zip(self, output_path: str = "."):
"""Zip the contents of the vector_store_path folder to a .zip file in output_path."""
"""Zip the contents of the vector store path folder to a .zip file in the output path.
Args:
output_path: The path where the zip file should be created.
Returns:
The path to the created zip file.
"""
vector_store_path = self.vector_store_path
logger.info(f"Compressing {vector_store_path}...")
zip_file_path = zip_contents(input_path=vector_store_path, output_path=output_path)
Expand Down
45 changes: 40 additions & 5 deletions buster/documents_manager/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ def __init__(
mongo_db_name: str,
**kwargs,
):
"""Initialize the DocumentsService.
Args:
pinecone_api_key: The Pinecone API key.
pinecone_env: The Pinecone environment.
pinecone_index: The Pinecone index.
pinecone_namespace: The Pinecone namespace.
mongo_uri: The MongoDB URI.
mongo_db_name: The MongoDB database name.
**kwargs: Additional keyword arguments to pass to the parent class.
"""
super().__init__(**kwargs)

pinecone.init(api_key=pinecone_api_key, environment=pinecone_env)
Expand All @@ -36,15 +47,26 @@ def __init__(
self.db = self.client[mongo_db_name]

def __repr__(self):
"""Return a string representation of the DocumentsService."""
return "DocumentsService"

def get_source_id(self, source: str) -> str:
"""Get the id of a source."""
"""Get the id of a source.
Args:
source: The name of the source.
Returns:
The id of the source.
"""
return str(self.db.sources.find_one({"name": source})["_id"])

def _add_documents(self, df: pd.DataFrame):
"""Write all documents from the dataframe into the db as a new version."""

"""Write all documents from the dataframe into the db as a new version.
Args:
df: The dataframe containing the documents.
"""
use_sparse_vector = "sparse_embedding" in df.columns
if use_sparse_vector:
logger.info("Uploading sparse embeddings too.")
Expand Down Expand Up @@ -85,13 +107,26 @@ def _add_documents(self, df: pd.DataFrame):
self.index.upsert(vectors=to_upsert[i : i + MAX_PINECONE_BATCH_SIZE], namespace=self.namespace)

def update_source(self, source: str, display_name: str = None, note: str = None):
"""Update the display name and/or note of a source. Also create the source if it does not exist."""
"""Update the display name and/or note of a source. Also create the source if it does not exist.
Args:
source: The name of the source.
display_name: The new display name of the source.
note: The new note of the source.
"""
self.db.sources.update_one(
{"name": source}, {"$set": {"display_name": display_name, "note": note}}, upsert=True
)

def delete_source(self, source: str) -> tuple[int, int]:
"""Delete a source and all its documents. Return if the source was deleted and the number of deleted documents."""
"""Delete a source and all its documents. Return if the source was deleted and the number of deleted documents.
Args:
source: The name of the source.
Returns:
A tuple containing the number of deleted sources and the number of deleted documents.
"""
source_id = self.get_source_id(source)

# MongoDB
Expand Down
Loading

0 comments on commit 70e7e9d

Please sign in to comment.