Skip to content

Commit

Permalink
Add BertScoreMax transform for qa_accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
kirupang-code committed Jul 25, 2024
1 parent 4859531 commit fdb99b1
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 142 deletions.
120 changes: 10 additions & 110 deletions src/fmeval/eval_algorithms/qa_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Union
from dataclasses import dataclass

import ray
from nltk.metrics.scores import f_measure, precision, recall
from ray.actor import ActorHandle

from fmeval.constants import (
BERTSCORE_DEFAULT_MODEL,
Expand All @@ -25,6 +23,7 @@
EvalOutput,
EvalScore,
)
from fmeval.transforms.common import BertScoreMax, BERT_SCORE
from fmeval.model_runners.model_runner import ModelRunner
from fmeval.transforms.transform import Transform
from fmeval.transforms.transform_pipeline import TransformPipeline
Expand All @@ -37,14 +36,11 @@
assert_condition,
)

# from fmeval.transforms.qa_accuracy_metrics import BertScoreWithDelimiter, BERT_SCORE

F1_SCORE = "f1_score"
EXACT_MATCH_SCORE = "exact_match_score"
QUASI_EXACT_MATCH_SCORE = "quasi_exact_match_score"
PRECISION_OVER_WORDS = "precision_over_words"
RECALL_OVER_WORDS = "recall_over_words"
BERT_SCORE = "bert_score"

# for metrics that are included in the QAAccuracyScores Transform
QA_ACCURACY_SCORE_NAMES = [
Expand Down Expand Up @@ -249,104 +245,6 @@ def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]:
return record


class BertScoreWithDelimiter(Transform):
"""The abstract base class for QA Accuracy metric transforms.
Concrete subclasses of QAAccuracyMetric should simply implement the
`compute_metric` method and their own __init__ method. Subclasses need not implement
the __call__ method, as it is already implemented in this class, but are
free to do so if additional customization is required.
"""

def __init__(
self,
target_output_keys: List[str],
model_output_keys: List[str],
output_keys: List[str],
allow_duplicate_input_keys: bool,
bertscore_model: Union[BertscoreHelperModel, ActorHandle],
target_output_delimiter: Optional[str] = "<OR>",
):
"""QAAccuracyMetric initializer.
Note that the ordering of the elements in `target_output_keys`, `model_output_keys`,
and `output_keys` must match, i.e. the kth element of `target_output_keys` and the
kth element of `model_output_keys` are used to compute the kth metric, which has
an output key of `output_keys[k]`.
:param target_output_keys: The keys corresponding to target outputs.
:param model_output_keys: The keys corresponding to model outputs.
:param output_keys: The output keys for this Transform, which correspond
to the metrics/scores that get computed.
:param allow_duplicate_input_keys: Whether to allow duplicate keys in
`target_output_keys` and `model_output_keys`. This parameter is usually
False, but will be True when a SummarizationAccuracyMetric is created
to compute metrics on perturbed model outputs. In this case,
`target_output_keys` will be a list of a single repeated key, while
`model_output_keys` will contain the keys for perturbed model outputs.
:param *args: Variable length argument list.
:param **kwargs: Arbitrary keyword arguments.
"""
assert_condition(
len(target_output_keys) == len(model_output_keys) and len(target_output_keys) == len(output_keys),
"len(target_output_keys), len(model_output_keys) and len(output_keys) should all match. "
f"len(target_output_keys) is {len(target_output_keys)}, len(model_output_keys) is "
f"{len(model_output_keys)}, and len(output_keys) is {len(output_keys)}.",
)
super().__init__(
target_output_keys,
model_output_keys,
output_keys,
allow_duplicate_input_keys,
bertscore_model,
target_output_delimiter,
)
self.register_input_output_keys(
target_output_keys + model_output_keys,
output_keys,
allow_duplicates=allow_duplicate_input_keys,
)
self.target_output_keys = target_output_keys
self.model_output_keys = model_output_keys
self.bertscore_model = bertscore_model
self.target_output_delimiter = target_output_delimiter

@validate_call
def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]:
"""Augment the input record with metrics computed via self.compute_metric.
:param record: The input record.
:returns: The input record with metrics added in.
"""
for target_output_key, model_output_key, output_key in zip(
self.target_output_keys, self.model_output_keys, self.output_keys
):
score = self.compute_metric(record[target_output_key], record[model_output_key])
record[output_key] = score
return record

def compute_metric(self, target_output: str, model_output: str) -> float:
"""Compute the metric that is specific to this Transform.
:param target_output: The target/reference output.
:param model_output: The actual output produced by the model.
:returns: A float representing the computed metric value.
"""
possible_targets = target_output.split(self.target_output_delimiter)
if isinstance(self.bertscore_model, BertscoreHelperModel):
return max([self.bertscore_model.get_helper_scores(target, model_output) for target in possible_targets])
else:
possible_scores = list(
map(
lambda x: self.bertscore_model.get_helper_scores.remote(x, model_output), # type: ignore[return-value]
possible_targets,
)
)
all_scores = ray.get(possible_scores)

return max(all_scores)


@dataclass(frozen=True)
class QAAccuracyConfig(EvalAlgorithmConfig):
"""Configures the QA Accuracy evaluation algorithm.
Expand Down Expand Up @@ -407,8 +305,11 @@ def __init__(self, eval_algorithm_config: QAAccuracyConfig = QAAccuracyConfig())
super().__init__(eval_algorithm_config)

self.bertscore_model = BertscoreHelperModel(eval_algorithm_config.model_type_for_bertscore)
qa_accuracy_score = QAAccuracyScores(target_output_delimiter=eval_algorithm_config.target_output_delimiter)
bert_score = BertScoreWithDelimiter(

# Saving QAAccuracyScores in the original self.transform
self.transform = QAAccuracyScores(target_output_delimiter=eval_algorithm_config.target_output_delimiter)

bert_score = BertScoreMax(
target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
output_keys=[BERT_SCORE],
Expand All @@ -417,9 +318,8 @@ def __init__(self, eval_algorithm_config: QAAccuracyConfig = QAAccuracyConfig())
target_output_delimiter=eval_algorithm_config.target_output_delimiter,
)
self._eval_algorithm_config = eval_algorithm_config
self.qa_accuracy_score = qa_accuracy_score # saving the QAAccuracyScores transform
self.bert_score = bert_score # saving the BertScore transform
self.pipeline = TransformPipeline([qa_accuracy_score, bert_score])
self.pipeline = TransformPipeline([self.transform, bert_score])

def evaluate_sample(self, target_output: str, model_output: str) -> List[EvalScore]:
"""Compute QA accuracy metrics for a single sample.
Expand Down Expand Up @@ -465,17 +365,17 @@ def evaluate(
"""
# Create a shared resource to be used during the evaluation.
bertscore_shared_resource = create_shared_resource(self.bertscore_model)
bert_score = BertScoreWithDelimiter(
bert_score = BertScoreMax(
target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
model_output_keys=[DatasetColumns.MODEL_OUTPUT.value.name],
output_keys=[BERT_SCORE],
allow_duplicate_input_keys=True,
bertscore_model=bertscore_shared_resource,
target_output_delimiter=self.qa_accuracy_score.target_output_delimiter,
target_output_delimiter=self.transform.target_output_delimiter,
)

# Create a new pipeline that uses the shared resource instead of self.bertscore_model.
pipeline = TransformPipeline([self.qa_accuracy_score, bert_score])
pipeline = TransformPipeline([self.transform, bert_score])

dataset_configs = get_dataset_configs(dataset_config, self.eval_name)
eval_outputs = []
Expand Down
79 changes: 78 additions & 1 deletion src/fmeval/transforms/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import numpy as np
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union, Callable

from ray.actor import ActorHandle

from fmeval.eval_algorithms.helper_models.helper_model import BertscoreHelperModel
from fmeval.model_runners.composers.composers import PromptComposer
from fmeval.model_runners.model_runner import ModelRunner
from fmeval.transforms.summarization_accuracy_metrics import BertScore, BERT_SCORE
from fmeval.transforms.transform import Transform
from fmeval.transforms.util import validate_call

Expand Down Expand Up @@ -191,3 +195,76 @@ def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]:
avg = np.mean([record[input_key] for input_key in self.input_keys])
record[self.output_key] = avg
return record


class BertScoreMax(Transform):
"""This Transform augments its input record with the maximum BERTScore metric computed over
possible targets separated by a target_output_delimiter.
"""

def __init__(
self,
target_output_keys: List[str],
model_output_keys: List[str],
output_keys: List[str],
allow_duplicate_input_keys: bool,
bertscore_model: Union[BertscoreHelperModel, ActorHandle],
target_output_delimiter: Optional[str] = "<OR>",
):
"""BertScoreMax initializer.
:param target_output_keys: The keys corresponding to target outputs.
:param model_output_keys: The keys corresponding to model outputs.
:param output_keys: The output keys for this Transform, which correspond
to the BERT scores that get computed.
:param allow_duplicate_input_keys: See docstring for SummarizationAccuracyMetric.
:param bertscore_model: A BertscoreHelperModel instance or a Ray actor handle for a BertscoreHelperModel.
:param split_function: This is essentially a function that allows us to convert our exists target output into
a list of possible targets.
"""
super().__init__(
target_output_keys,
model_output_keys,
output_keys,
allow_duplicate_input_keys,
bertscore_model,
target_output_delimiter,
)
self.register_input_output_keys(
target_output_keys + model_output_keys,
output_keys,
allow_duplicates=allow_duplicate_input_keys,
)
self.target_output_keys = target_output_keys
self.model_output_keys = model_output_keys
self.bertscore_model = bertscore_model
self.target_output_delimiter = target_output_delimiter

# BertScore transform used to compute metrics
self.bert_score_transform = BertScore(
target_output_keys=self.target_output_keys,
model_output_keys=self.model_output_keys,
output_keys=[BERT_SCORE],
allow_duplicate_input_keys=allow_duplicate_input_keys,
bertscore_model=self.bertscore_model,
)

@validate_call
def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]:
"""Augment the input record with BERT_SCORE metrics computed via BertScore.compute_metric.
:param record: The input record.
:returns: The input record with BERT_SCORE metric added in.
"""

for target_output_key, model_output_key, output_key in zip(
self.target_output_keys, self.model_output_keys, self.output_keys
):
# separating possible targets by target output delimiter to use for BertScore
possible_targets = record[target_output_key].split(self.target_output_delimiter)

scores = [
BertScore.compute_metric(self.bert_score_transform, target, record[model_output_key])
for target in possible_targets
]
record[output_key] = max(scores)
return record
33 changes: 2 additions & 31 deletions test/unit/eval_algorithms/test_qa_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest
import ray
from _pytest.fixtures import fixture
from ray.actor import ActorHandle
from ray.data import Dataset

from fmeval.constants import (
Expand Down Expand Up @@ -42,7 +41,7 @@
_split,
_quasi_exact_match_score,
SCORE_NAMES,
BertScoreWithDelimiter,
BertScoreMax,
)
from fmeval.exceptions import EvalAlgorithmClientError

Expand Down Expand Up @@ -184,34 +183,6 @@ def test_qa_accuracy_invalid_config(self):
with pytest.raises(EvalAlgorithmClientError, match=re.escape(expected_error_message)):
QAAccuracyConfig(target_output_delimiter="")

def test_bert_score_with_delimiter_call_with_ray_actor_handle(self):
"""
GIVEN a BertScoreWithDelimiter instance, where its `bertscore_model` is a Ray actor handle.
WHEN its __call__ method is invoked.
THEN the correct Ray APIs are called.
Note: we don't validate the structure of the __call__ output since
we already have @validate_call to handle that.
"""
mock_bertscore_model = Mock(spec=ActorHandle)
mock_bertscore_model.get_helper_scores = Mock()
mock_bertscore_model.get_helper_scores.remote = Mock(return_value="remote invocation result")

with patch("fmeval.eval_algorithms.qa_accuracy.ray.get") as mock_ray_get:
mock_ray_get.return_value = [0.0]
bs = BertScoreWithDelimiter(
target_output_keys=["target_output"],
model_output_keys=["model_output"],
output_keys=["bertscore"],
allow_duplicate_input_keys=False,
bertscore_model=mock_bertscore_model,
)
sample = {"target_output": "Hello there!", "model_output": "Hi"}
bs(sample)
mock_bertscore_model.get_helper_scores.remote.assert_called_once_with("Hello there!", "Hi")
mock_ray_get.assert_called_once_with(["remote invocation result"]) # this must be a list because ray.get
# takes in possible_scores which is a list (corresponding to possible targets)

class TestCaseQAAccuracyEvaluateSample(NamedTuple):
model_input: str
model_output: str
Expand Down Expand Up @@ -330,7 +301,7 @@ def test_qa_accuracy_evaluate_sample(self, mock_isinstance, bertscore_model_cls,
@patch("fmeval.eval_algorithms.qa_accuracy.evaluate_dataset")
@patch("fmeval.eval_algorithms.qa_accuracy.create_shared_resource")
@patch("fmeval.eval_algorithms.qa_accuracy.TransformPipeline")
@patch("fmeval.eval_algorithms.qa_accuracy.BertScoreWithDelimiter")
@patch("fmeval.eval_algorithms.qa_accuracy.BertScoreMax")
@patch("fmeval.eval_algorithms.qa_accuracy.QAAccuracyScores")
@patch("fmeval.eval_algorithms.qa_accuracy.get_dataset")
@patch("fmeval.eval_algorithms.qa_accuracy.get_dataset_configs")
Expand Down

0 comments on commit fdb99b1

Please sign in to comment.