diff --git a/buster/completers/base.py b/buster/completers/base.py index 559e8341..dc77bdf5 100644 --- a/buster/completers/base.py +++ b/buster/completers/base.py @@ -16,6 +16,32 @@ class Completion: + """ + A class to represent the completion object of a model's output for a user's question. + + Attributes: + error (bool): A boolean indicating if an error occurred when generating the completion. + user_inputs (UserInputs): The inputs from the user. + matched_documents (pd.DataFrame): The documents that were matched to the user's question. + answer_generator (Iterator): An optional iterator used to generate the model's answer. + answer_text (str): An optional answer text. + answer_relevant (bool): An optional boolean indicating if the answer is relevant. + question_relevant (bool): An optional boolean indicating if the question is relevant. + completion_kwargs (dict): Optional arguments for the completion. + validator (Validator): An optional Validator object. + + Methods: + __repr__: Outputs a string representation of the object. + _validate_arguments: Validates answer_generator and answer_text arguments. + answer_relevant: Determines if the answer is relevant or not. + question_relevant: Retrieves the relevance of the question. + answer_text: Retrieves the answer text. + answer_generator: Retrieves the answer generator. + postprocess: Postprocesses the results after generating the model's answer. + to_json: Outputs selected attributes of the object in JSON format. + from_dict: Creates a Completion object from a dictionary. + """ + def __init__( self, error: bool, @@ -202,7 +228,12 @@ def from_dict(cls, completion_dict: dict): class Completer(ABC): - """Generic LLM-based completer. Requires a prompt and an input to produce an output.""" + """ + Abstract base class for completers, which generate an answer to a prompt. + + Methods: + complete: The method that should be implemented by any child class to provide an answer to a prompt. + """ @abstractmethod def complete(self, prompt: str, user_input) -> (str | Iterator, bool): @@ -211,9 +242,21 @@ def complete(self, prompt: str, user_input) -> (str | Iterator, bool): class DocumentAnswerer: - """Completer that will answer questions based on documents. + """ + A class that answers questions based on documents. It takes care of formatting the prompts and the documents, and generating the answer when relevant. + + Attributes: + completer (Completer): Object that actually generates an answer to the prompt. + documents_formatter (DocumentsFormatter): Object that formats the documents for the prompt. + prompt_formatter (PromptFormatter): Object that prepares the prompt for the completer. + no_documents_message (str): Message to display when no documents are found to match the query. + completion_class (Completion): Class to use for the resulting completion. + + Methods: + prepare_prompt: Prepares the prompt that will be passed to the completer. + get_completion: Generates a completion to the user's question based on matched documents. """ def __init__( diff --git a/buster/completers/chatgpt.py b/buster/completers/chatgpt.py index b4a4021f..e7042971 100644 --- a/buster/completers/chatgpt.py +++ b/buster/completers/chatgpt.py @@ -28,6 +28,12 @@ class ChatGPTCompleter(Completer): def __init__(self, completion_kwargs: dict, client_kwargs: Optional[dict] = None): + """Initialize the ChatGPTCompleter with completion and client keyword arguments. + + Args: + completion_kwargs: A dictionary of keyword arguments to be used for completions. + client_kwargs: An optional dictionary of keyword arguments to be used for the OpenAI client. + """ # use default client if none passed self.completion_kwargs = completion_kwargs @@ -37,8 +43,21 @@ def __init__(self, completion_kwargs: dict, client_kwargs: Optional[dict] = None self.client = OpenAI(**client_kwargs) def complete(self, prompt: str, user_input: str, completion_kwargs=None) -> (str | Iterator, bool): - """Returns the completed message (can be a generator), and a boolean to indicate if an error occured or not.""" - # Uses default configuration if not overriden + """Given a prompt and user input, returns the generated message and error flag. + + Args: + prompt: The prompt containing the formatted documents and instructions. + user_input: The user input to be responded to. + completion_kwargs: An optional dictionary of keyword arguments to override the default completion kwargs. + + Returns: + A tuple containing the completed message and a boolean indicating if an error occurred. + + Raises: + openai.BadRequestError: If the completion request is invalid. + openai.RateLimitError: If the OpenAI servers are overloaded. + """ + # Uses default configuration if not overridden if completion_kwargs is None: completion_kwargs = self.completion_kwargs @@ -70,7 +89,7 @@ def complete(self, prompt: str, user_input: str, completion_kwargs=None) -> (str return error_message, error if completion_kwargs.get("stream") is True: - # We are entering streaming mode, so here were just wrapping the streamed + # We are entering streaming mode, so here we're just wrapping the streamed # openai response to be easier to handle later def answer_generator(): for chunk in response: diff --git a/buster/completers/user_inputs.py b/buster/completers/user_inputs.py index 4fa01727..297782b5 100644 --- a/buster/completers/user_inputs.py +++ b/buster/completers/user_inputs.py @@ -4,9 +4,24 @@ @dataclass class UserInputs: + """A class that represents user inputs. + + Attributes: + original_input: The original user input. + reformulated_input: The reformulated user input (optional). + """ + original_input: str reformulated_input: Optional[str] = None @property def current_input(self): + """Returns the current user input. + + If the reformulated input is not None, it returns the reformulated input. + Otherwise, it returns the original input. + + Returns: + The current user input. + """ return self.reformulated_input if self.reformulated_input is not None else self.original_input diff --git a/buster/documents_manager/base.py b/buster/documents_manager/base.py index 380333a1..c024c4d5 100644 --- a/buster/documents_manager/base.py +++ b/buster/documents_manager/base.py @@ -22,9 +22,9 @@ def __init__(self, required_columns: Optional[list[str]] = None): """ Constructor for DocumentsManager class. - Parameters: + Args: required_columns (Optional[list[str]]): A list of column names that are required for the dataframe to contain. - If None, no columns are enforced. + If None, no columns are enforced. """ self.required_columns = required_columns @@ -35,6 +35,14 @@ def _check_required_columns(self, df: pd.DataFrame): raise ValueError(f"DataFrame is missing one or more of {self.required_columns=}") def _checkpoint_csv(self, df, csv_filename: str, csv_overwrite: bool = True): + """ + Saves DataFrame with embeddings to a CSV checkpoint. + + Args: + df (pd.DataFrame): The DataFrame with embeddings. + csv_filename (str): Path to save a copy of the dataframe with computed embeddings for later use. + csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to True. + """ import os if csv_overwrite: @@ -68,7 +76,7 @@ def add( 1. Checks if the required columns are present in the DataFrame. 2. Computes embeddings for the 'content' column if they are not already present. 3. Optionally saves the DataFrame with computed embeddings to a CSV checkpoint. - 4. Calls the '_add_documents' method to add documents with embeddinsg to the DocumentsManager. + 4. Calls the '_add_documents' method to add documents with embeddings to the DocumentsManager. Args: df (pd.DataFrame): The DataFrame containing the documents to be added. @@ -77,9 +85,8 @@ def add( Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model. sparse_embedding_fn (callable, optional): A function that computes sparse embeddings for a given input string. Default is None. Only use if you want sparse embeddings. - - csv_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use. - csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file. + csv_filename (str, optional): Path to save a copy of the dataframe with computed embeddings for later use. + csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to True. **add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method. """ @@ -110,30 +117,29 @@ def batch_add( **add_kwargs, ): """ + Adds DataFrame data to a DataManager instance in batches. + This function takes a DataFrame and adds its data to a DataManager instance in batches. It ensures that a minimum time interval is maintained between successive batches to prevent timeouts or excessive load. This is useful for APIs like openAI with rate limits. Args: - df (pandas.DataFrame): The input DataFrame containing data to be added. + df (pd.DataFrame): The input DataFrame containing data to be added. batch_size (int, optional): The size of each batch. Defaults to 3000. min_time_interval (int, optional): The minimum time interval (in seconds) between batches. - Defaults to 60. + Defaults to 60. num_workers (int, optional): The number of parallel workers to use when adding data. Defaults to 32. embedding_fn (callable, optional): A function that computes embeddings for a given input string. Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model. sparse_embedding_fn (callable, optional): A function that computes sparse embeddings for a given input string. Default is None. Only use if you want sparse embeddings. - csv_filename: (str, optional) = Path to save a copy of the dataframe with computed embeddings for later use. - csv_overwrite: (bool, optional) = If csv_filename is specified, whether to overwrite the file with a new file. + csv_filename (str, optional): Path to save a copy of the dataframe with computed embeddings for later use. + csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to False. When using batches, set to False to keep all embeddings in the same file. You may want to manually remove the file if experimenting. - **add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method. - - Returns: - None """ + total_batches = (len(df) // batch_size) + 1 logger.info(f"Adding {len(df)} documents with {batch_size=} for {total_batches=}") diff --git a/buster/documents_manager/deeplake.py b/buster/documents_manager/deeplake.py index 990a4f56..9932aa2a 100644 --- a/buster/documents_manager/deeplake.py +++ b/buster/documents_manager/deeplake.py @@ -18,6 +18,13 @@ def __init__( required_columns: Optional[list[str]] = None, **vector_store_kwargs, ): + """Initialize a DeepLakeDocumentsManager object. + + Args: + vector_store_path: The path to the vector store. + required_columns: A list of columns that are required in the dataframe. + **vector_store_kwargs: Additional keyword arguments to pass to the VectorStore initializer. + """ from deeplake.core.vectorstore import VectorStore self.vector_store_path = vector_store_path @@ -28,12 +35,23 @@ def __init__( ) def __len__(self): + """Get the number of documents in the vector store. + + Returns: + The number of documents in the vector store. + """ return len(self.vector_store) @classmethod def _extract_metadata(cls, df: pd.DataFrame) -> dict: - """extract the metadata from the dataframe in deeplake dict format""" + """Extract metadata from the dataframe in DeepLake dict format. + + Args: + df: The dataframe from which to extract metadata. + Returns: + The extracted metadata in DeepLake dict format. + """ # Ignore the content and embedding column for metadata df = df.drop(columns=["content", "embedding"], errors="ignore") @@ -46,12 +64,16 @@ def _extract_metadata(cls, df: pd.DataFrame) -> dict: return metadata def _add_documents(self, df: pd.DataFrame, **add_kwargs): - """Write all documents from the dataframe into the db as a new version. + """Write all documents from the dataframe into the vector store as a new version. - Each entry in the df is expected to have at least the following columns: + Each entry in the dataframe is expected to have at least the following columns: ["content", "embedding"] Embeddings will have been precomputed in the self.add() method, which calls this one. + + Args: + df: The dataframe containing the documents to add. + **add_kwargs: Additional keyword arguments to pass to the add method of the vector store. """ # Embedding should already be computed in the .add method assert "embedding" in df.columns, "expected column=embedding in the dataframe" @@ -70,7 +92,14 @@ def _add_documents(self, df: pd.DataFrame, **add_kwargs): ) def to_zip(self, output_path: str = "."): - """Zip the contents of the vector_store_path folder to a .zip file in output_path.""" + """Zip the contents of the vector store path folder to a .zip file in the output path. + + Args: + output_path: The path where the zip file should be created. + + Returns: + The path to the created zip file. + """ vector_store_path = self.vector_store_path logger.info(f"Compressing {vector_store_path}...") zip_file_path = zip_contents(input_path=vector_store_path, output_path=output_path) diff --git a/buster/documents_manager/service.py b/buster/documents_manager/service.py index 046a02c0..b6903f36 100644 --- a/buster/documents_manager/service.py +++ b/buster/documents_manager/service.py @@ -24,6 +24,17 @@ def __init__( mongo_db_name: str, **kwargs, ): + """Initialize the DocumentsService. + + Args: + pinecone_api_key: The Pinecone API key. + pinecone_env: The Pinecone environment. + pinecone_index: The Pinecone index. + pinecone_namespace: The Pinecone namespace. + mongo_uri: The MongoDB URI. + mongo_db_name: The MongoDB database name. + **kwargs: Additional keyword arguments to pass to the parent class. + """ super().__init__(**kwargs) pinecone.init(api_key=pinecone_api_key, environment=pinecone_env) @@ -36,15 +47,26 @@ def __init__( self.db = self.client[mongo_db_name] def __repr__(self): + """Return a string representation of the DocumentsService.""" return "DocumentsService" def get_source_id(self, source: str) -> str: - """Get the id of a source.""" + """Get the id of a source. + + Args: + source: The name of the source. + + Returns: + The id of the source. + """ return str(self.db.sources.find_one({"name": source})["_id"]) def _add_documents(self, df: pd.DataFrame): - """Write all documents from the dataframe into the db as a new version.""" - + """Write all documents from the dataframe into the db as a new version. + + Args: + df: The dataframe containing the documents. + """ use_sparse_vector = "sparse_embedding" in df.columns if use_sparse_vector: logger.info("Uploading sparse embeddings too.") @@ -85,13 +107,26 @@ def _add_documents(self, df: pd.DataFrame): self.index.upsert(vectors=to_upsert[i : i + MAX_PINECONE_BATCH_SIZE], namespace=self.namespace) def update_source(self, source: str, display_name: str = None, note: str = None): - """Update the display name and/or note of a source. Also create the source if it does not exist.""" + """Update the display name and/or note of a source. Also create the source if it does not exist. + + Args: + source: The name of the source. + display_name: The new display name of the source. + note: The new note of the source. + """ self.db.sources.update_one( {"name": source}, {"$set": {"display_name": display_name, "note": note}}, upsert=True ) def delete_source(self, source: str) -> tuple[int, int]: - """Delete a source and all its documents. Return if the source was deleted and the number of deleted documents.""" + """Delete a source and all its documents. Return if the source was deleted and the number of deleted documents. + + Args: + source: The name of the source. + + Returns: + A tuple containing the number of deleted sources and the number of deleted documents. + """ source_id = self.get_source_id(source) # MongoDB diff --git a/buster/formatters/prompts.py b/buster/formatters/prompts.py index 77a2dc04..9171956c 100644 --- a/buster/formatters/prompts.py +++ b/buster/formatters/prompts.py @@ -12,19 +12,24 @@ @dataclass class PromptFormatter: tokenizer: Tokenizer - max_tokens: 3500 + max_tokens: int text_before_docs: str text_after_docs: str formatter: str = "{text_before_docs}\n{documents}\n{text_after_docs}" - def format( - self, - documents: str, - ) -> str: - """ - Prepare the system prompt with prompt engineering. + def format(self, documents: str) -> str: + """Formats the system prompt with prompt engineering. + + Joins the text before and after documents with the documents provided. + + Args: + documents (str): The already formatted documents to include in the system prompt. - Joins the text before and after documents with + Returns: + str: The formatted system prompt. + + Raises: + ValueError: If the number of prompt tokens exceeds the maximum allowed tokens. """ system_prompt = self.formatter.format( text_before_docs=self.text_before_docs, documents=documents, text_after_docs=self.text_after_docs @@ -35,7 +40,16 @@ def format( return system_prompt -def prompt_formatter_factory(tokenizer: Tokenizer, prompt_cfg): +def prompt_formatter_factory(tokenizer: Tokenizer, prompt_cfg) -> PromptFormatter: + """Creates a PromptFormatter instance. + + Args: + tokenizer (Tokenizer): The tokenizer to use for the PromptFormatter. + prompt_cfg: The configuration for the PromptFormatter. + + Returns: + PromptFormatter: The created PromptFormatter instance. + """ return PromptFormatter( tokenizer=tokenizer, max_tokens=prompt_cfg["max_tokens"], diff --git a/buster/retriever/base.py b/buster/retriever/base.py index 12055ad6..605bda12 100644 --- a/buster/retriever/base.py +++ b/buster/retriever/base.py @@ -26,6 +26,15 @@ def __init__( *args, **kwargs, ): + """Initializes a Retriever instance. + + Args: + top_k: The maximum number of documents to retrieve. + thresh: The similarity threshold for document retrieval. + embedding_fn: The function to compute document embeddings. + embedding_fn: (Optional) The function to compute sparse document embeddings. + *args, **kwargs: Additional arguments and keyword arguments. + """ if embedding_fn is None: embedding_fn = get_openai_embedding @@ -37,31 +46,78 @@ def __init__( # Add your access to documents in your own init @abstractmethod - def get_documents(self, source: str = None) -> pd.DataFrame: - """Get all current documents from a given source.""" + def get_documents(self, source: Optional[str] = None) -> pd.DataFrame: + """Get all current documents from a given source. + + Args: + source: The source from which to retrieve documents. If None, retrieves documents from all sources. + + Returns: + A pandas DataFrame containing the documents. + """ ... @abstractmethod def get_source_display_name(self, source: str) -> str: """Get the display name of a source. - If source is None, returns all documents. If source does not exist, returns empty dataframe.""" + Args: + source: The source for which to retrieve the display name. + + Returns: + The display name of the source. + + If source is None, returns all documents. If source does not exist, returns empty dataframe. + """ ... @abstractmethod - def get_topk_documents(self, query: str, source: str = None, top_k: int = None) -> pd.DataFrame: + def get_topk_documents(self, query: str, source: Optional[str] = None, top_k: Optional[int] = None) -> pd.DataFrame: """Get the topk documents matching a user's query. - If no matches are found, returns an empty dataframe.""" + Args: + query: The user's query. + source: The source from which to retrieve documents. If None, retrieves documents from all sources. + top_k: The maximum number of documents to retrieve. + + Returns: + A pandas DataFrame containing the topk matched documents. + + If no matches are found, returns an empty dataframe. + """ ... - def threshold_documents(self, matched_documents, thresh: float) -> pd.DataFrame: + def threshold_documents(self, matched_documents: pd.DataFrame, thresh: float) -> pd.DataFrame: + """Filters out matched documents using a similarity threshold. + + Args: + matched_documents: The DataFrame containing the matched documents. + thresh: The similarity threshold. + + Returns: + A pandas DataFrame containing the filtered matched documents. + """ # filter out matched_documents using a threshold return matched_documents[matched_documents.similarity > thresh] def retrieve( - self, user_inputs: UserInputs, sources: Optional[list[str]] = None, top_k: int = None, thresh: float = None + self, + user_inputs: UserInputs, + sources: Optional[list[str]] = None, + top_k: Optional[int] = None, + thresh: Optional[float] = None, ) -> pd.DataFrame: + """Retrieves documents based on user inputs. + + Args: + user_inputs: The user's inputs. + sources: The sources from which to retrieve documents. If None, retrieves documents from all sources. + top_k: The maximum number of documents to retrieve. + thresh: The similarity threshold for document retrieval. + + Returns: + A pandas DataFrame containing the retrieved documents. + """ if top_k is None: top_k = self.top_k if thresh is None: diff --git a/buster/retriever/deeplake.py b/buster/retriever/deeplake.py index fd310651..c49b6732 100644 --- a/buster/retriever/deeplake.py +++ b/buster/retriever/deeplake.py @@ -12,13 +12,29 @@ def extract_metadata(x: pd.DataFrame, columns) -> pd.DataFrame: - """Returned metadata from deeplake is in a nested dict, extract it so that each attribute has its own column.""" + """Extracts metadata from deeplake. + + Args: + x: The dataframe containing the metadata. + columns: The columns to extract. + + Returns: + The dataframe with the extracted metadata. + """ for col in columns: x[col] = x.metadata[col] return x -def data_dict_to_df(data: dict): +def data_dict_to_df(data: dict) -> pd.DataFrame: + """Converts a dictionary of data to a Pandas DataFrame. + + Args: + data: The dictionary containing the data. + + Returns: + The DataFrame containing the data. + """ # rename 'score' to 'similarity' data["similarity"] = data.pop("score") data["content"] = data.pop("text") @@ -35,7 +51,17 @@ def data_dict_to_df(data: dict): return matched_documents -def build_tql_query(embedding, sources=None, top_k: int = 3): +def build_tql_query(embedding, sources=None, top_k: int = 3) -> str: + """Builds a TQL query. + + Args: + embedding: The embedding vector. + sources: The sources to filter by. + top_k: The number of top documents to retrieve. + + Returns: + The TQL query. + """ # Initialize the where_clause to an empty string. where_clause = "" @@ -88,8 +114,15 @@ def __init__( """ ) - def get_documents(self, sources: Optional[list[str]] = None): - """Get all current documents from a given source.""" + def get_documents(self, sources: Optional[list[str]] = None) -> pd.DataFrame: + """Get all current documents from a given source. + + Args: + sources: The sources to retrieve documents from. + + Returns: + The DataFrame containing the retrieved documents. + """ k = len(self.vector_store) # currently this is the only way to retrieve all embeddings in deeplake @@ -102,7 +135,15 @@ def get_documents(self, sources: Optional[list[str]] = None): def get_source_display_name(self, source: str) -> str: """Get the display name of a source. - If source is None, returns all documents. If source does not exist, returns empty dataframe.""" + Args: + source: The name of the source. + + Returns: + The display name of the source. + + Raises: + NotImplementedError: If the method is not implemented. + """ raise NotImplementedError() def get_topk_documents( @@ -115,8 +156,18 @@ def get_topk_documents( ) -> pd.DataFrame: """Get the topk documents matching a user's query. - If no matches are found, returns an empty dataframe.""" + If no matches are found, returns an empty dataframe. + + Args: + query: The user's query. + embedding: The embedding vector. + sources: The sources to filter by. + top_k: The number of top documents to retrieve. + return_tensors: The tensors to include in the result. + Returns: + The DataFrame containing the matched documents. + """ if query is not None: query_embedding = self.embedding_fn(query) elif embedding is not None: diff --git a/buster/retriever/service.py b/buster/retriever/service.py index b78e58ea..b0686895 100644 --- a/buster/retriever/service.py +++ b/buster/retriever/service.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import List, Optional import numpy as np import pandas as pd @@ -25,6 +25,22 @@ def __init__( mongo_db_name: str, **kwargs, ): + """ + Initializes a ServiceRetriever instance. + + The ServiceRetriever is a hybrid retrieval combining pinecone and mongodb services. + + Pinecone is exclusively used as a vector store. + The id of the pinecone vectors are used as a key in the mongodb database to store its associated metadata. + + Args: + pinecone_api_key: The API key for Pinecone. + pinecone_env: The environment for Pinecone. + pinecone_index: The name of the Pinecone index. + pinecone_namespace: The namespace for Pinecone. + mongo_uri: The URI for MongoDB. + mongo_db_name: The name of the MongoDB database. + """ super().__init__(**kwargs) pinecone.init(api_key=pinecone_api_key, environment=pinecone_env) @@ -36,15 +52,26 @@ def __init__( self.db = self.client[mongo_db_name] def get_source_id(self, source: str) -> str: - """Get the id of a source. Returns empty string if the source does not exist.""" + """Get the id of a source. Returns an empty string if the source does not exist. + + Args: + source: The name of the source. + + Returns: + The id of the source. + """ source_pointer = self.db.sources.find_one({"name": source}) return "" if source_pointer is None else str(source_pointer["_id"]) - def get_documents(self, source: str = None) -> pd.DataFrame: + def get_documents(self, source: Optional[str] = None) -> pd.DataFrame: """Get all current documents from a given source. - If source is None, returns all documents. If source does not exist, returns empty dataframe.""" + Args: + source: The name of the source. Defaults to None. + Returns: + A DataFrame containing all the documents. If the source does not exist, returns an empty DataFrame. + """ if source is None: # No source specified, return all documents documents = self.db.documents.find() @@ -60,14 +87,31 @@ def get_documents(self, source: str = None) -> pd.DataFrame: return pd.DataFrame(list(documents)) def get_source_display_name(self, source: str) -> str: - """Get the display name of a source.""" + """Get the display name of a source. + + Args: + source: The name of the source. + + Returns: + The display name of the source. + """ if source is None: return ALL_SOURCES else: display_name = self.db.sources.find_one({"name": source})["display_name"] return display_name - def get_topk_documents(self, query: str, sources: Optional[list[str]], top_k: int) -> pd.DataFrame: + def get_topk_documents(self, query: str, sources: Optional[List[str]], top_k: int) -> pd.DataFrame: + """Get the top k documents matching a query from the specified sources. + + Args: + query: The query string. + sources: The list of source names to search. Defaults to None. + top_k: The number of top matches to return. + + Returns: + A DataFrame containing the top k matching documents. + """ if sources is None: filter = None else: diff --git a/buster/tokenizers/base.py b/buster/tokenizers/base.py index b83db94f..45ac7c93 100644 --- a/buster/tokenizers/base.py +++ b/buster/tokenizers/base.py @@ -3,20 +3,60 @@ class Tokenizer(ABC): - """Abstract base class for a tokenizer.""" + """Abstract base class for a tokenizer. + + Args: + model_name: The name of the tokenizer model. + + Attributes: + model_name: The name of the tokenizer model. + + """ def __init__(self, model_name: str): self.model_name = model_name @abstractmethod def encode(self, string: str) -> list[int]: + """Encodes a string into a list of integers. + + Args: + string: The input string to be encoded. + + Returns: + A list of integers representing the encoded string. + + """ + ... @abstractmethod def decode(self, encoded: list[int]) -> str: + """Decodes a list of integers into a string. + + Args: + encoded: The list of integers to be decoded. + + Returns: + The decoded string. + + """ + ... def num_tokens(self, string: str, return_encoded: bool = False) -> Union[int, tuple[int, list[int]]]: + """Returns the number of tokens in a string. + + Args: + string: The input string. + return_encoded: Whether or not to return the encoded string along with the number of tokens. + + Returns: + If `return_encoded` is False, returns the number of tokens in the string. + If `return_encoded` is True, returns a tuple containing the number of tokens and the encoded string. + + """ + encoded = self.encode(string) if return_encoded: return len(encoded), encoded diff --git a/buster/tokenizers/gpt.py b/buster/tokenizers/gpt.py index 72993c45..df1eb707 100644 --- a/buster/tokenizers/gpt.py +++ b/buster/tokenizers/gpt.py @@ -4,14 +4,42 @@ class GPTTokenizer(Tokenizer): - """Tokenizer from openai, supports most GPT models.""" + """Tokenizer class for GPT models. + + This class implements a tokenizer for GPT models using the tiktoken library. + + Args: + model_name (str): The name of the GPT model to be used. + + Attributes: + encoder: The encoder object created using tiktoken.encoding_for_model(). + + """ def __init__(self, model_name: str): super().__init__(model_name) self.encoder = tiktoken.encoding_for_model(model_name=model_name) def encode(self, string: str): + """Encodes a given string using the GPT tokenizer. + + Args: + string (str): The string to be encoded. + + Returns: + list[int]: The encoded representation of the string. + + """ return self.encoder.encode(string) def decode(self, encoded: list[int]): + """Decodes a list of tokens using the GPT tokenizer. + + Args: + encoded (list[int]): The list of tokens to be decoded. + + Returns: + str: The decoded string representation of the tokens. + + """ return self.encoder.decode(encoded) diff --git a/buster/validators/base.py b/buster/validators/base.py index 5e0e22ee..aa45d336 100644 --- a/buster/validators/base.py +++ b/buster/validators/base.py @@ -22,6 +22,16 @@ def __init__( answer_validator_cfg=None, documents_validator_cfg=None, ): + """ + Initializes the Validator class. + + Args: + use_reranking: A boolean indicating whether to use reranking. + validate_documents: A boolean indicating whether to validate documents. + question_validator_cfg: A configuration dictionary for the QuestionValidator. + answer_validator_cfg: A configuration dictionary for the AnswerValidator. + documents_validator_cfg: A configuration dictionary for the DocumentsValidator. + """ self.question_validator = ( QuestionValidator(**question_validator_cfg) if question_validator_cfg is not None else QuestionValidator() ) @@ -37,17 +47,56 @@ def __init__( self.validate_documents = validate_documents def check_question_relevance(self, question: str) -> tuple[bool, str]: + """ + Checks the relevance of a question. + + Args: + question: The question to be checked. + + Returns: + A tuple containing a boolean indicating the relevance and a string describing the result. + """ return self.question_validator.check_question_relevance(question) def check_answer_relevance(self, answer: str) -> bool: + """ + Checks the relevance of an answer. + + Args: + answer: The answer to be checked. + + Returns: + A boolean indicating the relevance of the answer. + """ return self.answer_validator.check_answer_relevance(answer) def check_documents_relevance(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFrame: + """ + Checks the relevance of documents. + + Args: + answer: The answer to be checked. + matched_documents: The DataFrame containing the matched documents. + + Returns: + A DataFrame containing the relevance of the documents. + """ return self.documents_validator.check_documents_relevance(answer, matched_documents) def rerank_docs( self, answer: str, matched_documents: pd.DataFrame, embedding_fn=get_openai_embedding ) -> pd.DataFrame: + """ + Reranks the matched documents based on answer similarity. + + Args: + answer: The answer for reranking. + matched_documents: The DataFrame containing the matched documents. + embedding_fn: The function used to calculate document embeddings. + + Returns: + A DataFrame containing the reranked documents. + """ """Here we re-rank matched documents according to the answer provided by the llm. This score could be used to determine wether a document was actually relevant to generation.