Skip to content

Commit

Permalink
bug: 5123 metrics (#5245)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

Only added field_name to agreement metrics as only alpha is influenced
by the content field.

Closes #5123 

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)
- Documentation update

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I did a self-review of my code
- I made corresponding changes to the documentation
  • Loading branch information
sdiazlor authored Jul 18, 2024
1 parent 5606746 commit 97aa33f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 46 deletions.
4 changes: 3 additions & 1 deletion argilla/src/argilla/client/feedback/dataset/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def compute_agreement_metrics(
self,
metric_names: Union[str, List[str]] = None,
question_name: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion, RankingQuestion] = None,
field_name: Union[str, List[str]] = None,
) -> Union["AgreementMetricResult", List["AgreementMetricResult"]]:
"""Compute agreement or reliability of annotation metrics.
Expand All @@ -94,6 +95,7 @@ def compute_agreement_metrics(
Args:
metric_names: Metric name or list of metric names of the metrics, dependent on the question type.
question_name: Question for which we want to compute the metrics.
field_name: Name of the fields related to the question we want to analyse the agreement.
Note:
Currently, TextQuestion is not supported.
Expand All @@ -104,7 +106,7 @@ def compute_agreement_metrics(
"""
from argilla.client.feedback.metrics.agreement_metrics import AgreementMetric

return AgreementMetric(self, question_name).compute(metric_names)
return AgreementMetric(self, question_name, field_name).compute(metric_names)


class UnificationMixin:
Expand Down
15 changes: 12 additions & 3 deletions argilla/src/argilla/client/feedback/metrics/agreement_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains metrics to gather information related to inter-Annotator agreement. """
"""This module contains metrics to gather information related to inter-Annotator agreement."""

import warnings
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self, **kwargs):
def prepare_dataset_for_annotation_task(
dataset: Union["FeedbackDataset", "RemoteFeedbackDataset"],
question_name: str,
field_name: Union[str, List[str]],
filter_by: Optional[Dict[str, Union["ResponseStatusFilter", List["ResponseStatusFilter"]]]] = None,
sort_by: Optional[List["SortBy"]] = None,
max_records: Optional[int] = None,
Expand All @@ -74,6 +76,7 @@ def prepare_dataset_for_annotation_task(
Args:
dataset: FeedbackDataset to compute the metrics.
question_name: Name of the question for which we want to analyse the agreement.
field_name: Name of the fields related to the question we want to analyse the agreement.
filter_by: A dict with key the field to filter by, and values the filters to apply.
Can be one of: draft, pending, submitted, and discarded. If set to None,
no filter will be applied. Defaults to None (no filter is applied).
Expand Down Expand Up @@ -108,7 +111,9 @@ def prepare_dataset_for_annotation_task(

for row in hf_dataset:
responses_ = row[question_name]
question_text = row["text"]
question_text = (
" ".join([row[field] for field in field_name]) if isinstance(field_name, list) else row[field_name]
)
for response in responses_:
user_id = response["user_id"]
if user_id is None:
Expand Down Expand Up @@ -181,7 +186,7 @@ class AgreementMetric(MetricBase):
Example:
>>> import argilla as rg
>>> from argilla.client.feedback.metrics import AgreementMetric
>>> metric = AgreementMetric(dataset=dataset, question_name=question, filter_by={"response_status": "submitted"})
>>> metric = AgreementMetric(dataset=dataset, question_name=question, field_name=field, filter_by={"response_status": "submitted"})
>>> metrics_report = metric.compute("alpha")
"""
Expand All @@ -190,6 +195,7 @@ def __init__(
self,
dataset: FeedbackDataset,
question_name: str,
field_name: Union[str, List[str]],
filter_by: Optional[Dict[str, Union["ResponseStatusFilter", List["ResponseStatusFilter"]]]] = None,
sort_by: Optional[List["SortBy"]] = None,
max_records: Optional[int] = None,
Expand All @@ -199,6 +205,7 @@ def __init__(
Args:
dataset: FeedbackDataset to compute the metrics.
question_name: Name of the question for which we want to analyse the agreement.
field_name: Name of the fields related to the question we want to analyse the agreement.
filter_by: A dict with key the field to filter by, and values the filters to apply.
Can be one of: draft, pending, submitted, and discarded. If set to None,
no filter will be applied. Defaults to None (no filter is applied).
Expand All @@ -207,6 +214,7 @@ def __init__(
max_records: The maximum number of records to use for training. Defaults to None.
"""
self._metrics_per_question = METRICS_PER_QUESTION
self._field_name = field_name
super().__init__(dataset, question_name)
self._filter_by = filter_by
self._sort_by = sort_by
Expand All @@ -231,6 +239,7 @@ def compute(self, metric_names: Union[str, List[str]]) -> List[AgreementMetricRe
dataset = prepare_dataset_for_annotation_task(
self._dataset,
self._question_name,
self._field_name,
filter_by=self._filter_by,
sort_by=self._sort_by,
max_records=self._max_records,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,20 @@ def test_allowed_metrics(
)
dataset.add_records(records=feedback_dataset_records_with_paired_suggestions)

metric = AgreementMetric(dataset=dataset, question_name=question)
metric = AgreementMetric(
dataset=dataset, question_name=question, field_name=[field.name for field in feedback_dataset_fields]
)
assert set(metric.allowed_metrics) == metric_names


@pytest.mark.parametrize(
"question, num_items, type_of_data",
"field, question, num_items, type_of_data",
[
("question-1", None, None),
("question-2", 12, int),
("question-3", 12, str),
("question-4", 12, frozenset),
("question-5", 12, tuple),
(["text"], "question-1", None, None),
(["text", "label"], "question-2", 12, int),
(["text", "label"], "question-3", 12, str),
(["text"], "question-4", 12, FrozenSet),
(["label"], "question-5", 12, Tuple),
],
)
@pytest.mark.usefixtures(
Expand All @@ -91,6 +93,7 @@ def test_prepare_dataset_for_annotation_task(
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord],
question: str,
field: Union[str, List[str]],
num_items: int,
type_of_data: Union[str, int, FrozenSet, Tuple[str]],
):
Expand All @@ -101,19 +104,24 @@ def test_prepare_dataset_for_annotation_task(
)
dataset.add_records(records=feedback_dataset_records_with_paired_suggestions)

if question in ("question-1",):
if question == "question-1":
with pytest.raises(NotImplementedError, match=r"^Question '"):
prepare_dataset_for_annotation_task(dataset, question)
prepare_dataset_for_annotation_task(dataset, question, field)
else:
formatted_dataset = prepare_dataset_for_annotation_task(dataset, question)
formatted_dataset = prepare_dataset_for_annotation_task(dataset, question, field)
assert isinstance(formatted_dataset, list)
assert len(formatted_dataset) == num_items
item = formatted_dataset[0]
assert isinstance(item, tuple)
assert isinstance(item[0], str)
assert item[0].startswith("00000000-") # beginning of our uuid for tests
assert isinstance(item[1], str)
assert item[1] == feedback_dataset_records_with_paired_suggestions[0].fields["text"]
expected_field_value = (
" ".join([feedback_dataset_records_with_paired_suggestions[0].fields[f] for f in field])
if isinstance(field, list)
else feedback_dataset_records_with_paired_suggestions[0].fields[field]
)
assert item[1] == expected_field_value
assert isinstance(item[2], type_of_data)


Expand Down Expand Up @@ -156,9 +164,17 @@ def test_agreement_metrics(

if question in ("question-1",):
with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"):
AgreementMetric(dataset=dataset, question_name=question)
AgreementMetric(
dataset=dataset,
question_name=question,
field_name=[field.name for field in feedback_dataset_fields],
)
else:
metric = AgreementMetric(dataset=dataset, question_name=question)
metric = AgreementMetric(
dataset=dataset,
question_name=question,
field_name=[field.name for field in feedback_dataset_fields],
)
# Test for repr method
assert repr(metric) == f"AgreementMetric(question_name={question})"
metrics_report = metric.compute(metric_names)
Expand All @@ -173,19 +189,19 @@ def test_agreement_metrics(

@pytest.mark.asyncio
@pytest.mark.parametrize(
"question, metric_names",
"field, question, metric_names",
[
# TextQuestion
("question-1", None),
(["text"], "question-1", None),
# RatingQuestion
("question-2", "alpha"),
("question-2", ["alpha"]),
(["text", "label"], "question-2", "alpha"),
(["text", "label"], "question-2", ["alpha"]),
# LabelQuestion
("question-3", "alpha"),
("text", "question-3", "alpha"),
# MultiLabelQuestion
("question-4", "alpha"),
("label", "question-4", "alpha"),
# RankingQuestion
("question-5", "alpha"),
(["text", "label"], "question-5", "alpha"),
],
)
@pytest.mark.usefixtures(
Expand All @@ -200,6 +216,7 @@ async def test_agreement_metrics_remote(
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord],
question: str,
field: Union[str, List[str]],
metric_names: Union[str, List[str]],
owner: User,
):
Expand All @@ -219,9 +236,17 @@ async def test_agreement_metrics_remote(

if question in ("question-1",):
with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"):
AgreementMetric(dataset=remote, question_name=question)
AgreementMetric(
dataset=remote,
question_name=question,
field_name=field,
)
else:
metric = AgreementMetric(dataset=remote, question_name=question)
metric = AgreementMetric(
dataset=remote,
question_name=question,
field_name=field,
)
# Test for repr method
assert repr(metric) == f"AgreementMetric(question_name={question})"
metrics_report = metric.compute(metric_names)
Expand All @@ -235,19 +260,19 @@ async def test_agreement_metrics_remote(


@pytest.mark.parametrize(
"question, metric_names",
"field, question, metric_names",
[
# TextQuestion
("question-1", None),
(["text"], "question-1", None),
# RatingQuestion
("question-2", "alpha"),
("question-2", ["alpha"]),
(["text", "label"], "question-2", "alpha"),
(["text", "label"], "question-2", ["alpha"]),
# LabelQuestion
("question-3", "alpha"),
("text", "question-3", "alpha"),
# MultiLabelQuestion
("question-4", "alpha"),
("label", "question-4", "alpha"),
# RankingQuestion
("question-5", "alpha"),
(["text", "label"], "question-5", "alpha"),
],
)
@pytest.mark.usefixtures(
Expand All @@ -262,6 +287,7 @@ def test_agreement_metrics_from_feedback_dataset(
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord],
question: str,
field: Union[str, List[str]],
metric_names: Union[str, List[str]],
):
dataset = FeedbackDataset(
Expand All @@ -273,9 +299,11 @@ def test_agreement_metrics_from_feedback_dataset(

if question in ("question-1",):
with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"):
dataset.compute_agreement_metrics(question_name=question, metric_names=metric_names)
dataset.compute_agreement_metrics(question_name=question, field_name=field, metric_names=metric_names)
else:
metrics_report = dataset.compute_agreement_metrics(question_name=question, metric_names=metric_names)
metrics_report = dataset.compute_agreement_metrics(
question_name=question, field_name=field, metric_names=metric_names
)

if isinstance(metric_names, str):
metrics_report = [metrics_report]
Expand All @@ -288,19 +316,19 @@ def test_agreement_metrics_from_feedback_dataset(

@pytest.mark.asyncio
@pytest.mark.parametrize(
"question, metric_names",
"field, question, metric_names",
[
# TextQuestion
("question-1", None),
(["text"], "question-1", None),
# RatingQuestion
("question-2", "alpha"),
("question-2", ["alpha"]),
(["text", "label"], "question-2", "alpha"),
(["text", "label"], "question-2", ["alpha"]),
# LabelQuestion
("question-3", "alpha"),
("text", "question-3", "alpha"),
# MultiLabelQuestion
("question-4", "alpha"),
("label", "question-4", "alpha"),
# RankingQuestion
("question-5", "alpha"),
(["text", "label"], "question-5", "alpha"),
],
)
@pytest.mark.usefixtures(
Expand All @@ -315,6 +343,7 @@ async def test_agreement_metrics_from_remote_feedback_dataset(
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord],
question: str,
field: Union[str, List[str]],
metric_names: Union[str, List[str]],
owner: User,
) -> None:
Expand All @@ -335,9 +364,11 @@ async def test_agreement_metrics_from_remote_feedback_dataset(

if question in ("question-1",):
with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"):
remote.compute_agreement_metrics(question_name=question, metric_names=metric_names)
remote.compute_agreement_metrics(question_name=question, field_name=field, metric_names=metric_names)
else:
metrics_report = remote.compute_agreement_metrics(question_name=question, metric_names=metric_names)
metrics_report = remote.compute_agreement_metrics(
question_name=question, field_name=field, metric_names=metric_names
)

if isinstance(metric_names, str):
metrics_report = [metrics_report]
Expand Down
4 changes: 2 additions & 2 deletions docs/_source/practical_guides/collect_responses.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ import argilla as rg
from argilla.client.feedback.metrics import AgreementMetric

feedback_dataset = rg.FeedbackDataset.from_argilla("...", workspace="...")
metric = AgreementMetric(dataset=feedback_dataset, question_name="question_name")
metric = AgreementMetric(dataset=feedback_dataset, field_name="text", question_name="question_name")
agreement_metrics = metric.compute("alpha")
# >>> agreement_metrics
# [AgreementMetricResult(metric_name='alpha', count=1000, result=0.467889)]
Expand All @@ -156,7 +156,7 @@ import argilla as rg

#dataset = rg.FeedbackDataset.from_huggingface("argilla/go_emotions_raw")

agreement_metrics = dataset.compute_agreement_metrics(question_name="label", metric_names="alpha")
agreement_metrics = dataset.compute_agreement_metrics(question_name="label", field_name="text", metric_names="alpha")
agreement_metrics

# AgreementMetricResult(metric_name='alpha', count=191792, result=0.2703263452657748)
Expand Down

0 comments on commit 97aa33f

Please sign in to comment.