Skip to content

Commit

Permalink
Feature LLM based Curation (#10)
Browse files Browse the repository at this point in the history
* Add LLMCurate
- Implemented LLM based curation using self-ensembling to compute reliability scores for text generation usecases
- Implemented util functions to perform batch inference using LLMs
- Added corresponding tests
- Updated the dependencies in pyproject.toml
- Bumped the version from 0.1.3 to 0.2.0
  • Loading branch information
sumanthprabhu authored Sep 26, 2024
1 parent 2265381 commit f6a96b3
Show file tree
Hide file tree
Showing 12 changed files with 1,237 additions and 5 deletions.
3 changes: 2 additions & 1 deletion dqc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .crossval import CrossValCurate
from .llm import LLMCurate
from .version import __version__, show_versions

__all__ = ["CrossValCurate"]
__all__ = ["CrossValCurate", "LLMCurate"]
14 changes: 13 additions & 1 deletion dqc/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import random
from abc import ABC, abstractmethod
from typing import Union

import numpy as np
import torch
from pandas._typing import RandomState


Expand All @@ -15,10 +18,19 @@ class BaseCurate(ABC):

def __init__(
self,
random_state: RandomState = 42,
random_state: Union[int, RandomState] = 42,
**options,
):
self.random_state = random_state
self._set_seed(random_state)

def _set_seed(self, seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

@abstractmethod
def fit_transform(self): ...
178 changes: 178 additions & 0 deletions dqc/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing import Callable, List, Tuple, Union

import pandas as pd
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from dqc.base import BaseCurate
from dqc.llm_utils import (
_empty_ds_ensemble_handler,
_validate_init_params,
_validate_run_params,
compute_selfensembling_confidence_score,
run_LLM,
)
from dqc.utils import Logger

logger = Logger("DQC-Toolkit")


class LLMCurate(BaseCurate):
"""
Args:
model (AutoModelForCausalLM): Instantiated LLM
tokenizer (AutoTokenizer): Instantiated tokenizer corresponding to the `model`
verbose (bool, optional): Sets the verbosity level during execution. `True` indicates logging level INFO and `False` indicates logging level 'WARNING'. Defaults to False.
Examples:
"""

def __init__(
self,
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
verbose: bool = False,
**options,
):
super().__init__(**options)

_validate_init_params(model, tokenizer, verbose)
self.model = model
self.tokenizer = tokenizer
self.verbose = verbose
self._set_verbosity(verbose)

self.ds_ensemble = None

def __str__(self):
display_dict = self.__dict__.copy()

for key in list(display_dict.keys()):
if key in ["ds_ensemble"]:
## Don't need to display these attributes
del display_dict[key]

return str(display_dict)

__repr__ = __str__

def _set_verbosity(self, verbose: bool):
"""Set logger level based on user input for parameter `verbose`
Args:
verbose (bool): Indicator for verbosity
"""
if verbose:
logger.set_level("INFO")
else:
logger.set_level("WARNING")

def fit_transform(self):
pass

def run(
self,
column_to_curate: str,
data: Union[pd.DataFrame, Dataset] = None,
ds_column_mapping: dict = {},
prompt_variants: List[str] = [""],
skip_llm_inference: bool = False,
llm_response_cleaned_column_list: List[str] = ["reference_prediction"],
return_scores: bool = True,
answer_start_token: str = "",
answer_end_token: str = "",
scoring_params: dict = {
"scoring_method": "exact_match",
"case_sensitive": False,
},
**options,
) -> Dataset:
"""Run LLMCurate on the input data
Args:
column_to_curate (str): Column name in `data` with the text that needs to be curated
data (Union[pd.DataFrame, Dataset]): Input data for LLM based curation
ds_column_mapping (dict, optional): Mapping of entities to be used in the LLM prompt to the corresponding columns in the input data. Defaults to {}.
prompt_variants (List[str], optional): List of different LLM prompts to be used to curate the labels under `column_to_curate`. Defaults to [''].
skip_llm_inference (bool, optional): Indicator variable to prevent re-running LLM inference. Set to `True` if artifacts from the previous run of LLMCurate needs to be reused. Else `False`. Defaults to False.
llm_response_cleaned_column_list (list, optional): Names of the columns that will contain LLM predictions for each input prompt in `prompt_variants`. Defaults to ['reference_prediction'].
return_scores (bool, optional): Indicator variable set to `True` if label confidence scores are to be computed for each label under `column_to_curate`. Defaults to True.
answer_start_token (str, optional): Token that indicates the start of answer generation. Defaults to ''
answer_end_token (str, optional): Token that indicates the end of answer generation. Defaults to ''
scoring_params (dict, optional): Parameters related to util function `compute_selfensembling_confidence_score` to compute confidence scores of `column_to_curate`
Returns:
Dataset: Input dataset with reference responses. If `return_scores=True`, then input dataset with reference responses and confidence scores.
"""
if not skip_llm_inference:
empty_string_col_list = _validate_run_params(
data,
column_to_curate,
ds_column_mapping,
prompt_variants,
llm_response_cleaned_column_list,
)

if len(empty_string_col_list) > 0:
logger.warning(
"Found empty string(s) in the input data under column(s) {empty_string_col_list}"
)

logger.info(
f"Running the LLM to generate the {len(prompt_variants)} reference responses using `prompt_variants`.."
)
ds_ensemble = None

model = self.model
tokenizer = self.tokenizer

for index, prompt_template_prefix in enumerate(prompt_variants):
proposed_answer_col_name = llm_response_cleaned_column_list[index]
ds = run_LLM(
data,
model,
tokenizer,
ds_column_mapping=ds_column_mapping,
prompt_template_prefix=prompt_template_prefix,
answer_start_token=answer_start_token,
answer_end_token=answer_end_token,
llm_response_cleaned_col_name=proposed_answer_col_name,
random_state=self.random_state,
**options,
)

if not ds_ensemble:
ds_ensemble = ds
else:
ds_ensemble = ds_ensemble.add_column(
proposed_answer_col_name, ds[proposed_answer_col_name]
)
self.ds_ensemble = ds_ensemble

if return_scores:
if skip_llm_inference:
if (
isinstance(data, pd.DataFrame)
or ds_column_mapping
or prompt_variants
):
logger.warning(
"Ignoring params `data`, `ds_column_mapping` and `prompt_variants` since `skip_llm_inference` is set to `True`"
)

_empty_ds_ensemble_handler(len(self.ds_ensemble) == 0, skip_llm_inference)

logger.info(
"Computing confidence scores using the LLM reference responses.."
)
self.ds_ensemble = self.ds_ensemble.map(
compute_selfensembling_confidence_score,
fn_kwargs={
"target_column": column_to_curate,
"reference_column_list": llm_response_cleaned_column_list,
**scoring_params,
},
)

return self.ds_ensemble
14 changes: 14 additions & 0 deletions dqc/llm_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from ._sanity_checks import (
_empty_ds_ensemble_handler,
_validate_init_params,
_validate_run_params,
)
from .compute_confidence_score import compute_selfensembling_confidence_score
from .inference import build_LLM_prompt, infer_LLM, run_LLM

__all__ = [
"build_LLM_prompt",
"compute_selfensembling_confidence_score",
"infer_LLM",
"run_LLM",
]
Loading

0 comments on commit f6a96b3

Please sign in to comment.