Skip to content

Commit

Permalink
refactor: SummarizationAccuracyMetrics transform to handle multiple…
Browse files Browse the repository at this point in the history
… target outputs more efficiently (#317)

* Added metric to factual knowledge + unit/integration tests

cr: https://code.amazon.com/reviews/CR-135854933

* fixed changes from PR comments

* Deleted metrics.py and restored code in util.py

* added factual knowledge metrics to constants.py

* added factual knowledge metrics to be included in binary score

* updated score descriptions for factual knowledge

* feat: add configurable param logical_operator (OR/AND) to factual knoweldge

* fixed changes from PR comments

* added warning and fixed typo

* modified warnings and fixed invalid config tests for factual_knowledge

* feat: Adding BERTScore to QAAccuracy + QAAccuracySemanticRobustness

* fix: documentation and tests for qa accuracy + qa accuracy semantic robustness

* fix: lint checks

* fix: created dataset for qa_accuracy, reverted to js_model_runner

* fix: integration tests by adding approx for BertScore

* fix: moved BertScoreWithDelimiter to qa_accuracy and updated tests

* fix: restored qa_accuracy_semantic_robustness

* fix: smaller dataset for integ tests to reduce runtime

* fix: smaller dataset for integ tests to reduce runtime

* Add BertScoreMax transform for qa_accuracy

* fix: lint checks

* fix: cleaning up code and checking reporting folder for changes

* fix: refactored SummarizationAccuracyMetric

* fix: deleted dataset file from previous PR

* refactor: added target_output_keys_provider to SummarizationAccuracyMetric

* edited description of test function

* updated assert description for target_output_keys

* fixed comments from PR

* fixed errors in test from previous run

* add unit test and fixed type issues

* fix: Mocked BertScore in unit test
  • Loading branch information
kirupang-code authored Aug 8, 2024
1 parent 2c6a402 commit bc5a15f
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/fmeval/eval_algorithms/general_semantic_robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _build_pipeline(
# Compute BERTScores with target_output = the original model output
# and model_output = the output from invoking the model with the perturbed prompt.
get_bert_scores = BertScore(
target_output_keys=[original_model_output_key for _ in range(self.num_perturbations)],
target_output_keys=[original_model_output_key],
model_output_keys=get_perturbed_responses.output_keys,
output_keys=[create_output_key(BertScore.__name__, i) for i in range(self.num_perturbations)],
allow_duplicate_input_keys=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _build_pipeline(
)

perturbed_meteor, perturbed_rouge, perturbed_bert_score = SummarizationAccuracy._create_transforms(
target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name for _ in range(self.config.num_perturbations)],
target_output_keys=[DatasetColumns.TARGET_OUTPUT.value.name],
model_output_keys=get_perturbed_responses.output_keys,
meteor_keys=[
create_output_key(MeteorScore.__name__, "perturbed", i) for i in range(self.config.num_perturbations)
Expand Down
113 changes: 79 additions & 34 deletions src/fmeval/transforms/summarization_accuracy_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import evaluate as hf_evaluate

from abc import abstractmethod
from typing import Any, Dict, Union, List
from typing import Any, Dict, Union, List, Optional
from ray.actor import ActorHandle
from nltk import word_tokenize
from nltk.translate import meteor_score

from fmeval.transforms.transform import Transform
from fmeval.transforms.util import validate_call
from fmeval.constants import BERTSCORE_DEFAULT_MODEL
from fmeval.eval_algorithms.helper_models.helper_model import BertscoreHelperModel
from fmeval.util import assert_condition

Expand All @@ -36,67 +37,82 @@ class SummarizationAccuracyMetric(Transform):

def __init__(
self,
target_output_keys: List[str],
target_output_keys: Optional[List[str]],
model_output_keys: List[str],
output_keys: List[str],
allow_duplicate_input_keys: bool,
target_output_keys_provider: Optional[str],
*args,
**kwargs,
):
"""SummarizationAccuracyMetric initializer.
Note that the ordering of the elements in `target_output_keys`, `model_output_keys`,
and `output_keys` must match, i.e. the kth element of `target_output_keys` and the
kth element of `model_output_keys` are used to compute the kth metric, which has
an output key of `output_keys[k]`.
Note that the ordering of the elements in `model_output_keys`, and `output_keys`
must match, i.e. the kth element of kth element of `model_output_keys` is used
to compute the kth metric, which has an output key of `output_keys[k]`.
:param target_output_keys: The keys corresponding to target outputs.
:param target_output_keys: The keys corresponding to target outputs. If this is
set to None, then we will use `target_output_keys_provider` to get the
list of target outputs.
:param model_output_keys: The keys corresponding to model outputs.
:param output_keys: The output keys for this Transform, which correspond
to the metrics/scores that get computed.
:param allow_duplicate_input_keys: Whether to allow duplicate keys in
`target_output_keys` and `model_output_keys`. This parameter is usually
False, but will be True when a SummarizationAccuracyMetric is created
to compute metrics on perturbed model outputs. In this case,
`target_output_keys` will be a list of a single repeated key, while
`model_output_keys` will contain the keys for perturbed model outputs.
False.
:param target_output_keys_provider: The key corresponding to a list of target
outputs. Will only be used if `target_output_keys` is set to None.
:param *args: Variable length argument list.
:param **kwargs: Arbitrary keyword arguments.
"""
assert_condition(
len(target_output_keys) == len(model_output_keys) and len(target_output_keys) == len(output_keys),
"len(target_output_keys), len(model_output_keys) and len(output_keys) should all match. "
f"len(target_output_keys) is {len(target_output_keys)}, len(model_output_keys) is "
len(model_output_keys) == len(output_keys),
"len(model_output_keys) and len(output_keys) should match. "
f"len(model_output_keys) is "
f"{len(model_output_keys)}, and len(output_keys) is {len(output_keys)}.",
)
if target_output_keys is None:
assert_condition(
target_output_keys_provider is not None,
f"target_output_keys is {target_output_keys}, but target_output_keys_provider"
f" (the fallback value) is {target_output_keys_provider} which is invalid.",
)
super().__init__(
target_output_keys,
model_output_keys,
output_keys,
allow_duplicate_input_keys,
target_output_keys_provider,
*args,
**kwargs,
)
input_keys = target_output_keys if target_output_keys else [target_output_keys_provider] # type: ignore
self.register_input_output_keys(
target_output_keys + model_output_keys,
input_keys + model_output_keys,
output_keys,
allow_duplicates=allow_duplicate_input_keys,
)
self.target_output_keys = target_output_keys
self.model_output_keys = model_output_keys
self.target_output_keys_provider = target_output_keys_provider

@validate_call
def __call__(self, record: Dict[str, Any]) -> Dict[str, Any]:
"""Augment the input record with metrics computed via self.compute_metric.
The max score is computed over all possible targets represented by
self.target_output_keys and stored in the input record.
:param record: The input record.
:returns: The input record with metrics added in.
"""
for target_output_key, model_output_key, output_key in zip(
self.target_output_keys, self.model_output_keys, self.output_keys
):
score = self.compute_metric(record[target_output_key], record[model_output_key])
record[output_key] = score
target_output_list = (
[record[target_output_key] for target_output_key in self.target_output_keys]
if self.target_output_keys
else record[self.target_output_keys_provider] # type: ignore[index]
)
for model_output_key, output_key in zip(self.model_output_keys, self.output_keys):
scores = [self.compute_metric(target, record[model_output_key]) for target in target_output_list]
record[output_key] = max(scores)
return record

@abstractmethod
Expand Down Expand Up @@ -126,26 +142,34 @@ class MeteorScore(SummarizationAccuracyMetric):

def __init__(
self,
target_output_keys: List[str],
target_output_keys: Optional[List[str]],
model_output_keys: List[str],
output_keys: List[str],
allow_duplicate_input_keys,
allow_duplicate_input_keys: bool,
target_output_keys_provider: Optional[str] = None,
load_modules: bool = True,
):
"""MeteorScore initializer.
:param target_output_keys: The keys corresponding to target outputs.
:param target_output_keys: The keys corresponding to target outputs. If this is
set to None, then we will use `target_output_keys_provider` to get the
list of target outputs.
:param model_output_keys: The keys corresponding to model outputs.
:param output_keys: The output keys for this Transform, which correspond
to the Meteor scores that get computed.
:param allow_duplicate_input_keys: See docstring for SummarizationAccuracyMetric.
:param allow_duplicate_input_keys: Whether to allow duplicate keys in
`target_output_keys` and `model_output_keys`. This parameter is usually
False.
:param target_output_keys_provider: The key corresponding to a list of target
outputs. Will only be used if `target_output_keys` is set to None.
:param load_modules: Whether to load the meteor helper modules.
"""
super().__init__(
target_output_keys,
model_output_keys,
output_keys,
allow_duplicate_input_keys,
target_output_keys_provider,
# The first instance of this class that gets created will
# load the helper modules, so copies of this instance
# need not load them again.
Expand Down Expand Up @@ -190,20 +214,27 @@ class RougeScore(SummarizationAccuracyMetric):

def __init__(
self,
target_output_keys: List[str],
target_output_keys: Optional[List[str]],
model_output_keys: List[str],
output_keys: List[str],
allow_duplicate_input_keys: bool,
target_output_keys_provider: Optional[str] = None,
rouge_type: str = ROUGE_2,
use_stemmer: bool = True,
):
"""RougeScore initializer.
:param target_output_keys: The keys corresponding to target outputs.
:param target_output_keys: The keys corresponding to target outputs. If this is
set to None, then we will use `target_output_keys_provider` to get the
list of target outputs.
:param model_output_keys: The keys corresponding to model outputs.
:param output_keys: The output keys for this Transform, which correspond
to the ROUGE scores that get computed.
:param allow_duplicate_input_keys: See docstring for SummarizationAccuracyMetric.
to the Rouge scores that get computed.
:param allow_duplicate_input_keys: Whether to allow duplicate keys in
`target_output_keys` and `model_output_keys`. This parameter is usually
False.
:param target_output_keys_provider: The key corresponding to a list of target
outputs. Will only be used if `target_output_keys` is set to None.
:param rouge_type: Which ROUGE type to use (1, 2, L).
:param use_stemmer: Whether to use a stemmer for ROUGE.
"""
Expand All @@ -212,6 +243,7 @@ def __init__(
model_output_keys,
output_keys,
allow_duplicate_input_keys,
target_output_keys_provider,
rouge_type=rouge_type,
use_stemmer=use_stemmer,
)
Expand Down Expand Up @@ -247,23 +279,36 @@ class BertScore(SummarizationAccuracyMetric):

def __init__(
self,
target_output_keys: List[str],
target_output_keys: Optional[List[str]],
model_output_keys: List[str],
output_keys: List[str],
allow_duplicate_input_keys: bool,
bertscore_model: Union[BertscoreHelperModel, ActorHandle],
target_output_keys_provider: Optional[str] = None,
bertscore_model: Union[BertscoreHelperModel, ActorHandle] = BertscoreHelperModel(BERTSCORE_DEFAULT_MODEL),
):
"""BertScore initializer.
:param target_output_keys: The keys corresponding to target outputs.
:param target_output_keys: The keys corresponding to target outputs. If this is
set to None, then we will use `target_output_keys_provider` to get the
list of target outputs.
:param model_output_keys: The keys corresponding to model outputs.
:param output_keys: The output keys for this Transform, which correspond
to the BERT scores that get computed.
:param allow_duplicate_input_keys: See docstring for SummarizationAccuracyMetric.
to the BERT_SCORES that get computed.
:param allow_duplicate_input_keys: Whether to allow duplicate keys in
`target_output_keys` and `model_output_keys`. This parameter is usually
False.
:param target_output_keys_provider: The key corresponding to a list of target
outputs. Will only be used if `target_output_keys` is set to None.
:param bertscore_model: A BertscoreHelperModel instance or a Ray actor handle for a BertscoreHelperModel.
If no model is provided, the parameter will be set to the default BertscoreHelperModel
"""
super().__init__(
target_output_keys, model_output_keys, output_keys, allow_duplicate_input_keys, bertscore_model
target_output_keys,
model_output_keys,
output_keys,
allow_duplicate_input_keys,
target_output_keys_provider,
bertscore_model,
)
self.bertscore_model = bertscore_model

Expand Down
54 changes: 54 additions & 0 deletions test/unit/transforms/test_summarization_accuracy_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,60 @@ def test_bert_score_call_with_bertscore_model_object():
mock_bertscore_model.get_helper_scores.assert_called_once_with("Hello there!", "Hi")


def test_bert_score_call_with_target_output_keys_provider():
"""
GIVEN a BertScore instance with a valid `target_output_keys provider`.
WHEN its __call__ method is invoked.
THEN self.bertscore_model is invoked with the correct arguments.
Note: we don't validate the structure of the __call__ output since
we already have @validate_call to handle that.
"""
mock_bertscore_model = Mock(spec=BertscoreHelperModel)
mock_bertscore_model.get_helper_scores = Mock()

bs = BertScore(
target_output_keys=None,
model_output_keys=["model_output"],
output_keys=["bertscore"],
allow_duplicate_input_keys=False,
target_output_keys_provider="target_output",
bertscore_model=mock_bertscore_model,
)
sample = {"target_output": ["Hello there!"], "model_output": "Hi"}
bs(sample)
mock_bertscore_model.get_helper_scores.assert_called_once_with("Hello there!", "Hi")


def test_bertscore_multiple_targets_max_score():
"""
GIVEN a BertScore instance with multiple possible target answers.
WHEN its __call__ method is invoked.
THEN the maximum score is returned.
"""
mock_bertscore_model = Mock(spec=BertscoreHelperModel)
mock_bertscore_model.get_helper_scores = Mock()

output_scores = [0.2, 0.1, 0.3]
mock_bertscore_model.get_helper_scores.side_effect = output_scores

bs = BertScore(
target_output_keys=None,
model_output_keys=["model_output"],
output_keys=["bertscore"],
allow_duplicate_input_keys=False,
target_output_keys_provider="target_output",
bertscore_model=mock_bertscore_model,
)
sample = {"target_output": ["random output", "hello", "something"], "model_output": "hello"}
output = bs(sample)

mock_bertscore_model.get_helper_scores.assert_has_calls(
[call("random output", "hello"), call("hello", "hello"), call("something", "hello")]
)
assert output["bertscore"] == pytest.approx(max(output_scores), rel=1e-5)


def test_bert_score_call_with_ray_actor_handle():
"""
GIVEN a BertScore instance, where its `bertscore_model` is a Ray actor handle.
Expand Down

0 comments on commit bc5a15f

Please sign in to comment.