diff --git a/src/argilla/client/feedback/integrations/huggingface/dataset.py b/src/argilla/client/feedback/integrations/huggingface/dataset.py index 185dd345ca..71fda2d26d 100644 --- a/src/argilla/client/feedback/integrations/huggingface/dataset.py +++ b/src/argilla/client/feedback/integrations/huggingface/dataset.py @@ -87,6 +87,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset "start": Value(dtype="int32"), "end": Value(dtype="int32"), "label": Value(dtype="string"), + "text": Value(dtype="string"), }, id="question", ) @@ -95,6 +96,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset "start": Value(dtype="int32"), "end": Value(dtype="int32"), "label": Value(dtype="string"), + "text": Value(dtype="string"), "score": Value(dtype="float32"), } ) @@ -165,6 +167,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset "start": span.start, "end": span.end, "label": span.label, + "text": record.fields[question.field][span.start : span.end], } for span in response.values[question.name].value ] @@ -178,7 +181,20 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset if record.suggestions: for suggestion in record.suggestions: if question.name == suggestion.question_name: - suggestion_value = suggestion.dict(include={"value"})["value"] + if question.type == QuestionTypes.span: + suggestion_value = [ + { + "start": span.start, + "end": span.end, + "label": span.label, + "score": span.score, + "text": record.fields[question.field][span.start : span.end], + } + for span in suggestion.value + ] + else: + suggestion_value = suggestion.dict(include={"value"})["value"] + suggestion_metadata = { "type": suggestion.type, "score": suggestion.score, diff --git a/tests/unit/client/feedback/schemas/test_questions.py b/tests/unit/client/feedback/schemas/test_questions.py index b695788c18..2f0eb03a45 100644 --- a/tests/unit/client/feedback/schemas/test_questions.py +++ b/tests/unit/client/feedback/schemas/test_questions.py @@ -502,6 +502,7 @@ def test_span_question_with_duplicated_labels() -> None: SpanQuestion( name="question", title="Question", + field="field", description="Description", labels=[SpanLabelOption(value="a", text="A text"), SpanLabelOption(value="a", text="Text for A")], )