From fdb99b148e84bce8be716c078e92f6255ad3b374 Mon Sep 17 00:00:00 2001 From: Kirupa Gunaseelan Date: Thu, 25 Jul 2024 15:28:03 -0700 Subject: [PATCH] Add BertScoreMax transform for qa_accuracy --- src/fmeval/eval_algorithms/qa_accuracy.py | 120 ++---------------- src/fmeval/transforms/common.py | 79 +++++++++++- test/unit/eval_algorithms/test_qa_accuracy.py | 33 +---- 3 files changed, 90 insertions(+), 142 deletions(-) diff --git a/src/fmeval/eval_algorithms/qa_accuracy.py b/src/fmeval/eval_algorithms/qa_accuracy.py index efb45bbc..b71dd273 100644 --- a/src/fmeval/eval_algorithms/qa_accuracy.py +++ b/src/fmeval/eval_algorithms/qa_accuracy.py @@ -4,9 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union from dataclasses import dataclass -import ray from nltk.metrics.scores import f_measure, precision, recall -from ray.actor import ActorHandle from fmeval.constants import ( BERTSCORE_DEFAULT_MODEL, @@ -25,6 +23,7 @@ EvalOutput, EvalScore, ) +from fmeval.transforms.common import BertScoreMax, BERT_SCORE from fmeval.model_runners.model_runner import ModelRunner from fmeval.transforms.transform import Transform from fmeval.transforms.transform_pipeline import TransformPipeline @@ -37,14 +36,11 @@ assert_condition, ) -# from fmeval.transforms.qa_accuracy_metrics import BertScoreWithDelimiter, BERT_SCORE - F1_SCORE = "f1_score" EXACT_MATCH_SCORE = "exact_match_score" QUASI_EXACT_MATCH_SCORE = "quasi_exact_match_score" PRECISION_OVER_WORDS = "precision_over_words" RECALL_OVER_WORDS = "recall_over_words" -BERT_SCORE = "bert_score" # for metrics that are included in the QAAccuracyScores Transform QA_ACCURACY_SCORE_NAMES = [ @@ -249,104 +245,6 @@ def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]: return record -class BertScoreWithDelimiter(Transform): - """The abstract base class for QA Accuracy metric transforms. - - Concrete subclasses of QAAccuracyMetric should simply implement the - `compute_metric` method and their own __init__ method. Subclasses need not implement - the __call__ method, as it is already implemented in this class, but are - free to do so if additional customization is required. - """ - - def __init__( - self, - target_output_keys: List[str], - model_output_keys: List[str], - output_keys: List[str], - allow_duplicate_input_keys: bool, - bertscore_model: Union[BertscoreHelperModel, ActorHandle], - target_output_delimiter: Optional[str] = "", - ): - """QAAccuracyMetric initializer. - - Note that the ordering of the elements in `target_output_keys`, `model_output_keys`, - and `output_keys` must match, i.e. the kth element of `target_output_keys` and the - kth element of `model_output_keys` are used to compute the kth metric, which has - an output key of `output_keys[k]`. - - :param target_output_keys: The keys corresponding to target outputs. - :param model_output_keys: The keys corresponding to model outputs. - :param output_keys: The output keys for this Transform, which correspond - to the metrics/scores that get computed. - :param allow_duplicate_input_keys: Whether to allow duplicate keys in - `target_output_keys` and `model_output_keys`. This parameter is usually - False, but will be True when a SummarizationAccuracyMetric is created - to compute metrics on perturbed model outputs. In this case, - `target_output_keys` will be a list of a single repeated key, while - `model_output_keys` will contain the keys for perturbed model outputs. - :param *args: Variable length argument list. - :param **kwargs: Arbitrary keyword arguments. - """ - assert_condition( - len(target_output_keys) == len(model_output_keys) and len(target_output_keys) == len(output_keys), - "len(target_output_keys), len(model_output_keys) and len(output_keys) should all match. " - f"len(target_output_keys) is {len(target_output_keys)}, len(model_output_keys) is " - f"{len(model_output_keys)}, and len(output_keys) is {len(output_keys)}.", - ) - super().__init__( - target_output_keys, - model_output_keys, - output_keys, - allow_duplicate_input_keys, - bertscore_model, - target_output_delimiter, - ) - self.register_input_output_keys( - target_output_keys + model_output_keys, - output_keys, - allow_duplicates=allow_duplicate_input_keys, - ) - self.target_output_keys = target_output_keys - self.model_output_keys = model_output_keys - self.bertscore_model = bertscore_model - self.target_output_delimiter = target_output_delimiter - - @validate_call - def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]: - """Augment the input record with metrics computed via self.compute_metric. - - :param record: The input record. - :returns: The input record with metrics added in. - """ - for target_output_key, model_output_key, output_key in zip( - self.target_output_keys, self.model_output_keys, self.output_keys - ): - score = self.compute_metric(record[target_output_key], record[model_output_key]) - record[output_key] = score - return record - - def compute_metric(self, target_output: str, model_output: str) -> float: - """Compute the metric that is specific to this Transform. - - :param target_output: The target/reference output. - :param model_output: The actual output produced by the model. - :returns: A float representing the computed metric value. - """ - possible_targets = target_output.split(self.target_output_delimiter) - if isinstance(self.bertscore_model, BertscoreHelperModel): - return max([self.bertscore_model.get_helper_scores(target, model_output) for target in possible_targets]) - else: - possible_scores = list( - map( - lambda x: self.bertscore_model.get_helper_scores.remote(x, model_output), # type: ignore[return-value] - possible_targets, - ) - ) - all_scores = ray.get(possible_scores) - - return max(all_scores) - - @dataclass(frozen=True) class QAAccuracyConfig(EvalAlgorithmConfig): """Configures the QA Accuracy evaluation algorithm. @@ -407,8 +305,11 @@ def __init__(self, eval_algorithm_config: QAAccuracyConfig = QAAccuracyConfig()) super().__init__(eval_algorithm_config) self.bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore) - qa_accuracy_score = QAAccuracyScores(target_output_delimiter=eval_algorithm_config.target_output_delimiter) - bert_score = BertScoreWithDelimiter( + + # Saving QAAccuracyScores in the original self.transform + self.transform = QAAccuracyScores(target_output_delimiter=eval_algorithm_config.target_output_delimiter) + + bert_score = BertScoreMax( target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name], model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name], output_keys=[BERT_SCORE], @@ -417,9 +318,8 @@ def __init__(self, eval_algorithm_config: QAAccuracyConfig = QAAccuracyConfig()) target_output_delimiter=eval_algorithm_config.target_output_delimiter, ) self._eval_algorithm_config = eval_algorithm_config - self.qa_accuracy_score = qa_accuracy_score # saving the QAAccuracyScores transform self.bert_score = bert_score # saving the BertScore transform - self.pipeline = TransformPipeline([qa_accuracy_score, bert_score]) + self.pipeline = TransformPipeline([self.transform, bert_score]) def evaluate_sample(self, target_output: str, model_output: str) -> List[EvalScore]: """Compute QA accuracy metrics for a single sample. @@ -465,17 +365,17 @@ def evaluate( """ # Create a shared resource to be used during the evaluation. bertscore_shared_resource = create_shared_resource(self.bertscore_model) - bert_score = BertScoreWithDelimiter( + bert_score = BertScoreMax( target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name], model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name], output_keys=[BERT_SCORE], allow_duplicate_input_keys=True, bertscore_model=bertscore_shared_resource, - target_output_delimiter=self.qa_accuracy_score.target_output_delimiter, + target_output_delimiter=self.transform.target_output_delimiter, ) # Create a new pipeline that uses the shared resource instead of self.bertscore_model. - pipeline = TransformPipeline([self.qa_accuracy_score, bert_score]) + pipeline = TransformPipeline([self.transform, bert_score]) dataset_configs = get_dataset_configs(dataset_config, self.eval_name) eval_outputs = [] diff --git a/src/fmeval/transforms/common.py b/src/fmeval/transforms/common.py index bfed98e7..16497c41 100644 --- a/src/fmeval/transforms/common.py +++ b/src/fmeval/transforms/common.py @@ -1,8 +1,12 @@ import numpy as np -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union, Callable +from ray.actor import ActorHandle + +from fmeval.eval_algorithms.helper_models.helper_model import BertscoreHelperModel from fmeval.model_runners.composers.composers import PromptComposer from fmeval.model_runners.model_runner import ModelRunner +from fmeval.transforms.summarization_accuracy_metrics import BertScore, BERT_SCORE from fmeval.transforms.transform import Transform from fmeval.transforms.util import validate_call @@ -191,3 +195,76 @@ def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]: avg = np.mean([record[input_key] for input_key in self.input_keys]) record[self.output_key] = avg return record + + +class BertScoreMax(Transform): + """This Transform augments its input record with the maximum BERTScore metric computed over + possible targets separated by a target_output_delimiter. + """ + + def __init__( + self, + target_output_keys: List[str], + model_output_keys: List[str], + output_keys: List[str], + allow_duplicate_input_keys: bool, + bertscore_model: Union[BertscoreHelperModel, ActorHandle], + target_output_delimiter: Optional[str] = "", + ): + """BertScoreMax initializer. + :param target_output_keys: The keys corresponding to target outputs. + :param model_output_keys: The keys corresponding to model outputs. + :param output_keys: The output keys for this Transform, which correspond + to the BERT scores that get computed. + :param allow_duplicate_input_keys: See docstring for SummarizationAccuracyMetric. + :param bertscore_model: A BertscoreHelperModel instance or a Ray actor handle for a BertscoreHelperModel. + :param split_function: This is essentially a function that allows us to convert our exists target output into + a list of possible targets. + """ + super().__init__( + target_output_keys, + model_output_keys, + output_keys, + allow_duplicate_input_keys, + bertscore_model, + target_output_delimiter, + ) + self.register_input_output_keys( + target_output_keys + model_output_keys, + output_keys, + allow_duplicates=allow_duplicate_input_keys, + ) + self.target_output_keys = target_output_keys + self.model_output_keys = model_output_keys + self.bertscore_model = bertscore_model + self.target_output_delimiter = target_output_delimiter + + # BertScore transform used to compute metrics + self.bert_score_transform = BertScore( + target_output_keys=self.target_output_keys, + model_output_keys=self.model_output_keys, + output_keys=[BERT_SCORE], + allow_duplicate_input_keys=allow_duplicate_input_keys, + bertscore_model=self.bertscore_model, + ) + + @validate_call + def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]: + """Augment the input record with BERT_SCORE metrics computed via BertScore.compute_metric. + + :param record: The input record. + :returns: The input record with BERT_SCORE metric added in. + """ + + for target_output_key, model_output_key, output_key in zip( + self.target_output_keys, self.model_output_keys, self.output_keys + ): + # separating possible targets by target output delimiter to use for BertScore + possible_targets = record[target_output_key].split(self.target_output_delimiter) + + scores = [ + BertScore.compute_metric(self.bert_score_transform, target, record[model_output_key]) + for target in possible_targets + ] + record[output_key] = max(scores) + return record diff --git a/test/unit/eval_algorithms/test_qa_accuracy.py b/test/unit/eval_algorithms/test_qa_accuracy.py index 49c8ae15..f7a61194 100644 --- a/test/unit/eval_algorithms/test_qa_accuracy.py +++ b/test/unit/eval_algorithms/test_qa_accuracy.py @@ -5,7 +5,6 @@ import pytest import ray from _pytest.fixtures import fixture -from ray.actor import ActorHandle from ray.data import Dataset from fmeval.constants import ( @@ -42,7 +41,7 @@ _split, _quasi_exact_match_score, SCORE_NAMES, - BertScoreWithDelimiter, + BertScoreMax, ) from fmeval.exceptions import EvalAlgorithmClientError @@ -184,34 +183,6 @@ def test_qa_accuracy_invalid_config(self): with pytest.raises(EvalAlgorithmClientError, match=re.escape(expected_error_message)): QAAccuracyConfig(target_output_delimiter="") - def test_bert_score_with_delimiter_call_with_ray_actor_handle(self): - """ - GIVEN a BertScoreWithDelimiter instance, where its `bertscore_model` is a Ray actor handle. - WHEN its __call__ method is invoked. - THEN the correct Ray APIs are called. - - Note: we don't validate the structure of the __call__ output since - we already have @validate_call to handle that. - """ - mock_bertscore_model = Mock(spec=ActorHandle) - mock_bertscore_model.get_helper_scores = Mock() - mock_bertscore_model.get_helper_scores.remote = Mock(return_value="remote invocation result") - - with patch("fmeval.eval_algorithms.qa_accuracy.ray.get") as mock_ray_get: - mock_ray_get.return_value = [0.0] - bs = BertScoreWithDelimiter( - target_output_keys=["target_output"], - model_output_keys=["model_output"], - output_keys=["bertscore"], - allow_duplicate_input_keys=False, - bertscore_model=mock_bertscore_model, - ) - sample = {"target_output": "Hello there!", "model_output": "Hi"} - bs(sample) - mock_bertscore_model.get_helper_scores.remote.assert_called_once_with("Hello there!", "Hi") - mock_ray_get.assert_called_once_with(["remote invocation result"]) # this must be a list because ray.get - # takes in possible_scores which is a list (corresponding to possible targets) - class TestCaseQAAccuracyEvaluateSample(NamedTuple): model_input: str model_output: str @@ -330,7 +301,7 @@ def test_qa_accuracy_evaluate_sample(self, mock_isinstance, bertscore_model_cls, @patch("fmeval.eval_algorithms.qa_accuracy.evaluate_dataset") @patch("fmeval.eval_algorithms.qa_accuracy.create_shared_resource") @patch("fmeval.eval_algorithms.qa_accuracy.TransformPipeline") - @patch("fmeval.eval_algorithms.qa_accuracy.BertScoreWithDelimiter") + @patch("fmeval.eval_algorithms.qa_accuracy.BertScoreMax") @patch("fmeval.eval_algorithms.qa_accuracy.QAAccuracyScores") @patch("fmeval.eval_algorithms.qa_accuracy.get_dataset") @patch("fmeval.eval_algorithms.qa_accuracy.get_dataset_configs")