From bc3f9307753e4be5fa26bd8c6180fdff3e77dd56 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 13 Mar 2024 09:36:17 +0100 Subject: [PATCH] chore: apply some of suggested changes (#4644) This PR adds the `text` value for span suggestions and responses when publishing HF datasets. Commented by @davidberenstein1957 [here](https://github.com/argilla-io/argilla/pull/4623#discussion_r1519318718) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../integrations/huggingface/dataset.py | 18 +++++++++++++++++- .../client/feedback/schemas/test_questions.py | 1 + 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/argilla/client/feedback/integrations/huggingface/dataset.py b/src/argilla/client/feedback/integrations/huggingface/dataset.py index 185dd345ca..71fda2d26d 100644 --- a/src/argilla/client/feedback/integrations/huggingface/dataset.py +++ b/src/argilla/client/feedback/integrations/huggingface/dataset.py @@ -87,6 +87,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset "start": Value(dtype="int32"), "end": Value(dtype="int32"), "label": Value(dtype="string"), + "text": Value(dtype="string"), }, id="question", ) @@ -95,6 +96,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset "start": Value(dtype="int32"), "end": Value(dtype="int32"), "label": Value(dtype="string"), + "text": Value(dtype="string"), "score": Value(dtype="float32"), } ) @@ -165,6 +167,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset "start": span.start, "end": span.end, "label": span.label, + "text": record.fields[question.field][span.start : span.end], } for span in response.values[question.name].value ] @@ -178,7 +181,20 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset if record.suggestions: for suggestion in record.suggestions: if question.name == suggestion.question_name: - suggestion_value = suggestion.dict(include={"value"})["value"] + if question.type == QuestionTypes.span: + suggestion_value = [ + { + "start": span.start, + "end": span.end, + "label": span.label, + "score": span.score, + "text": record.fields[question.field][span.start : span.end], + } + for span in suggestion.value + ] + else: + suggestion_value = suggestion.dict(include={"value"})["value"] + suggestion_metadata = { "type": suggestion.type, "score": suggestion.score, diff --git a/tests/unit/client/feedback/schemas/test_questions.py b/tests/unit/client/feedback/schemas/test_questions.py index b695788c18..2f0eb03a45 100644 --- a/tests/unit/client/feedback/schemas/test_questions.py +++ b/tests/unit/client/feedback/schemas/test_questions.py @@ -502,6 +502,7 @@ def test_span_question_with_duplicated_labels() -> None: SpanQuestion( name="question", title="Question", + field="field", description="Description", labels=[SpanLabelOption(value="a", text="A text"), SpanLabelOption(value="a", text="Text for A")], )