-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
AnswerExactMatchEvaluator
(#7381)
* Add AnswerExactMatchEvaluator * Add release notes * Fix linting * Update docstrings * Update docstrings * Remove to_dict and from_dict * Fix linting
- Loading branch information
1 parent
f69c3e5
commit 610ad6f
Showing
4 changed files
with
129 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .answer_exact_match import AnswerExactMatchEvaluator | ||
|
||
__all__ = ["AnswerExactMatchEvaluator"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import Dict, List | ||
|
||
from haystack.core.component import component | ||
|
||
|
||
@component | ||
class AnswerExactMatchEvaluator: | ||
""" | ||
Evaluator that checks if the predicted answers matches any of the ground truth answers exactly. | ||
The result is a number from 0.0 to 1.0, it represents the proportion of questions where any predicted answer | ||
matched one of the ground truth answers. | ||
Each question can have multiple ground truth answers and multiple predicted answers. | ||
Usage example: | ||
```python | ||
from haystack.components.evaluators import AnswerExactMatchEvaluator | ||
evaluator = AnswerExactMatchEvaluator() | ||
result = evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Berlin"], ["Paris"]], | ||
) | ||
print(result["result"]) | ||
# 1.0 | ||
``` | ||
""" | ||
|
||
@component.output_types(result=float) | ||
def run( | ||
self, questions: List[str], ground_truth_answers: List[List[str]], predicted_answers: List[List[str]] | ||
) -> Dict[str, float]: | ||
""" | ||
Run the AnswerExactMatchEvaluator on the given inputs. | ||
All lists must have the same length. | ||
:param questions: | ||
A list of questions. | ||
:param ground_truth_answers: | ||
A list of expected answers for each question. | ||
:param predicted_answers: | ||
A list of predicted answers for each question. | ||
:returns: | ||
A dictionary with the following outputs: | ||
- `result` - A number from 0.0 to 1.0 that represents the proportion of questions where any predicted | ||
answer matched one of the ground truth answers. | ||
""" | ||
if not len(questions) == len(ground_truth_answers) == len(predicted_answers): | ||
raise ValueError("The length of questions, ground_truth_answers, and predicted_answers must be the same.") | ||
|
||
matches = 0 | ||
for truths, extracted in zip(ground_truth_answers, predicted_answers): | ||
if set(truths) & set(extracted): | ||
matches += 1 | ||
|
||
# The proportion of questions where any predicted answer matched one of the ground truth answers | ||
result = matches / len(questions) | ||
|
||
return {"result": result} |
6 changes: 6 additions & 0 deletions
6
releasenotes/notes/exact-match-evaluator-197bb87b65e19d0c.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
--- | ||
features: | ||
- | | ||
Add `AnswerExactMatchEvaluator`, a Component that can be used to calculate the Exact Match metric | ||
given a list of questions, a list of expected answers for each question and the list of predicted | ||
answers for each question. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import pytest | ||
|
||
from haystack.components.evaluators import AnswerExactMatchEvaluator | ||
|
||
|
||
def test_run_with_all_matching(): | ||
evaluator = AnswerExactMatchEvaluator() | ||
result = evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Berlin"], ["Paris"]], | ||
) | ||
|
||
assert result["result"] == 1.0 | ||
|
||
|
||
def test_run_with_no_matching(): | ||
evaluator = AnswerExactMatchEvaluator() | ||
result = evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Paris"], ["London"]], | ||
) | ||
|
||
assert result["result"] == 0.0 | ||
|
||
|
||
def test_run_with_partial_matching(): | ||
evaluator = AnswerExactMatchEvaluator() | ||
result = evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Berlin"], ["London"]], | ||
) | ||
|
||
assert result["result"] == 0.5 | ||
|
||
|
||
def test_run_with_different_lengths(): | ||
evaluator = AnswerExactMatchEvaluator() | ||
|
||
with pytest.raises(ValueError): | ||
evaluator.run( | ||
questions=["What is the capital of Germany?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Berlin"], ["London"]], | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"]], | ||
predicted_answers=[["Berlin"], ["London"]], | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
evaluator.run( | ||
questions=["What is the capital of Germany?", "What is the capital of France?"], | ||
ground_truth_answers=[["Berlin"], ["Paris"]], | ||
predicted_answers=[["Berlin"]], | ||
) |