From 1e2813404963199b396faa9678553110962d5e13 Mon Sep 17 00:00:00 2001 From: sumanthprabhu Date: Sun, 15 Sep 2024 15:57:10 +0530 Subject: [PATCH 1/6] Add LLMCurate - Implemented LLM based curation using self-ensembling to compute reliability scores for text generation usecases - Implemented util functions to perform batch inference using LLMs - Added corresponding tests - Updated the dependencies in pyproject.toml - Bumped the version from 0.1.3 to 0.2.0 --- dqc/__init__.py | 3 +- dqc/llm.py | 174 +++++++++++ dqc/llm_utils/__init__.py | 9 + dqc/llm_utils/_sanity_checks.py | 231 ++++++++++++++ dqc/llm_utils/compute_confidence_score.py | 93 ++++++ dqc/llm_utils/inference.py | 231 ++++++++++++++ dqc/version.py | 2 +- pyproject.toml | 5 +- tests/llm/conftest.py | 71 +++++ tests/llm/test_llm.py | 348 ++++++++++++++++++++++ 10 files changed, 1163 insertions(+), 4 deletions(-) create mode 100644 dqc/llm.py create mode 100644 dqc/llm_utils/__init__.py create mode 100644 dqc/llm_utils/_sanity_checks.py create mode 100644 dqc/llm_utils/compute_confidence_score.py create mode 100644 dqc/llm_utils/inference.py create mode 100644 tests/llm/conftest.py create mode 100644 tests/llm/test_llm.py diff --git a/dqc/__init__.py b/dqc/__init__.py index 52a4b3d..c4091ad 100644 --- a/dqc/__init__.py +++ b/dqc/__init__.py @@ -1,4 +1,5 @@ from .crossval import CrossValCurate +from .llm import LLMCurate from .version import __version__, show_versions -__all__ = ["CrossValCurate"] +__all__ = ["CrossValCurate", "LLMCurate"] diff --git a/dqc/llm.py b/dqc/llm.py new file mode 100644 index 0000000..b2cc5ef --- /dev/null +++ b/dqc/llm.py @@ -0,0 +1,174 @@ +from typing import Callable, List, Tuple, Union + +import pandas as pd +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from dqc.base import BaseCurate +from dqc.llm_utils import ( + _empty_ds_ensemble_handler, + _validate_init_params, + _validate_run_params, + compute_reliability_score, + run_LLM, +) +from dqc.utils import Logger + +logger = Logger("DQC-Toolkit") + + +class LLMCurate(BaseCurate): + """ + Args: + model (AutoModelForCausalLM): Instantiated LLM + tokenizer (AutoTokenizer): Instantiated tokenizer corresponding to the `model` + verbose (bool, optional): Sets the verbosity level during execution. `True` indicates logging level INFO and `False` indicates logging level 'WARNING'. Defaults to False. + + Examples: + + """ + + def __init__( + self, + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + verbose: bool = False, + **options, + ): + super().__init__(**options) + + _validate_init_params(model, tokenizer, verbose) + self.model = model + self.tokenizer = tokenizer + self.verbose = verbose + self._set_verbosity(verbose) + + self.ds_ensemble = None + + def __str__(self): + display_dict = self.__dict__.copy() + + for key in list(display_dict.keys()): + if key in ["ds_ensemble"]: + ## Don't need to display these attributes + del display_dict[key] + + return str(display_dict) + + __repr__ = __str__ + + def _set_verbosity(self, verbose: bool): + """Set logger level based on user input for parameter `verbose` + + Args: + verbose (bool): Indicator for verbosity + """ + if verbose: + logger.set_level("INFO") + else: + logger.set_level("WARNING") + + def fit_transform(self): + pass + + def run( + self, + column_to_curate: str, + data: Union[pd.DataFrame, Dataset] = None, + ds_column_mapping: dict = {}, + prompt_variants: List[str] = [""], + skip_llm_inference: bool = False, + llm_response_cleaned_column_list: List[str] = ["reference_prediction"], + return_scores: bool = True, + answer_start_token: str = "", + answer_end_token: str = "", + scoring_method: Union[Callable[[str, str], float], str] = "exact_match", + **options, + ) -> Dataset: + """Run LLMCurate on the input data + + Args: + column_to_curate (str): Column name in `data` with the text that needs to be curated + data (Union[pd.DataFrame, Dataset]): Input data for LLM based curation + ds_column_mapping (dict, optional): Mapping of entities to be used in the LLM prompt to the corresponding columns in the input data. Defaults to {}. + prompt_variants (List[str], optional): List of different LLM prompts to be used to curate the labels under `column_to_curate`. Defaults to ['']. + skip_llm_inference (bool, optional): Indicator variable to prevent re-running LLM inference. Set to `True` if artifacts from the previous run of LLMCurate needs to be reused. Else `False`. Defaults to False. + llm_response_cleaned_column_list (list, optional): Names of the columns that will contain LLM predictions for each input prompt in `prompt_variants`. Defaults to ['reference_prediction']. + return_scores (bool, optional): Indicator variable set to `True` if label reliability scores are to be computed for each label under `column_to_curate`. Defaults to True. + answer_start_token (str, optional): Token that indicates the start of answer generation. Defaults to '' + answer_end_token (str, optional): Token that indicates the end of answer generation. Defaults to '' + scoring_method (Union[Callable[[str, str], float], str], optional): A function or the string 'exact_match' to compute the reliability score. Defaults to 'exact_match'. + + Returns: + Dataset: Input dataset with reference responses. If `return_scores=True`, then input dataset with reference responses and reliability scores. + """ + if not skip_llm_inference: + empty_string_col_list = _validate_run_params( + data, + column_to_curate, + ds_column_mapping, + prompt_variants, + llm_response_cleaned_column_list, + ) + + if len(empty_string_col_list) > 0: + logger.warning( + "Found empty string(s) in the input data under column(s) {empty_string_col_list}" + ) + + logger.info( + f"Running the LLM to generate the {len(prompt_variants)} reference responses using `prompt_variants`.." + ) + ds_ensemble = None + + model = self.model + tokenizer = self.tokenizer + + for index, prompt_template_prefix in enumerate(prompt_variants): + proposed_answer_col_name = llm_response_cleaned_column_list[index] + ds = run_LLM( + data, + model, + tokenizer, + ds_column_mapping=ds_column_mapping, + prompt_template_prefix=prompt_template_prefix, + answer_start_token=answer_start_token, + answer_end_token=answer_end_token, + llm_response_cleaned_col_name=proposed_answer_col_name, + **options, + ) + + if not ds_ensemble: + ds_ensemble = ds + else: + ds_ensemble = ds_ensemble.add_column( + proposed_answer_col_name, ds[proposed_answer_col_name] + ) + self.ds_ensemble = ds_ensemble + + if return_scores: + if skip_llm_inference: + if ( + isinstance(data, pd.DataFrame) + or ds_column_mapping + or prompt_variants + ): + logger.warning( + "Ignoring params `data`, `ds_column_mapping` and `prompt_variants` since `skip_llm_inference` is set to `True`" + ) + + _empty_ds_ensemble_handler(len(self.ds_ensemble) == 0, skip_llm_inference) + + logger.info( + "Computing reliability scores using the LLM reference responses.." + ) + self.ds_ensemble = self.ds_ensemble.map( + compute_reliability_score, + fn_kwargs={ + "target_column": column_to_curate, + "reference_column_list": llm_response_cleaned_column_list, + "scoring_method": scoring_method, + }, + ) + + return self.ds_ensemble diff --git a/dqc/llm_utils/__init__.py b/dqc/llm_utils/__init__.py new file mode 100644 index 0000000..0dc8ef5 --- /dev/null +++ b/dqc/llm_utils/__init__.py @@ -0,0 +1,9 @@ +from ._sanity_checks import ( + _empty_ds_ensemble_handler, + _validate_init_params, + _validate_run_params, +) +from .compute_confidence_score import compute_reliability_score +from .inference import build_LLM_prompt, infer_LLM, run_LLM + +__all__ = ["build_LLM_prompt", "compute_reliability_score", "infer_LLM", "run_LLM"] diff --git a/dqc/llm_utils/_sanity_checks.py b/dqc/llm_utils/_sanity_checks.py new file mode 100644 index 0000000..d5b0b8e --- /dev/null +++ b/dqc/llm_utils/_sanity_checks.py @@ -0,0 +1,231 @@ +from typing import List, Tuple, Union + +import pandas as pd +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def _validate_ds_column_mapping(data: pd.DataFrame, ds_column_mapping: dict): + """Sanity checks for `ds_column_mapping` + + Args: + data (Union[pd.DataFrame, Dataset]): Input data for LLM based curation + ds_column_mapping (dict): Mapping of entities to be used in the LLM prompt to the corresponding columns in the input data. + + Raises: + ValueError: If any of the enities or column names in `ds_column_mapping` are invalid. + """ + valid_column_names = data.columns + + if not ds_column_mapping: + raise ValueError( + f"`ds_column_mapping` cannot be empty when `skip_llm_inference=False`. Please pass a valid non empty dictionary." + ) + + for entity, col_name in ds_column_mapping.items(): + if not entity or entity == "": + raise ValueError( + f"Entity `{entity}` for column name {col_name} in `ds_column_mapping` is invalid. Please make sure to pass valid non empty texts only." + ) + + if not col_name or col_name == "" or col_name not in valid_column_names: + raise ValueError( + f"Column name `{col_name}` for entity `{entity}` in `ds_column_mapping` is invalid. Please make sure to pass valid column names from the input data." + ) + + return + + +def _check_null_values( + data: pd.DataFrame, column_to_curate: str, ds_column_mapping: dict +) -> List[str]: + """Sanity checks to detect null entries in the data + + Args: + data (pd.DataFrame): Input data for LLM based curation + column_to_curate (str): Column name in `data` with the text that needs to be curated + ds_column_mapping (dict): Mapping of entities to be used in the LLM prompt to the corresponding columns in the input data + + Raises: + ValueError: If null values are found in the data + + Returns: + List[str]: List of column names containing blank values + """ + if ( + column_to_curate is None + or column_to_curate == "" + or column_to_curate not in data.columns + ): + raise ValueError( + f"`column_to_curate` should be a valid column name from the input data columns" + ) + + empty_string_col_list = [] + for _, column_name in ds_column_mapping.items(): + if len(data.loc[data[column_name].isnull(), column_name]) > 0: + raise ValueError( + f"Found `None` entries under column {column_name} in the input data." + ) + + count = (data[column_name] == "").sum() + if count > 0: + empty_string_col_list.append(column_name) + + return empty_string_col_list + + +def _check_num_rows(data: pd.DataFrame): + """Check if the input data contains atleast one row + + Args: + data (pd.DataFrame): Input data for LLM based curation + + Raises: + ValueError: If data contains zero rows + """ + if len(data) == 0: + raise ValueError( + "Input data must be non null when `skip_llm_inference` is set to `False`." + ) + + return + + +def _validate_prompts_and_response_col_names( + prompt_variants: List[str], llm_response_cleaned_column_list: List[str] +): + """Sanity checks to detect invalid entries in `prompt_variants` and `llm_response_cleaned_column_list` + + Args: + prompt_variants (List[str]): List of different LLM prompts to be used to curate the labels under `column_to_curate`. + llm_response_cleaned_column_list (list): Names of the columns that will contain LLM predictions for each input prompt in `prompt_variants`. + + Raises: + ValueError: If either `prompt_variants` or `llm_response_cleaned_column_list` contain invalid entries or number of entries in both lists do not match + """ + if None in prompt_variants: + raise ValueError( + f"Found {None} in `prompt_variants`. Please make sure to pass only non null prompts in `prompt_variants`" + ) + + if None in llm_response_cleaned_column_list: + raise ValueError( + f"Found {None} in `llm_response_cleaned_col_name_list`. Please make sure to pass only non empty strings in `llm_response_cleaned_col_name_list` " + ) + + if "" in llm_response_cleaned_column_list: + raise ValueError( + f"Found blank string in `llm_response_cleaned_col_name_list`. Please make sure to pass only non empty strings in `llm_response_cleaned_col_name_list`" + ) + + if len(prompt_variants) != len(llm_response_cleaned_column_list): + raise ValueError( + "Number of prompts in `prompt_list` should match number of response column names in `llm_response_cleaned_col_name_list`" + ) + + +def _validate_run_params( + data: pd.DataFrame, + column_to_curate: str, + ds_column_mapping: dict, + prompt_variants: List[str], + llm_response_cleaned_column_list: List[str], +) -> List[str]: + """Run collection of sanity checks on parameters passed to `LLMCurate.run()` + + Args: + data (pd.DataFrame): Input data for LLM based curation + column_to_curate (str): Column name in `data` with the text that needs to be curated + prompt_variants (List[str]): List of different LLM prompts to be used to curate the labels under `column_to_curate`.. + ds_column_mapping (dict): Mapping of entities to be used in the LLM prompt to the corresponding columns in the input data + llm_response_cleaned_column_list (list): Names of the columns that will contain LLM predictions for each input prompt in `prompt_variants` + + Returns: + List[str]: List of column names containing blank values + """ + if isinstance(data, Dataset): + data = data.to_pandas() + + _check_num_rows(data) + _validate_ds_column_mapping(data, ds_column_mapping) + _validate_prompts_and_response_col_names( + prompt_variants, llm_response_cleaned_column_list + ) + + return _check_null_values(data, column_to_curate, ds_column_mapping) + + +def is_valid_text_generation_pipeline( + model: AutoModelForCausalLM, tokenizer: AutoTokenizer +): + """Check if the given model and tokenizer are compatible with text generation. + + Args: + model (AutoModelForCausalLM): Instantiated LLM + tokenizer (AutoTokenizer): Instantiated tokenizer corresponding to the `model` + + Raises: + ValueError: If `model` or `tokenizer` or both are invalid artifacts + + """ + if not hasattr(model, "generate"): + raise ValueError( + "Found invalid model: Missing 'generate' method for text generation." + ) + + if not all(hasattr(tokenizer, attr) for attr in ["encode", "decode"]): + raise ValueError( + "Found invalid tokenizer: Missing 'encode' or 'decode' methods." + ) + + return + + +def _validate_init_params( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + verbose: bool, +): + """Sanity checks to verify the validity of params passed to initialize an instance of LLMCurate + + Args: + model (AutoModelForCausalLM): Instantiated LLM + tokenizer (AutoTokenizer): Instantiated tokenizer corresponding to the `model` + verbose (bool, optional): Sets the verbosity level during execution. `True` indicates logging level INFO and `False` indicates logging level 'WARNING'. + """ + + expected_types = [ + (verbose, "verbose", bool), + ] + + for obj, var_name, expected_type in expected_types: + if not isinstance(obj, expected_type): + raise ValueError( + f"Expected `{var_name}` to be an instance of `{expected_type}`, but got `{type(obj).__name__}`" + ) + + is_valid_text_generation_pipeline(model, tokenizer) + + return + + +def _empty_ds_ensemble_handler(empty_ds_ensemble: bool, skip_llm_inference: bool): + """Exception handling if `ds_ensemble` is `None` when reliability scores need to be computed + + Args: + empty_ds_ensemble (bool): Indicator variable set to `True` if `ds_ensemble` is empty. Else `False` + skip_llm_inference (bool): Indicator variable to prevent re-running LLM inference. Set to `True` if artifacts from the previous run of LLMCurate needs to be reused. Else `False`. + + Raises: + ValueError: If `ds_ensemble` is None + """ + if empty_ds_ensemble: + error_message = "Found `self.ds_ensemble` to be `None`." + + if skip_llm_inference: + error_message += "Try running with parameter `skip_llm_inference=False` to first generate reference responses using the input dataset." + + raise ValueError(error_message) + + return diff --git a/dqc/llm_utils/compute_confidence_score.py b/dqc/llm_utils/compute_confidence_score.py new file mode 100644 index 0000000..5da13f5 --- /dev/null +++ b/dqc/llm_utils/compute_confidence_score.py @@ -0,0 +1,93 @@ +from itertools import combinations +from typing import Callable, List, Union + +import datasets +import numpy as np +from transformers import AutoModel, AutoTokenizer + + +def _compute_exact_match_score( + target_text: str, reference_text_list: List[str] +) -> float: + """Util function to quantify the number of exact matches between a target text and a list of text strings. + + Args: + target_text (str): The text string that will be compared to each text in `reference_text_list`. + reference_text_list (List[str]): List of text strings that need to be individually matched against `target_text`. + + Returns: + float: Score between 0 and 1 indicating the percentage of texts in `reference_text_list` that exactly matched `target_text` + """ + matches = [1 if text == target_text else 0 for text in reference_text_list] + return sum(matches) / len(matches) + + +def _compute_custom_match_score( + target_text: str, + reference_text_list: List[str], + scoring_method: Union[Callable, str], +) -> float: + """Util function to compute the average similarity score between a target text and a list of text string based on a specified scoring function. + + Args: + target_text (str): The text string that will be compared to each text in `reference_text_list`. + reference_text_list (List[str]): List of text strings that need to be individually matched against `target_text` + scoring_method (Union[Callable, str]): A function or the string 'exact_match' to compute the reliability score. Defaults to 'exact_match'. + + Returns: + float: Score between 0 and 1 indicating how closely the texts in `reference_text_list` match the `target_text`. A score of 1 means perfect matches for all entries, + while a score of 0 indicates no similarity. + """ + matches = np.array( + [scoring_method(text, target_text) for text in reference_text_list] + ) + return matches.mean() + + +def compute_reliability_score( + example: datasets.formatting.formatting.LazyRow, + target_column: str, + reference_column_list: List[str], + scoring_method: Union[Callable[[str, str], float], str] = "exact_match", + case_sensitive: bool = False, +) -> float: + """Util function to assess the reliability of a given target text using LLM generated reference texts. + + Args: + example (datasets.formatting.formatting.LazyRow): A row of data from a dataset containing the target and reference texts. + target_column (str): Name of the column containing the target text for estimation of reliability. + reference_column_list (List[str]): Names of the columns containing the reference texts to be compared with the target text. + scoring_method (Union[Callable[[str, str], float], str], optional): A function or the string 'exact_match' to compute the reliability score. Defaults to 'exact_match'. + case_sensitive (bool, optional): `True` if string comparisons need to be case aware. Else `False`. Defaults to `False` + Raises: + ValueError: If `scoring_method` is neither 'exact_match' nor a valid callable function + + Returns: + float: Score between 0 and 1 quantifying the reliability of the target text + """ + if not callable(scoring_method) and scoring_method != "exact_match": + raise ValueError( + "Parameter `scoring_method` must be 'exact_match' or a valid callable that measures string similarity" + ) + + reference_text_list = [] + result_dict = {} + score = 0 + + for col in reference_column_list: + reference_text_list.append(example[col]) + + target_text = example[target_column] + + if not case_sensitive: + target_text = target_text.lower() + reference_text_list = [text.lower() for text in reference_text_list] + + if scoring_method == "exact_match": + score = _compute_exact_match_score(target_text, reference_text_list) + else: + score = _compute_custom_match_score( + target_text, reference_text_list, scoring_method + ) + + return {"reliability_score": score} diff --git a/dqc/llm_utils/inference.py b/dqc/llm_utils/inference.py new file mode 100644 index 0000000..87c33aa --- /dev/null +++ b/dqc/llm_utils/inference.py @@ -0,0 +1,231 @@ +import gc +from typing import Union + +import datasets +import pandas as pd +from datasets import Dataset +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + pipeline, +) + + +def _generate_predictions( + example: datasets.formatting.formatting.LazyBatch, + generator: pipeline, + llm_prompt_col_name: str, + llm_response_raw_col_name: str = "llm_response", + **options, +) -> dict: + """ + Generates predictions using the text generation model for a given example. + + Args: + example (datasets.formatting.formatting.LazyBatch): Batch of samples from a dataset. + generator (pipeline): Huggingface pipeline for text generation. + llm_prompt_col_name (str): Prompt for the text generation model. + llm_response_raw_col_name (str, optional): Name of the column containing prediction. Defaults to 'llm_response'. + + Returns: + dict: A dictionary containing the generated predictions. + """ + predictions = [] + batch_results = generator( + example[llm_prompt_col_name], early_stopping=True, **options + ) + res_dict = {} + + predictions = [result[0]["generated_text"] for result in batch_results] + res_dict[llm_response_raw_col_name] = predictions + + return res_dict + + +def build_LLM_prompt( + input_ds: Dataset, + ds_column_mapping: dict, + prompt_template_prefix: str = "", + answer_start_token: str = "", + llm_prompt_col_name: str = "llm_prompt", +) -> Dataset: + """Util function to build the LLM prompt from input text data + + Args: + input_ds (Dataset): Input dataset containing text + ds_column_mapping (dict): Dictionary mapping prompt entities to dataset column names. + prompt_template_prefix (Union[str, None], optional): Text instruction to prepend to each transformed input text sample. Defaults to "". + answer_start_token (str, optional): Token to append to the prompt to indicate start of the answer. Defaults to "" + llm_prompt_col_name (str, optional): Name of the column for the built LLM prompts. Defaults to 'llm_prompt' + Returns: + Dataset: Dataset with generated predictions. + """ + if type(input_ds) == pd.DataFrame: + input_ds = Dataset.from_pandas(input_ds) + + def _helper( + example: datasets.formatting.formatting.LazyBatch, + prompt_template_prefix: str, + ds_column_mapping: dict, + llm_prompt_col_name: str, + ) -> dict: + llm_prompt = prompt_template_prefix + for entity_name, col_name in ds_column_mapping.items(): + if col_name: + entity_value = example[col_name] + if type(entity_value) == list: + entity_value = "|| ".join(map(str, entity_value)) + else: + entity_value = str(entity_value) + llm_prompt += f"[{entity_name}]{entity_value}[/{entity_name}]" + + if answer_start_token: + llm_prompt += answer_start_token + + return {llm_prompt_col_name: llm_prompt} + + input_ds = input_ds.map( + _helper, + fn_kwargs={ + "prompt_template_prefix": prompt_template_prefix, + "ds_column_mapping": ds_column_mapping, + "llm_prompt_col_name": llm_prompt_col_name, + }, + ) + return input_ds + + +def infer_LLM( + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + input_ds: Dataset, + llm_prompt_col_name: str = "llm_prompt", + llm_response_raw_col_name: str = "llm_response", + **options, +) -> Dataset: + """ + Util function to run LLM inference + + Args: + model (AutoModelForCausalLM): LLM artifact. + tokenizer (Autotokenizer) : LLM tokenizer object + input_ds (Dataset): Input dataset containing text prompts. + llm_prompt_col_name (str, optional): Name of the column containing text prompts. Defaults to 'llm_prompt'. + llm_response_raw_col_name (str, optional): Name of the column containing prediction. Defaults to 'llm_response'. + + Returns: + dataset: Dataset with generated predictions. + """ + text_generator = pipeline( + "text-generation", model=model, tokenizer=tokenizer, truncation=False, **options + ) + text_generator.tokenizer.pad_token_id = model.config.eos_token_id + + batch_size = options["batch_size"] if "batch_size" in options else 8 + + input_ds = input_ds.map( + _generate_predictions, + fn_kwargs={ + "generator": text_generator, + "llm_prompt_col_name": llm_prompt_col_name, + "llm_response_raw_col_name": llm_response_raw_col_name, + **options, + }, + batched=True, + batch_size=batch_size, + ) + + return input_ds + + +def _postprocess( + sample: datasets.formatting.formatting.LazyRow, + llm_prompt_col_name: str = "llm_prompt", + llm_response_raw_col_name: str = "llm_response", + llm_response_cleaned_col_name: str = "llm_response_cleaned", + answer_end_token: str = "", +) -> dict: + """Util function to extract the generated answer from the generated LLM prediction + + Args: + sample (datasets.formatting.formatting.LazyRow): Batch of samples from a dataset + llm_prompt_col_name (str, optional): Name of the column containing the LLM prompts. Defaults to 'llm_prompt' + llm_response_raw_col_name (str, optional): Name of the column containing prediction. Defaults to 'llm_response'. + llm_response_cleaned_col_name (str, optional): Name of the column for the final processed result. Defaults to 'llm_response_cleaned' + answer_end_token (str, optional): Token to use to separate noise from expected output + + Returns: + dict: Dictionary of extracted answer sequences + """ + prompt_length = len(sample[llm_prompt_col_name]) + + extracted_answer = sample[llm_response_raw_col_name][prompt_length:].strip("\n") + + if answer_end_token: + extracted_answer = extracted_answer.split(answer_end_token)[0] + + return {llm_response_cleaned_col_name: extracted_answer} + + +def run_LLM( + val_data: Union[pd.DataFrame, Dataset], + model: AutoModelForCausalLM, + tokenizer: AutoTokenizer, + ds_column_mapping: dict, + prompt_template_prefix: Union[str, None] = "", + llm_prompt_col_name: str = "llm_prompt", + llm_response_raw_col_name: str = "llm_response", + llm_response_cleaned_col_name: str = "llm_response_cleaned", + answer_start_token: str = "", + answer_end_token: str = "", + **options, +) -> dict: + """Run end-to-end LLM inference (from pre-processing input data to post-processing the predictions) and return the computed performance metrics on input validation data + + Args: + val_data (Union[pd.DataFrame, Dataset]): Validation data with labels + model (AutoModelForCausalLM): LLM artifact. + tokenizer (Autotokenizer) : LLM tokenizer object + ds_column_mapping (dict): prompt entity mapping + prompt_template_prefix (Union[str, None], optional): Text instruction to prepend to each transformed input text sample. Defaults to "". + llm_prompt_col_name (str, optional): Name of the column with the built LLM prompts. Defaults to 'llm_prompt' + llm_response_raw_col_name (str, optional): Name of the column containing prediction. Defaults to 'llm_response'. + llm_response_cleaned_col_name (str, optional): Name of the column containing the final post processed result. Defaults to 'llm_response_cleaned' + answer_start_token (str, optional): Token that indicates the start of answer generation. Defaults to '' + answer_end_token (str, optional): Token that indicates the end of answer generation. Defaults to '' + + Returns: + dict: A dictionary containing F1 score. + """ + predicted_label_list = [] + + val_ds = build_LLM_prompt( + val_data, + ds_column_mapping=ds_column_mapping, + prompt_template_prefix=prompt_template_prefix, + answer_start_token=answer_start_token, + llm_prompt_col_name=llm_prompt_col_name, + ) + + val_ds_with_pred = infer_LLM( + model, + tokenizer, + val_ds, + llm_prompt_col_name=llm_prompt_col_name, + llm_response_raw_col_name=llm_response_raw_col_name, + **options, + ) + + val_ds_with_pred = val_ds_with_pred.map( + _postprocess, + fn_kwargs={ + "llm_prompt_col_name": llm_prompt_col_name, + "llm_response_raw_col_name": llm_response_raw_col_name, + "llm_response_cleaned_col_name": llm_response_cleaned_col_name, + "answer_end_token": answer_end_token, + }, + ) + + return val_ds_with_pred diff --git a/dqc/version.py b/dqc/version.py index 782c293..d28cef5 100644 --- a/dqc/version.py +++ b/dqc/version.py @@ -4,7 +4,7 @@ import sklearn import transformers -__version__ = "0.1.3" +__version__ = "0.2.0" def show_versions(): diff --git a/pyproject.toml b/pyproject.toml index 18da587..501a797 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dqc-toolkit" -version = "0.1.3" +version = "0.2.0" authors = [ { name="Sumanth S Prabhu", email="sumanthprabhu.104@gmail.com" }, ] @@ -16,7 +16,8 @@ dependencies = [ "transformers>=4.39", "sentence_transformers>=2.6.1", "datasets>=2.18", - "scikit-learn>=1.3.2,<1.5", + "accelerate>=0.34.2", + "scikit-learn>=1.3.2", "ruff>=0.3.4" ] keywords = [ diff --git a/tests/llm/conftest.py b/tests/llm/conftest.py new file mode 100644 index 0000000..727a023 --- /dev/null +++ b/tests/llm/conftest.py @@ -0,0 +1,71 @@ +import gc +from typing import Tuple, Union + +import pandas as pd +import pytest +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + + +@pytest.fixture(scope="session") +def data(): + dataset = "ChilleD/SVAMP" + dset = load_dataset(dataset) + data = pd.DataFrame(dset["test"])[["question_concat", "Equation"]] + return data.loc[:10] + + +@pytest.fixture(scope="session") +def ds_column_mapping(): + return { + "valid": {"QUESTION": "question_concat", "EQUATION": "Equation"}, + "invalid_entity": {"": "question_concat", None: "Equation"}, + "invalid_column": {"QUESTION": None, "EQUATION": "nonexistent"}, + } + + +@pytest.fixture(scope="session") +def prompt_variants(): + return { + "valid": [ + "You are a helpful assistant. ", + "You are an honest assistant", + "Please respond with the correct answer to the following question", + ], + "invalid_count": ["", ""], + "invalid_prompt": ["", None, ""], + } + + +@pytest.fixture(scope="session") +def reliability_dataset_row(): + return { + "target_text": "sample sentence.", + "reference_1": "sample sentence.", + "reference_2": "SAMPLE sentence.", + "reference_3": "different sentence.", + } + + +@pytest.fixture(scope="session") +def model_and_tokenizer(): + return initialize_model_and_tokenizer("distilbert/distilgpt2") + + +def initialize_model_and_tokenizer( + model_name: str, +) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: + """Util function to construct the model and tokenizer objects""" + tokenizer = AutoTokenizer.from_pretrained( + model_name, padding_side="left", trust_remote_code=True + ) + model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") + return model, tokenizer + + +@pytest.fixture(scope="session") +def ref_col_set(): + return set( + ["reference_prediction1", "reference_prediction2", "reference_prediction3"] + ) diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py new file mode 100644 index 0000000..2280e80 --- /dev/null +++ b/tests/llm/test_llm.py @@ -0,0 +1,348 @@ +from typing import Callable + +import numpy as np +import pandas as pd +import pytest + +from dqc import LLMCurate +from dqc.llm_utils import compute_reliability_score + + +@pytest.mark.parametrize("model", [None, "random_str", 1]) +@pytest.mark.parametrize("tokenizer", [None, "random_str", 1]) +def test_init_failure( + data, + ds_column_mapping, + prompt_variants, + model, + tokenizer, + ref_col_set, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + with pytest.raises(ValueError): + llmc = LLMCurate( + model=model, + tokenizer=tokenizer, + batch_size=batch_size, + max_new_tokens=max_new_tokens, + verbose=verbose, + ) + ds = llmc.run( + data=data, + column_to_curate="Equation", + ds_column_mapping=ds_column_mapping["valid"], + prompt_variants=prompt_variants["valid"], + llm_response_cleaned_column_list=list(ref_col_set), + batch_size=batch_size, + max_new_tokens=max_new_tokens, + ) + + +def test_init_success( + data, + ds_column_mapping, + prompt_variants, + model_and_tokenizer, + ref_col_set, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + model, tokenizer = model_and_tokenizer + llmc = LLMCurate(model=model, tokenizer=tokenizer, verbose=verbose) + ds = llmc.run( + data=data, + column_to_curate="Equation", + ds_column_mapping=ds_column_mapping["valid"], + prompt_variants=prompt_variants["valid"], + llm_response_cleaned_column_list=list(ref_col_set), + batch_size=batch_size, + max_new_tokens=max_new_tokens, + ) + + ds_col_list = ds.column_names + assert len(set(ds_col_list).intersection(ref_col_set)) == 3 + assert "reliability_score" in ds_col_list + + +@pytest.mark.parametrize("column_to_curate", [None, 1]) +def test_column_to_curate_failure( + data, + column_to_curate, + ds_column_mapping, + prompt_variants, + model_and_tokenizer, + ref_col_set, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + model, tokenizer = model_and_tokenizer + with pytest.raises(ValueError): + llmc = LLMCurate(model=model, tokenizer=tokenizer, verbose=verbose) + ds = llmc.run( + data=data, + column_to_curate=column_to_curate, + ds_column_mapping=ds_column_mapping["valid"], + prompt_variants=prompt_variants["valid"], + llm_response_cleaned_column_list=list(ref_col_set), + batch_size=batch_size, + max_new_tokens=max_new_tokens, + ) + + +@pytest.mark.parametrize("colmap_variant", ["invalid_entity", "invalid_column"]) +def test_dscolmap_failure( + data, + colmap_variant, + ds_column_mapping, + prompt_variants, + model_and_tokenizer, + ref_col_set, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + model, tokenizer = model_and_tokenizer + llmc = LLMCurate(model=model, tokenizer=tokenizer, verbose=verbose) + with pytest.raises(ValueError): + ds = llmc.run( + data=data, + column_to_curate="Equation", + ds_column_mapping=ds_column_mapping[colmap_variant], + prompt_variants=prompt_variants["valid"], + llm_response_cleaned_column_list=list(ref_col_set), + batch_size=batch_size, + max_new_tokens=max_new_tokens, + ) + + ds_col_list = ds.column_names + assert len(set(ds_col_list).intersection(ref_col_set)) == 3 + assert "reliability_score" in ds_col_list + + +@pytest.mark.parametrize("p_variant", ["invalid_count", "invalid_prompt"]) +def test_prompt_failure( + data, + p_variant, + ds_column_mapping, + prompt_variants, + model_and_tokenizer, + ref_col_set, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + model, tokenizer = model_and_tokenizer + llmc = LLMCurate(model=model, tokenizer=tokenizer, verbose=verbose) + with pytest.raises(ValueError): + ds = llmc.run( + data=data, + column_to_curate="Equation", + ds_column_mapping=ds_column_mapping["valid"], + prompt_variants=prompt_variants[p_variant], + llm_response_cleaned_column_list=list(ref_col_set), + batch_size=batch_size, + max_new_tokens=max_new_tokens, + ) + + ds_col_list = ds.column_names + assert len(set(ds_col_list).intersection(ref_col_set)) == 3 + assert "reliability_score" in ds_col_list + + +def test_skip_llm_inference_success( + data, + ds_column_mapping, + prompt_variants, + model_and_tokenizer, + ref_col_set, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + model, tokenizer = model_and_tokenizer + llmc = LLMCurate(model=model, tokenizer=tokenizer, verbose=verbose) + llmc.run( + data=data, + column_to_curate="Equation", + ds_column_mapping=ds_column_mapping["valid"], + prompt_variants=prompt_variants["valid"], + llm_response_cleaned_column_list=list(ref_col_set), + batch_size=batch_size, + max_new_tokens=max_new_tokens, + skip_llm_inference=False, + ) + ds = llmc.run( + column_to_curate="Equation", + llm_response_cleaned_column_list=list(ref_col_set), + skip_llm_inference=True, + ) + + ds_col_list = ds.column_names + assert len(set(ds_col_list).intersection(ref_col_set)) == 3 + assert "reliability_score" in ds_col_list + + +def test_noscores_success( + data, + ds_column_mapping, + prompt_variants, + model_and_tokenizer, + ref_col_set, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + model, tokenizer = model_and_tokenizer + llmc = LLMCurate(model=model, tokenizer=tokenizer, verbose=verbose) + ds = llmc.run( + data=data, + column_to_curate="Equation", + ds_column_mapping=ds_column_mapping["valid"], + prompt_variants=prompt_variants["valid"], + llm_response_cleaned_column_list=list(ref_col_set), + batch_size=batch_size, + max_new_tokens=max_new_tokens, + return_scores=False, + ) + + ds_col_list = ds.column_names + assert len(set(ds_col_list).intersection(ref_col_set)) == 3 + assert "reliability_score" not in ds_col_list + + +@pytest.mark.parametrize("answer_start_token", [None, "[EQUATION]"]) +@pytest.mark.parametrize("answer_end_token", [None, "[/PROPOSED_ANSWER]"]) +def test_answertoken_success( + data, + ds_column_mapping, + prompt_variants, + answer_start_token, + answer_end_token, + model_and_tokenizer, + ref_col_set, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + model, tokenizer = model_and_tokenizer + llmc = LLMCurate(model=model, tokenizer=tokenizer, verbose=verbose) + ds = llmc.run( + data=data, + column_to_curate="Equation", + ds_column_mapping=ds_column_mapping["valid"], + prompt_variants=prompt_variants["valid"], + llm_response_cleaned_column_list=list(ref_col_set), + answer_start_token=answer_start_token, + answer_end_token=answer_end_token, + batch_size=batch_size, + max_new_tokens=max_new_tokens, + return_scores=False, + ) + + ds_col_list = ds.column_names + assert len(set(ds_col_list).intersection(ref_col_set)) == 3 + assert "reliability_score" not in ds_col_list + + +def test_llmc_run_success( + data, + ds_column_mapping, + prompt_variants, + model_and_tokenizer, + ref_col_set, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + model, tokenizer = model_and_tokenizer + llmc = LLMCurate(model=model, tokenizer=tokenizer, verbose=verbose) + ds = llmc.run( + data=data, + column_to_curate="Equation", + ds_column_mapping=ds_column_mapping["valid"], + prompt_variants=prompt_variants["valid"], + llm_response_cleaned_column_list=list(ref_col_set), + answer_start_token="[EQUATION]", + answer_end_token="[/PROPOSED_ANSWER]", + batch_size=batch_size, + max_new_tokens=max_new_tokens, + ) + + ds_col_list = ds.column_names + assert len(set(ds_col_list).intersection(ref_col_set)) == 3 + assert "reliability_score" in ds_col_list + + +@pytest.mark.parametrize("verbose", [True, False]) +def test_verbosity( + data, + ds_column_mapping, + prompt_variants, + model_and_tokenizer, + ref_col_set, + verbose, + batch_size=1, + max_new_tokens=1, +): + model, tokenizer = model_and_tokenizer + llmc = LLMCurate(model=model, tokenizer=tokenizer) + ds = llmc.run( + data=data, + column_to_curate="Equation", + ds_column_mapping=ds_column_mapping["valid"], + prompt_variants=prompt_variants["valid"], + llm_response_cleaned_column_list=list(ref_col_set), + batch_size=batch_size, + max_new_tokens=max_new_tokens, + ) + + ds_col_list = ds.column_names + assert len(set(ds_col_list).intersection(ref_col_set)) == 3 + assert "reliability_score" in ds_col_list + + +def exact_match(text1: str, text2: str) -> float: + return 1.0 if text1 == text2 else 0.0 + + +@pytest.fixture +def scoring_method() -> Callable[[str, str], float]: + return exact_match + + +def test_exact_match_case_insensitive(reliability_dataset_row, scoring_method): + res = compute_reliability_score( + example=reliability_dataset_row, + target_column="target_text", + reference_column_list=["reference_1", "reference_2", "reference_3"], + scoring_method=scoring_method, + ) + + assert res["reliability_score"] == pytest.approx(2 / 3) + + +def test_exact_match_case_sensitive(reliability_dataset_row, scoring_method): + res = compute_reliability_score( + example=reliability_dataset_row, + target_column="target_text", + reference_column_list=["reference_1", "reference_2", "reference_3"], + scoring_method=scoring_method, + case_sensitive=True, + ) + + # assert 0 <= res['reliability_score'] <= 1 + assert res["reliability_score"] == pytest.approx(1 / 3) + + +def test_invalid_scoring_method(reliability_dataset_row): + with pytest.raises(ValueError): + compute_reliability_score( + example=reliability_dataset_row, + target_column="target_text", + reference_column_list=["reference_1", "reference_2", "reference_3"], + scoring_method="invalid_method", + ) From fc2aa9da4a390a49d1e5ebb2316d11ec8b119095 Mon Sep 17 00:00:00 2001 From: sumanthprabhu Date: Mon, 16 Sep 2024 16:58:16 +0530 Subject: [PATCH 2/6] Add LLMCurate - Implemented `set_seed` in base.py to ensure reproducible results --- dqc/base.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dqc/base.py b/dqc/base.py index c676906..14511d5 100644 --- a/dqc/base.py +++ b/dqc/base.py @@ -1,6 +1,9 @@ +import random from abc import ABC, abstractmethod from typing import Union +import numpy as np +import torch from pandas._typing import RandomState @@ -15,10 +18,19 @@ class BaseCurate(ABC): def __init__( self, - random_state: RandomState = 42, + random_state: Union[int, RandomState] = 42, **options, ): self.random_state = random_state + self._set_seed(random_state) + + def _set_seed(self, seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True @abstractmethod def fit_transform(self): ... From 18da18d810df455e1b4438242485a69ef86bdffc Mon Sep 17 00:00:00 2001 From: sumanthprabhu Date: Mon, 16 Sep 2024 20:38:40 +0530 Subject: [PATCH 3/6] Add LLMCurate - Implemented `set_seed` for LLM Utils to ensure reproducibility. - Added corresponding unit tests --- dqc/llm.py | 1 + dqc/llm_utils/inference.py | 17 +++++++++++++++++ tests/llm/test_llm.py | 27 ++++++++++++++++++++++++++- 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/dqc/llm.py b/dqc/llm.py index b2cc5ef..91d0946 100644 --- a/dqc/llm.py +++ b/dqc/llm.py @@ -135,6 +135,7 @@ def run( answer_start_token=answer_start_token, answer_end_token=answer_end_token, llm_response_cleaned_col_name=proposed_answer_col_name, + random_state=self.random_state, **options, ) diff --git a/dqc/llm_utils/inference.py b/dqc/llm_utils/inference.py index 87c33aa..24d7005 100644 --- a/dqc/llm_utils/inference.py +++ b/dqc/llm_utils/inference.py @@ -1,8 +1,11 @@ import gc +import random from typing import Union import datasets +import numpy as np import pandas as pd +import torch from datasets import Dataset from tqdm import tqdm from transformers import ( @@ -13,6 +16,16 @@ ) +def _set_seed(seed): + print("Seed : ", seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + def _generate_predictions( example: datasets.formatting.formatting.LazyBatch, generator: pipeline, @@ -118,6 +131,10 @@ def infer_LLM( Returns: dataset: Dataset with generated predictions. """ + if options["random_state"]: + _set_seed(options["random_state"]) + del options["random_state"] + text_generator = pipeline( "text-generation", model=model, tokenizer=tokenizer, truncation=False, **options ) diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py index 2280e80..0f6db5f 100644 --- a/tests/llm/test_llm.py +++ b/tests/llm/test_llm.py @@ -17,6 +17,7 @@ def test_init_failure( model, tokenizer, ref_col_set, + random_state=43, batch_size=1, max_new_tokens=1, verbose=False, @@ -28,6 +29,7 @@ def test_init_failure( batch_size=batch_size, max_new_tokens=max_new_tokens, verbose=verbose, + random_state=random_state, ) ds = llmc.run( data=data, @@ -46,12 +48,15 @@ def test_init_success( prompt_variants, model_and_tokenizer, ref_col_set, + random_state=43, batch_size=1, max_new_tokens=1, verbose=False, ): model, tokenizer = model_and_tokenizer - llmc = LLMCurate(model=model, tokenizer=tokenizer, verbose=verbose) + llmc = LLMCurate( + model=model, tokenizer=tokenizer, random_state=random_state, verbose=verbose + ) ds = llmc.run( data=data, column_to_curate="Equation", @@ -67,6 +72,26 @@ def test_init_success( assert "reliability_score" in ds_col_list +@pytest.mark.parametrize("random_state", [None, "random_str"]) +def test_random_state_failure( + model_and_tokenizer, + random_state, + batch_size=1, + max_new_tokens=1, + verbose=False, +): + model, tokenizer = model_and_tokenizer + with pytest.raises((TypeError, ValueError)): + llmc = LLMCurate( + model=model, + tokenizer=tokenizer, + batch_size=batch_size, + max_new_tokens=max_new_tokens, + verbose=verbose, + random_state=random_state, + ) + + @pytest.mark.parametrize("column_to_curate", [None, 1]) def test_column_to_curate_failure( data, From aff0809a11bf2cde4b755827f9aca9f077e8ae14 Mon Sep 17 00:00:00 2001 From: sumanthprabhu Date: Tue, 17 Sep 2024 10:53:30 +0530 Subject: [PATCH 4/6] Add LLMCurate - Code refactoring for LLM Util functions --- dqc/llm.py | 13 +++---- dqc/llm_utils/__init__.py | 9 +++-- dqc/llm_utils/compute_confidence_score.py | 15 ++++---- dqc/llm_utils/inference.py | 1 - tests/llm/conftest.py | 2 +- tests/llm/test_llm.py | 44 ++++++++++++----------- 6 files changed, 46 insertions(+), 38 deletions(-) diff --git a/dqc/llm.py b/dqc/llm.py index 91d0946..511fa60 100644 --- a/dqc/llm.py +++ b/dqc/llm.py @@ -9,7 +9,7 @@ _empty_ds_ensemble_handler, _validate_init_params, _validate_run_params, - compute_reliability_score, + compute_selfensembling_confidence_score, run_LLM, ) from dqc.utils import Logger @@ -94,13 +94,13 @@ def run( prompt_variants (List[str], optional): List of different LLM prompts to be used to curate the labels under `column_to_curate`. Defaults to ['']. skip_llm_inference (bool, optional): Indicator variable to prevent re-running LLM inference. Set to `True` if artifacts from the previous run of LLMCurate needs to be reused. Else `False`. Defaults to False. llm_response_cleaned_column_list (list, optional): Names of the columns that will contain LLM predictions for each input prompt in `prompt_variants`. Defaults to ['reference_prediction']. - return_scores (bool, optional): Indicator variable set to `True` if label reliability scores are to be computed for each label under `column_to_curate`. Defaults to True. + return_scores (bool, optional): Indicator variable set to `True` if label confidence scores are to be computed for each label under `column_to_curate`. Defaults to True. answer_start_token (str, optional): Token that indicates the start of answer generation. Defaults to '' answer_end_token (str, optional): Token that indicates the end of answer generation. Defaults to '' - scoring_method (Union[Callable[[str, str], float], str], optional): A function or the string 'exact_match' to compute the reliability score. Defaults to 'exact_match'. + scoring_method (Union[Callable[[str, str], float], str], optional): A function or the string 'exact_match' to compute the confidence score. Defaults to 'exact_match'. Returns: - Dataset: Input dataset with reference responses. If `return_scores=True`, then input dataset with reference responses and reliability scores. + Dataset: Input dataset with reference responses. If `return_scores=True`, then input dataset with reference responses and confidence scores. """ if not skip_llm_inference: empty_string_col_list = _validate_run_params( @@ -161,14 +161,15 @@ def run( _empty_ds_ensemble_handler(len(self.ds_ensemble) == 0, skip_llm_inference) logger.info( - "Computing reliability scores using the LLM reference responses.." + "Computing confidence scores using the LLM reference responses.." ) self.ds_ensemble = self.ds_ensemble.map( - compute_reliability_score, + compute_selfensembling_confidence_score, fn_kwargs={ "target_column": column_to_curate, "reference_column_list": llm_response_cleaned_column_list, "scoring_method": scoring_method, + **options, }, ) diff --git a/dqc/llm_utils/__init__.py b/dqc/llm_utils/__init__.py index 0dc8ef5..54faf14 100644 --- a/dqc/llm_utils/__init__.py +++ b/dqc/llm_utils/__init__.py @@ -3,7 +3,12 @@ _validate_init_params, _validate_run_params, ) -from .compute_confidence_score import compute_reliability_score +from .compute_confidence_score import compute_selfensembling_confidence_score from .inference import build_LLM_prompt, infer_LLM, run_LLM -__all__ = ["build_LLM_prompt", "compute_reliability_score", "infer_LLM", "run_LLM"] +__all__ = [ + "build_LLM_prompt", + "selfensembling_confidence_score", + "infer_LLM", + "run_LLM", +] diff --git a/dqc/llm_utils/compute_confidence_score.py b/dqc/llm_utils/compute_confidence_score.py index 5da13f5..fd7d813 100644 --- a/dqc/llm_utils/compute_confidence_score.py +++ b/dqc/llm_utils/compute_confidence_score.py @@ -32,7 +32,7 @@ def _compute_custom_match_score( Args: target_text (str): The text string that will be compared to each text in `reference_text_list`. reference_text_list (List[str]): List of text strings that need to be individually matched against `target_text` - scoring_method (Union[Callable, str]): A function or the string 'exact_match' to compute the reliability score. Defaults to 'exact_match'. + scoring_method (Union[Callable, str]): A function or the string 'exact_match' to compute the confidence score. Defaults to 'exact_match'. Returns: float: Score between 0 and 1 indicating how closely the texts in `reference_text_list` match the `target_text`. A score of 1 means perfect matches for all entries, @@ -44,26 +44,27 @@ def _compute_custom_match_score( return matches.mean() -def compute_reliability_score( +def compute_selfensembling_confidence_score( example: datasets.formatting.formatting.LazyRow, target_column: str, reference_column_list: List[str], scoring_method: Union[Callable[[str, str], float], str] = "exact_match", case_sensitive: bool = False, + **options, ) -> float: - """Util function to assess the reliability of a given target text using LLM generated reference texts. + """Util function to compute confidence score of a given target text using LLM generated reference texts. Args: example (datasets.formatting.formatting.LazyRow): A row of data from a dataset containing the target and reference texts. - target_column (str): Name of the column containing the target text for estimation of reliability. + target_column (str): Name of the column containing the target text for estimation of confidence score. reference_column_list (List[str]): Names of the columns containing the reference texts to be compared with the target text. - scoring_method (Union[Callable[[str, str], float], str], optional): A function or the string 'exact_match' to compute the reliability score. Defaults to 'exact_match'. + scoring_method (Union[Callable[[str, str], float], str], optional): A function or the string 'exact_match' to compute the confidence score. Defaults to 'exact_match'. case_sensitive (bool, optional): `True` if string comparisons need to be case aware. Else `False`. Defaults to `False` Raises: ValueError: If `scoring_method` is neither 'exact_match' nor a valid callable function Returns: - float: Score between 0 and 1 quantifying the reliability of the target text + float: Score between 0 and 1 quantifying the confidence score for the target text """ if not callable(scoring_method) and scoring_method != "exact_match": raise ValueError( @@ -90,4 +91,4 @@ def compute_reliability_score( target_text, reference_text_list, scoring_method ) - return {"reliability_score": score} + return {"confidence_score": score} diff --git a/dqc/llm_utils/inference.py b/dqc/llm_utils/inference.py index 24d7005..9833378 100644 --- a/dqc/llm_utils/inference.py +++ b/dqc/llm_utils/inference.py @@ -17,7 +17,6 @@ def _set_seed(seed): - print("Seed : ", seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) diff --git a/tests/llm/conftest.py b/tests/llm/conftest.py index 727a023..c1fd173 100644 --- a/tests/llm/conftest.py +++ b/tests/llm/conftest.py @@ -39,7 +39,7 @@ def prompt_variants(): @pytest.fixture(scope="session") -def reliability_dataset_row(): +def confidence_dataset_row(): return { "target_text": "sample sentence.", "reference_1": "sample sentence.", diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py index 0f6db5f..01e8046 100644 --- a/tests/llm/test_llm.py +++ b/tests/llm/test_llm.py @@ -5,7 +5,7 @@ import pytest from dqc import LLMCurate -from dqc.llm_utils import compute_reliability_score +from dqc.llm_utils import compute_selfensembling_confidence_score @pytest.mark.parametrize("model", [None, "random_str", 1]) @@ -69,7 +69,7 @@ def test_init_success( ds_col_list = ds.column_names assert len(set(ds_col_list).intersection(ref_col_set)) == 3 - assert "reliability_score" in ds_col_list + assert "confidence_score" in ds_col_list @pytest.mark.parametrize("random_state", [None, "random_str"]) @@ -145,7 +145,7 @@ def test_dscolmap_failure( ds_col_list = ds.column_names assert len(set(ds_col_list).intersection(ref_col_set)) == 3 - assert "reliability_score" in ds_col_list + assert "confidence_score" in ds_col_list @pytest.mark.parametrize("p_variant", ["invalid_count", "invalid_prompt"]) @@ -175,15 +175,17 @@ def test_prompt_failure( ds_col_list = ds.column_names assert len(set(ds_col_list).intersection(ref_col_set)) == 3 - assert "reliability_score" in ds_col_list + assert "confidence_score" in ds_col_list +@pytest.mark.parametrize("case_sensitive", [True, False]) def test_skip_llm_inference_success( data, ds_column_mapping, prompt_variants, model_and_tokenizer, ref_col_set, + case_sensitive, batch_size=1, max_new_tokens=1, verbose=False, @@ -204,11 +206,12 @@ def test_skip_llm_inference_success( column_to_curate="Equation", llm_response_cleaned_column_list=list(ref_col_set), skip_llm_inference=True, + case_sensitive=case_sensitive, ) ds_col_list = ds.column_names assert len(set(ds_col_list).intersection(ref_col_set)) == 3 - assert "reliability_score" in ds_col_list + assert "confidence_score" in ds_col_list def test_noscores_success( @@ -236,7 +239,7 @@ def test_noscores_success( ds_col_list = ds.column_names assert len(set(ds_col_list).intersection(ref_col_set)) == 3 - assert "reliability_score" not in ds_col_list + assert "confidence_score" not in ds_col_list @pytest.mark.parametrize("answer_start_token", [None, "[EQUATION]"]) @@ -270,7 +273,7 @@ def test_answertoken_success( ds_col_list = ds.column_names assert len(set(ds_col_list).intersection(ref_col_set)) == 3 - assert "reliability_score" not in ds_col_list + assert "confidence_score" not in ds_col_list def test_llmc_run_success( @@ -299,7 +302,7 @@ def test_llmc_run_success( ds_col_list = ds.column_names assert len(set(ds_col_list).intersection(ref_col_set)) == 3 - assert "reliability_score" in ds_col_list + assert "confidence_score" in ds_col_list @pytest.mark.parametrize("verbose", [True, False]) @@ -327,7 +330,7 @@ def test_verbosity( ds_col_list = ds.column_names assert len(set(ds_col_list).intersection(ref_col_set)) == 3 - assert "reliability_score" in ds_col_list + assert "confidence_score" in ds_col_list def exact_match(text1: str, text2: str) -> float: @@ -339,34 +342,33 @@ def scoring_method() -> Callable[[str, str], float]: return exact_match -def test_exact_match_case_insensitive(reliability_dataset_row, scoring_method): - res = compute_reliability_score( - example=reliability_dataset_row, +def test_exact_match_case_insensitive(confidence_dataset_row, scoring_method): + res = compute_selfensembling_confidence_score( + example=confidence_dataset_row, target_column="target_text", reference_column_list=["reference_1", "reference_2", "reference_3"], scoring_method=scoring_method, ) - assert res["reliability_score"] == pytest.approx(2 / 3) + assert res["confidence_score"] == pytest.approx(2 / 3) -def test_exact_match_case_sensitive(reliability_dataset_row, scoring_method): - res = compute_reliability_score( - example=reliability_dataset_row, +def test_exact_match_case_sensitive(confidence_dataset_row, scoring_method): + res = compute_selfensembling_confidence_score( + example=confidence_dataset_row, target_column="target_text", reference_column_list=["reference_1", "reference_2", "reference_3"], scoring_method=scoring_method, case_sensitive=True, ) - # assert 0 <= res['reliability_score'] <= 1 - assert res["reliability_score"] == pytest.approx(1 / 3) + assert res["confidence_score"] == pytest.approx(1 / 3) -def test_invalid_scoring_method(reliability_dataset_row): +def test_invalid_scoring_method(confidence_dataset_row): with pytest.raises(ValueError): - compute_reliability_score( - example=reliability_dataset_row, + compute_selfensembling_confidence_score( + example=confidence_dataset_row, target_column="target_text", reference_column_list=["reference_1", "reference_2", "reference_3"], scoring_method="invalid_method", From cf36d2a8003c0ac3e5a75753f679828f91c50fb1 Mon Sep 17 00:00:00 2001 From: sumanthprabhu Date: Tue, 17 Sep 2024 11:42:45 +0530 Subject: [PATCH 5/6] Add LLMCurate - Added test to verify that all imports work correctly --- dqc/llm_utils/__init__.py | 2 +- tests/llm/__init__.py | 0 tests/llm/test_llm.py | 8 ++++++++ 3 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 tests/llm/__init__.py diff --git a/dqc/llm_utils/__init__.py b/dqc/llm_utils/__init__.py index 54faf14..18dd3a6 100644 --- a/dqc/llm_utils/__init__.py +++ b/dqc/llm_utils/__init__.py @@ -8,7 +8,7 @@ __all__ = [ "build_LLM_prompt", - "selfensembling_confidence_score", + "compute_selfensembling_confidence_score", "infer_LLM", "run_LLM", ] diff --git a/tests/llm/__init__.py b/tests/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py index 01e8046..898fdae 100644 --- a/tests/llm/test_llm.py +++ b/tests/llm/test_llm.py @@ -333,6 +333,14 @@ def test_verbosity( assert "confidence_score" in ds_col_list +def test_import_star(): + try: + exec("from dqc import *") + exec("from dqc.llm_utils import *") + except Exception as e: + pytest.fail(f"Importing with * raised an error: {e}") + + def exact_match(text1: str, text2: str) -> float: return 1.0 if text1 == text2 else 0.0 From 99f915eb878c63b87681b2049c22ba1aea63be44 Mon Sep 17 00:00:00 2001 From: sumanthprabhu Date: Tue, 17 Sep 2024 16:52:33 +0530 Subject: [PATCH 6/6] Add LLMCurate - Refactored `run` method to aggregate parameters related computing confidence scores into a single dictionary `scoring_params` - Refactored corresponding tests --- dqc/llm.py | 10 ++++++---- tests/llm/test_llm.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dqc/llm.py b/dqc/llm.py index 511fa60..c19c7c0 100644 --- a/dqc/llm.py +++ b/dqc/llm.py @@ -82,7 +82,10 @@ def run( return_scores: bool = True, answer_start_token: str = "", answer_end_token: str = "", - scoring_method: Union[Callable[[str, str], float], str] = "exact_match", + scoring_params: dict = { + "scoring_method": "exact_match", + "case_sensitive": False, + }, **options, ) -> Dataset: """Run LLMCurate on the input data @@ -97,7 +100,7 @@ def run( return_scores (bool, optional): Indicator variable set to `True` if label confidence scores are to be computed for each label under `column_to_curate`. Defaults to True. answer_start_token (str, optional): Token that indicates the start of answer generation. Defaults to '' answer_end_token (str, optional): Token that indicates the end of answer generation. Defaults to '' - scoring_method (Union[Callable[[str, str], float], str], optional): A function or the string 'exact_match' to compute the confidence score. Defaults to 'exact_match'. + scoring_params (dict, optional): Parameters related to util function `compute_selfensembling_confidence_score` to compute confidence scores of `column_to_curate` Returns: Dataset: Input dataset with reference responses. If `return_scores=True`, then input dataset with reference responses and confidence scores. @@ -168,8 +171,7 @@ def run( fn_kwargs={ "target_column": column_to_curate, "reference_column_list": llm_response_cleaned_column_list, - "scoring_method": scoring_method, - **options, + **scoring_params, }, ) diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py index 898fdae..bca5f02 100644 --- a/tests/llm/test_llm.py +++ b/tests/llm/test_llm.py @@ -206,7 +206,7 @@ def test_skip_llm_inference_success( column_to_curate="Equation", llm_response_cleaned_column_list=list(ref_col_set), skip_llm_inference=True, - case_sensitive=case_sensitive, + scoring_params={"case_sensitive": case_sensitive}, ) ds_col_list = ds.column_names