Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature LLM based Curation #10

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dqc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .crossval import CrossValCurate
from .llm import LLMCurate
from .version import __version__, show_versions

__all__ = ["CrossValCurate"]
__all__ = ["CrossValCurate", "LLMCurate"]
14 changes: 13 additions & 1 deletion dqc/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import random
from abc import ABC, abstractmethod
from typing import Union

import numpy as np
import torch
from pandas._typing import RandomState


Expand All @@ -15,10 +18,19 @@ class BaseCurate(ABC):

def __init__(
self,
random_state: RandomState = 42,
random_state: Union[int, RandomState] = 42,
**options,
):
self.random_state = random_state
self._set_seed(random_state)

def _set_seed(self, seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

@abstractmethod
def fit_transform(self): ...
178 changes: 178 additions & 0 deletions dqc/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing import Callable, List, Tuple, Union

import pandas as pd
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from dqc.base import BaseCurate
from dqc.llm_utils import (
_empty_ds_ensemble_handler,
_validate_init_params,
_validate_run_params,
compute_selfensembling_confidence_score,
run_LLM,
)
from dqc.utils import Logger

logger = Logger("DQC-Toolkit")


class LLMCurate(BaseCurate):
"""
Args:
model (AutoModelForCausalLM): Instantiated LLM
tokenizer (AutoTokenizer): Instantiated tokenizer corresponding to the `model`
verbose (bool, optional): Sets the verbosity level during execution. `True` indicates logging level INFO and `False` indicates logging level 'WARNING'. Defaults to False.

Examples:

"""

def __init__(
self,
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
verbose: bool = False,
**options,
):
super().__init__(**options)

_validate_init_params(model, tokenizer, verbose)
self.model = model
self.tokenizer = tokenizer
self.verbose = verbose
self._set_verbosity(verbose)

self.ds_ensemble = None

def __str__(self):
display_dict = self.__dict__.copy()

for key in list(display_dict.keys()):
if key in ["ds_ensemble"]:
## Don't need to display these attributes
del display_dict[key]

return str(display_dict)

__repr__ = __str__

def _set_verbosity(self, verbose: bool):
"""Set logger level based on user input for parameter `verbose`

Args:
verbose (bool): Indicator for verbosity
"""
if verbose:
logger.set_level("INFO")
else:
logger.set_level("WARNING")

def fit_transform(self):
pass

def run(
self,
column_to_curate: str,
data: Union[pd.DataFrame, Dataset] = None,
ds_column_mapping: dict = {},
prompt_variants: List[str] = [""],
skip_llm_inference: bool = False,
llm_response_cleaned_column_list: List[str] = ["reference_prediction"],
return_scores: bool = True,
answer_start_token: str = "",
answer_end_token: str = "",
scoring_params: dict = {
"scoring_method": "exact_match",
"case_sensitive": False,
},
**options,
) -> Dataset:
"""Run LLMCurate on the input data

Args:
column_to_curate (str): Column name in `data` with the text that needs to be curated
data (Union[pd.DataFrame, Dataset]): Input data for LLM based curation
ds_column_mapping (dict, optional): Mapping of entities to be used in the LLM prompt to the corresponding columns in the input data. Defaults to {}.
prompt_variants (List[str], optional): List of different LLM prompts to be used to curate the labels under `column_to_curate`. Defaults to [''].
skip_llm_inference (bool, optional): Indicator variable to prevent re-running LLM inference. Set to `True` if artifacts from the previous run of LLMCurate needs to be reused. Else `False`. Defaults to False.
llm_response_cleaned_column_list (list, optional): Names of the columns that will contain LLM predictions for each input prompt in `prompt_variants`. Defaults to ['reference_prediction'].
return_scores (bool, optional): Indicator variable set to `True` if label confidence scores are to be computed for each label under `column_to_curate`. Defaults to True.
answer_start_token (str, optional): Token that indicates the start of answer generation. Defaults to ''
answer_end_token (str, optional): Token that indicates the end of answer generation. Defaults to ''
scoring_params (dict, optional): Parameters related to util function `compute_selfensembling_confidence_score` to compute confidence scores of `column_to_curate`

Returns:
Dataset: Input dataset with reference responses. If `return_scores=True`, then input dataset with reference responses and confidence scores.
"""
if not skip_llm_inference:
empty_string_col_list = _validate_run_params(
data,
column_to_curate,
ds_column_mapping,
prompt_variants,
llm_response_cleaned_column_list,
)

if len(empty_string_col_list) > 0:
logger.warning(
"Found empty string(s) in the input data under column(s) {empty_string_col_list}"
)

logger.info(
f"Running the LLM to generate the {len(prompt_variants)} reference responses using `prompt_variants`.."
)
ds_ensemble = None

model = self.model
tokenizer = self.tokenizer

for index, prompt_template_prefix in enumerate(prompt_variants):
proposed_answer_col_name = llm_response_cleaned_column_list[index]
ds = run_LLM(
data,
model,
tokenizer,
ds_column_mapping=ds_column_mapping,
prompt_template_prefix=prompt_template_prefix,
answer_start_token=answer_start_token,
answer_end_token=answer_end_token,
llm_response_cleaned_col_name=proposed_answer_col_name,
random_state=self.random_state,
**options,
)

if not ds_ensemble:
ds_ensemble = ds
else:
ds_ensemble = ds_ensemble.add_column(
proposed_answer_col_name, ds[proposed_answer_col_name]
)
self.ds_ensemble = ds_ensemble

if return_scores:
if skip_llm_inference:
if (
isinstance(data, pd.DataFrame)
or ds_column_mapping
or prompt_variants
):
logger.warning(
"Ignoring params `data`, `ds_column_mapping` and `prompt_variants` since `skip_llm_inference` is set to `True`"
)

_empty_ds_ensemble_handler(len(self.ds_ensemble) == 0, skip_llm_inference)

logger.info(
"Computing confidence scores using the LLM reference responses.."
)
self.ds_ensemble = self.ds_ensemble.map(
compute_selfensembling_confidence_score,
fn_kwargs={
"target_column": column_to_curate,
"reference_column_list": llm_response_cleaned_column_list,
**scoring_params,
},
)

return self.ds_ensemble
14 changes: 14 additions & 0 deletions dqc/llm_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from ._sanity_checks import (
_empty_ds_ensemble_handler,
_validate_init_params,
_validate_run_params,
)
from .compute_confidence_score import compute_selfensembling_confidence_score
from .inference import build_LLM_prompt, infer_LLM, run_LLM

__all__ = [
"build_LLM_prompt",
"compute_selfensembling_confidence_score",
"infer_LLM",
"run_LLM",
]
Loading
Loading