From 97aa33f992061987a1c6f8e1a3f99affd5a002d3 Mon Sep 17 00:00:00 2001 From: Sara Han <127759186+sdiazlor@users.noreply.github.com> Date: Thu, 18 Jul 2024 16:47:36 +0200 Subject: [PATCH] bug: 5123 metrics (#5245) # Description Only added field_name to agreement metrics as only alpha is influenced by the content field. Closes #5123 **Type of change** - Bug fix (non-breaking change which fixes an issue) - Documentation update **How Has This Been Tested** **Checklist** - I did a self-review of my code - I made corresponding changes to the documentation --- .../argilla/client/feedback/dataset/mixins.py | 4 +- .../feedback/metrics/agreement_metrics.py | 15 ++- .../metrics/test_agreement_metrics.py | 111 +++++++++++------- .../practical_guides/collect_responses.md | 4 +- 4 files changed, 88 insertions(+), 46 deletions(-) diff --git a/argilla/src/argilla/client/feedback/dataset/mixins.py b/argilla/src/argilla/client/feedback/dataset/mixins.py index 2ed2d854dd..50a5a31e32 100644 --- a/argilla/src/argilla/client/feedback/dataset/mixins.py +++ b/argilla/src/argilla/client/feedback/dataset/mixins.py @@ -85,6 +85,7 @@ def compute_agreement_metrics( self, metric_names: Union[str, List[str]] = None, question_name: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion, RankingQuestion] = None, + field_name: Union[str, List[str]] = None, ) -> Union["AgreementMetricResult", List["AgreementMetricResult"]]: """Compute agreement or reliability of annotation metrics. @@ -94,6 +95,7 @@ def compute_agreement_metrics( Args: metric_names: Metric name or list of metric names of the metrics, dependent on the question type. question_name: Question for which we want to compute the metrics. + field_name: Name of the fields related to the question we want to analyse the agreement. Note: Currently, TextQuestion is not supported. @@ -104,7 +106,7 @@ def compute_agreement_metrics( """ from argilla.client.feedback.metrics.agreement_metrics import AgreementMetric - return AgreementMetric(self, question_name).compute(metric_names) + return AgreementMetric(self, question_name, field_name).compute(metric_names) class UnificationMixin: diff --git a/argilla/src/argilla/client/feedback/metrics/agreement_metrics.py b/argilla/src/argilla/client/feedback/metrics/agreement_metrics.py index 46299617dc..e4105a31d6 100644 --- a/argilla/src/argilla/client/feedback/metrics/agreement_metrics.py +++ b/argilla/src/argilla/client/feedback/metrics/agreement_metrics.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module contains metrics to gather information related to inter-Annotator agreement. """ +"""This module contains metrics to gather information related to inter-Annotator agreement.""" + import warnings from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union @@ -52,6 +53,7 @@ def __init__(self, **kwargs): def prepare_dataset_for_annotation_task( dataset: Union["FeedbackDataset", "RemoteFeedbackDataset"], question_name: str, + field_name: Union[str, List[str]], filter_by: Optional[Dict[str, Union["ResponseStatusFilter", List["ResponseStatusFilter"]]]] = None, sort_by: Optional[List["SortBy"]] = None, max_records: Optional[int] = None, @@ -74,6 +76,7 @@ def prepare_dataset_for_annotation_task( Args: dataset: FeedbackDataset to compute the metrics. question_name: Name of the question for which we want to analyse the agreement. + field_name: Name of the fields related to the question we want to analyse the agreement. filter_by: A dict with key the field to filter by, and values the filters to apply. Can be one of: draft, pending, submitted, and discarded. If set to None, no filter will be applied. Defaults to None (no filter is applied). @@ -108,7 +111,9 @@ def prepare_dataset_for_annotation_task( for row in hf_dataset: responses_ = row[question_name] - question_text = row["text"] + question_text = ( + " ".join([row[field] for field in field_name]) if isinstance(field_name, list) else row[field_name] + ) for response in responses_: user_id = response["user_id"] if user_id is None: @@ -181,7 +186,7 @@ class AgreementMetric(MetricBase): Example: >>> import argilla as rg >>> from argilla.client.feedback.metrics import AgreementMetric - >>> metric = AgreementMetric(dataset=dataset, question_name=question, filter_by={"response_status": "submitted"}) + >>> metric = AgreementMetric(dataset=dataset, question_name=question, field_name=field, filter_by={"response_status": "submitted"}) >>> metrics_report = metric.compute("alpha") """ @@ -190,6 +195,7 @@ def __init__( self, dataset: FeedbackDataset, question_name: str, + field_name: Union[str, List[str]], filter_by: Optional[Dict[str, Union["ResponseStatusFilter", List["ResponseStatusFilter"]]]] = None, sort_by: Optional[List["SortBy"]] = None, max_records: Optional[int] = None, @@ -199,6 +205,7 @@ def __init__( Args: dataset: FeedbackDataset to compute the metrics. question_name: Name of the question for which we want to analyse the agreement. + field_name: Name of the fields related to the question we want to analyse the agreement. filter_by: A dict with key the field to filter by, and values the filters to apply. Can be one of: draft, pending, submitted, and discarded. If set to None, no filter will be applied. Defaults to None (no filter is applied). @@ -207,6 +214,7 @@ def __init__( max_records: The maximum number of records to use for training. Defaults to None. """ self._metrics_per_question = METRICS_PER_QUESTION + self._field_name = field_name super().__init__(dataset, question_name) self._filter_by = filter_by self._sort_by = sort_by @@ -231,6 +239,7 @@ def compute(self, metric_names: Union[str, List[str]]) -> List[AgreementMetricRe dataset = prepare_dataset_for_annotation_task( self._dataset, self._question_name, + self._field_name, filter_by=self._filter_by, sort_by=self._sort_by, max_records=self._max_records, diff --git a/argilla/tests/integration/client/feedback/metrics/test_agreement_metrics.py b/argilla/tests/integration/client/feedback/metrics/test_agreement_metrics.py index d2cef9f226..41c2a91c50 100644 --- a/argilla/tests/integration/client/feedback/metrics/test_agreement_metrics.py +++ b/argilla/tests/integration/client/feedback/metrics/test_agreement_metrics.py @@ -65,18 +65,20 @@ def test_allowed_metrics( ) dataset.add_records(records=feedback_dataset_records_with_paired_suggestions) - metric = AgreementMetric(dataset=dataset, question_name=question) + metric = AgreementMetric( + dataset=dataset, question_name=question, field_name=[field.name for field in feedback_dataset_fields] + ) assert set(metric.allowed_metrics) == metric_names @pytest.mark.parametrize( - "question, num_items, type_of_data", + "field, question, num_items, type_of_data", [ - ("question-1", None, None), - ("question-2", 12, int), - ("question-3", 12, str), - ("question-4", 12, frozenset), - ("question-5", 12, tuple), + (["text"], "question-1", None, None), + (["text", "label"], "question-2", 12, int), + (["text", "label"], "question-3", 12, str), + (["text"], "question-4", 12, FrozenSet), + (["label"], "question-5", 12, Tuple), ], ) @pytest.mark.usefixtures( @@ -91,6 +93,7 @@ def test_prepare_dataset_for_annotation_task( feedback_dataset_questions: List["AllowedQuestionTypes"], feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord], question: str, + field: Union[str, List[str]], num_items: int, type_of_data: Union[str, int, FrozenSet, Tuple[str]], ): @@ -101,11 +104,11 @@ def test_prepare_dataset_for_annotation_task( ) dataset.add_records(records=feedback_dataset_records_with_paired_suggestions) - if question in ("question-1",): + if question == "question-1": with pytest.raises(NotImplementedError, match=r"^Question '"): - prepare_dataset_for_annotation_task(dataset, question) + prepare_dataset_for_annotation_task(dataset, question, field) else: - formatted_dataset = prepare_dataset_for_annotation_task(dataset, question) + formatted_dataset = prepare_dataset_for_annotation_task(dataset, question, field) assert isinstance(formatted_dataset, list) assert len(formatted_dataset) == num_items item = formatted_dataset[0] @@ -113,7 +116,12 @@ def test_prepare_dataset_for_annotation_task( assert isinstance(item[0], str) assert item[0].startswith("00000000-") # beginning of our uuid for tests assert isinstance(item[1], str) - assert item[1] == feedback_dataset_records_with_paired_suggestions[0].fields["text"] + expected_field_value = ( + " ".join([feedback_dataset_records_with_paired_suggestions[0].fields[f] for f in field]) + if isinstance(field, list) + else feedback_dataset_records_with_paired_suggestions[0].fields[field] + ) + assert item[1] == expected_field_value assert isinstance(item[2], type_of_data) @@ -156,9 +164,17 @@ def test_agreement_metrics( if question in ("question-1",): with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"): - AgreementMetric(dataset=dataset, question_name=question) + AgreementMetric( + dataset=dataset, + question_name=question, + field_name=[field.name for field in feedback_dataset_fields], + ) else: - metric = AgreementMetric(dataset=dataset, question_name=question) + metric = AgreementMetric( + dataset=dataset, + question_name=question, + field_name=[field.name for field in feedback_dataset_fields], + ) # Test for repr method assert repr(metric) == f"AgreementMetric(question_name={question})" metrics_report = metric.compute(metric_names) @@ -173,19 +189,19 @@ def test_agreement_metrics( @pytest.mark.asyncio @pytest.mark.parametrize( - "question, metric_names", + "field, question, metric_names", [ # TextQuestion - ("question-1", None), + (["text"], "question-1", None), # RatingQuestion - ("question-2", "alpha"), - ("question-2", ["alpha"]), + (["text", "label"], "question-2", "alpha"), + (["text", "label"], "question-2", ["alpha"]), # LabelQuestion - ("question-3", "alpha"), + ("text", "question-3", "alpha"), # MultiLabelQuestion - ("question-4", "alpha"), + ("label", "question-4", "alpha"), # RankingQuestion - ("question-5", "alpha"), + (["text", "label"], "question-5", "alpha"), ], ) @pytest.mark.usefixtures( @@ -200,6 +216,7 @@ async def test_agreement_metrics_remote( feedback_dataset_questions: List["AllowedQuestionTypes"], feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord], question: str, + field: Union[str, List[str]], metric_names: Union[str, List[str]], owner: User, ): @@ -219,9 +236,17 @@ async def test_agreement_metrics_remote( if question in ("question-1",): with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"): - AgreementMetric(dataset=remote, question_name=question) + AgreementMetric( + dataset=remote, + question_name=question, + field_name=field, + ) else: - metric = AgreementMetric(dataset=remote, question_name=question) + metric = AgreementMetric( + dataset=remote, + question_name=question, + field_name=field, + ) # Test for repr method assert repr(metric) == f"AgreementMetric(question_name={question})" metrics_report = metric.compute(metric_names) @@ -235,19 +260,19 @@ async def test_agreement_metrics_remote( @pytest.mark.parametrize( - "question, metric_names", + "field, question, metric_names", [ # TextQuestion - ("question-1", None), + (["text"], "question-1", None), # RatingQuestion - ("question-2", "alpha"), - ("question-2", ["alpha"]), + (["text", "label"], "question-2", "alpha"), + (["text", "label"], "question-2", ["alpha"]), # LabelQuestion - ("question-3", "alpha"), + ("text", "question-3", "alpha"), # MultiLabelQuestion - ("question-4", "alpha"), + ("label", "question-4", "alpha"), # RankingQuestion - ("question-5", "alpha"), + (["text", "label"], "question-5", "alpha"), ], ) @pytest.mark.usefixtures( @@ -262,6 +287,7 @@ def test_agreement_metrics_from_feedback_dataset( feedback_dataset_questions: List["AllowedQuestionTypes"], feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord], question: str, + field: Union[str, List[str]], metric_names: Union[str, List[str]], ): dataset = FeedbackDataset( @@ -273,9 +299,11 @@ def test_agreement_metrics_from_feedback_dataset( if question in ("question-1",): with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"): - dataset.compute_agreement_metrics(question_name=question, metric_names=metric_names) + dataset.compute_agreement_metrics(question_name=question, field_name=field, metric_names=metric_names) else: - metrics_report = dataset.compute_agreement_metrics(question_name=question, metric_names=metric_names) + metrics_report = dataset.compute_agreement_metrics( + question_name=question, field_name=field, metric_names=metric_names + ) if isinstance(metric_names, str): metrics_report = [metrics_report] @@ -288,19 +316,19 @@ def test_agreement_metrics_from_feedback_dataset( @pytest.mark.asyncio @pytest.mark.parametrize( - "question, metric_names", + "field, question, metric_names", [ # TextQuestion - ("question-1", None), + (["text"], "question-1", None), # RatingQuestion - ("question-2", "alpha"), - ("question-2", ["alpha"]), + (["text", "label"], "question-2", "alpha"), + (["text", "label"], "question-2", ["alpha"]), # LabelQuestion - ("question-3", "alpha"), + ("text", "question-3", "alpha"), # MultiLabelQuestion - ("question-4", "alpha"), + ("label", "question-4", "alpha"), # RankingQuestion - ("question-5", "alpha"), + (["text", "label"], "question-5", "alpha"), ], ) @pytest.mark.usefixtures( @@ -315,6 +343,7 @@ async def test_agreement_metrics_from_remote_feedback_dataset( feedback_dataset_questions: List["AllowedQuestionTypes"], feedback_dataset_records_with_paired_suggestions: List[FeedbackRecord], question: str, + field: Union[str, List[str]], metric_names: Union[str, List[str]], owner: User, ) -> None: @@ -335,9 +364,11 @@ async def test_agreement_metrics_from_remote_feedback_dataset( if question in ("question-1",): with pytest.raises(NotImplementedError, match=r"^No metrics are defined currently for"): - remote.compute_agreement_metrics(question_name=question, metric_names=metric_names) + remote.compute_agreement_metrics(question_name=question, field_name=field, metric_names=metric_names) else: - metrics_report = remote.compute_agreement_metrics(question_name=question, metric_names=metric_names) + metrics_report = remote.compute_agreement_metrics( + question_name=question, field_name=field, metric_names=metric_names + ) if isinstance(metric_names, str): metrics_report = [metrics_report] diff --git a/docs/_source/practical_guides/collect_responses.md b/docs/_source/practical_guides/collect_responses.md index ba49f3b65d..edce957a6b 100644 --- a/docs/_source/practical_guides/collect_responses.md +++ b/docs/_source/practical_guides/collect_responses.md @@ -141,7 +141,7 @@ import argilla as rg from argilla.client.feedback.metrics import AgreementMetric feedback_dataset = rg.FeedbackDataset.from_argilla("...", workspace="...") -metric = AgreementMetric(dataset=feedback_dataset, question_name="question_name") +metric = AgreementMetric(dataset=feedback_dataset, field_name="text", question_name="question_name") agreement_metrics = metric.compute("alpha") # >>> agreement_metrics # [AgreementMetricResult(metric_name='alpha', count=1000, result=0.467889)] @@ -156,7 +156,7 @@ import argilla as rg #dataset = rg.FeedbackDataset.from_huggingface("argilla/go_emotions_raw") -agreement_metrics = dataset.compute_agreement_metrics(question_name="label", metric_names="alpha") +agreement_metrics = dataset.compute_agreement_metrics(question_name="label", field_name="text", metric_names="alpha") agreement_metrics # AgreementMetricResult(metric_name='alpha', count=191792, result=0.2703263452657748)