From 9459e8f79df3fdf257d4dbc594d99e001052b98d Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 6 Mar 2024 15:06:02 +0100 Subject: [PATCH] bugfix: Ranking value from suggestions are not properly converted from HF datasets (#4629) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR adds the custom conversion for suggestions with ranking value from HF datasets. Otherwise, ranking values won't be properly created when reading from a HF dataset: ```python {'rank': [1, 2], 'value': ['a', 'b']} ``` instead of : ```python [{'value': 'a', 'rank': 1}, {'value': 'b', 'rank': 2}] ``` **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [X] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) Tested locally - [ ] Test A - [ ] Test B **Checklist** - [ ] I followed the style guidelines of this project - [ ] I did a self-review of my code - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 1 + .../client/feedback/integrations/huggingface/dataset.py | 9 +++++---- src/argilla/client/feedback/schemas/questions.py | 2 +- .../client/feedback/dataset/local/test_dataset.py | 4 ++++ 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de28a24bcf..60844f0762 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ These are the section headers that we use: ### Fixed - Fixed prepare for training when passing `RankingValueSchema` instances to suggestions. ([#4628](https://github.com/argilla-io/argilla/pull/4628)) +- Fixed parsing ranking values in suggestions from HF datasets. ([#4629](https://github.com/argilla-io/argilla/pull/4629)) ## [1.25.0](https://github.com/argilla-io/argilla/compare/v1.24.0...v1.25.0) diff --git a/src/argilla/client/feedback/integrations/huggingface/dataset.py b/src/argilla/client/feedback/integrations/huggingface/dataset.py index 136e979604..06ee29dfe1 100644 --- a/src/argilla/client/feedback/integrations/huggingface/dataset.py +++ b/src/argilla/client/feedback/integrations/huggingface/dataset.py @@ -428,10 +428,11 @@ def from_huggingface( f"{question.name}-suggestion" in hfds[index] and hfds[index][f"{question.name}-suggestion"] is not None ): - suggestion = { - "question_name": question.name, - "value": hfds[index][f"{question.name}-suggestion"], - } + value = hfds[index][f"{question.name}-suggestion"] + if question.type == QuestionTypes.ranking: + value = [{"rank": r, "value": v} for r, v in zip(value["rank"], value["value"])] + + suggestion = {"question_name": question.name, "value": value} if hfds[index][f"{question.name}-suggestion-metadata"] is not None: suggestion.update(hfds[index][f"{question.name}-suggestion-metadata"]) suggestions.append(suggestion) diff --git a/src/argilla/client/feedback/schemas/questions.py b/src/argilla/client/feedback/schemas/questions.py index 5b50c9ba83..7acda4ee8c 100644 --- a/src/argilla/client/feedback/schemas/questions.py +++ b/src/argilla/client/feedback/schemas/questions.py @@ -274,7 +274,7 @@ class RankingQuestion(QuestionSchema, LabelMappingMixin): Examples: >>> from argilla.client.feedback.schemas.questions import RankingQuestion - >>> RankingQuestion(name="ranking_question", title="Ranking Question", labels=["label_1", "label_2"]) + >>> RankingQuestion(name="ranking_question", title="Ranking Question", values=["label_1", "label_2"]) """ type: Literal[QuestionTypes.ranking] = Field(QuestionTypes.ranking.value, allow_mutation=False) diff --git a/tests/integration/client/feedback/dataset/local/test_dataset.py b/tests/integration/client/feedback/dataset/local/test_dataset.py index 955ebc472a..f115fa2a8e 100644 --- a/tests/integration/client/feedback/dataset/local/test_dataset.py +++ b/tests/integration/client/feedback/dataset/local/test_dataset.py @@ -566,6 +566,10 @@ def test_push_to_huggingface_and_from_huggingface( hf_response.dict() == response.dict() for hf_response, response in zip(hf_record.responses, record.responses) ) + assert all( + hf_suggestion.dict() == suggestion.dict() + for hf_suggestion, suggestion in zip(hf_record.suggestions, record.suggestions) + ), f"{[s.dict() for s in hf_record.suggestions]} != {[s.dict() for s in record.suggestions]}" dataset.add_records( records=[