From b226ef530f2fc9c3908551ead3ad96236b74a74b Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Tue, 5 Mar 2024 17:38:53 +0100 Subject: [PATCH 01/12] fix: Parse ranking suggestions for huggingface format --- .../client/feedback/integrations/huggingface/dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/argilla/client/feedback/integrations/huggingface/dataset.py b/src/argilla/client/feedback/integrations/huggingface/dataset.py index faba10cdac..72677c5fbf 100644 --- a/src/argilla/client/feedback/integrations/huggingface/dataset.py +++ b/src/argilla/client/feedback/integrations/huggingface/dataset.py @@ -148,7 +148,10 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset if record.suggestions: for suggestion in record.suggestions: if question.name == suggestion.question_name: - suggestion_value = suggestion.value + if question.type == QuestionTypes.ranking: + suggestion_value = [r.dict() for r in suggestion.value] + else: + suggestion_value = suggestion.value suggestion_metadata = { "type": suggestion.type, "score": suggestion.score, From 7f1ac4058740e88ee4454c6d2971e29d908e656d Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Tue, 5 Mar 2024 17:39:16 +0100 Subject: [PATCH 02/12] fix some code examples --- src/argilla/client/feedback/training/schemas/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/argilla/client/feedback/training/schemas/base.py b/src/argilla/client/feedback/training/schemas/base.py index 65c4544369..edf3ad5dc0 100644 --- a/src/argilla/client/feedback/training/schemas/base.py +++ b/src/argilla/client/feedback/training/schemas/base.py @@ -410,7 +410,7 @@ def for_supervised_fine_tuning( >>> from argilla import TrainingTask >>> dataset = rg.FeedbackDataset.from_argilla(name="...") >>> def formatting_func(sample: Dict[str, Any]): - ... annotations = sample["good] + ... annotations = sample["good"] ... if annotations and annotations[0]["value"] == "Bad": ... return ... return template.format(prompt=sample["prompt"][0]["value"], response=sample["response"][0]["value"]) @@ -973,7 +973,7 @@ class TrainingTaskForSFT(BaseModel, TrainingData): >>> from argilla import TrainingTaskForSFT >>> dataset = rg.FeedbackDataset.from_argilla(name="...") >>> def formatting_func(sample: Dict[str, Any]): - ... annotations = sample["good] + ... annotations = sample["good"] ... if annotations and annotations[0]["value"] == "Bad": ... return ... yield template.format(prompt=sample["prompt"][0]["value"], response=sample["response"][0]["value"]) From 74c70f116c6ebc5fba9cd77fd73e66f4a73b9fb8 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Tue, 5 Mar 2024 17:39:47 +0100 Subject: [PATCH 03/12] adapt tests --- tests/integration/client/conftest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/client/conftest.py b/tests/integration/client/conftest.py index ae25690e20..1ddc5f7e60 100644 --- a/tests/integration/client/conftest.py +++ b/tests/integration/client/conftest.py @@ -25,7 +25,7 @@ LabelQuestion, MultiLabelQuestion, RankingQuestion, - RatingQuestion, + RankingValueSchema, RatingQuestion, TextField, TextQuestion, ) @@ -481,7 +481,7 @@ def feedback_dataset_records() -> List[FeedbackRecord]: "question-2": {"value": 2}, "question-3": {"value": "b"}, "question-4": {"value": ["b", "c"]}, - "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, + "question-5": {"value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank= 2, value="b")]}, }, "status": "submitted", } @@ -517,7 +517,7 @@ def feedback_dataset_records() -> List[FeedbackRecord]: }, { "question_name": "question-5", - "value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}], + "value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank= 2, value="b")], "type": "human", "score": 0.0, "agent": "agent-1", From c5eee3944325c1b623fe30f84e167cf86f9a5b85 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Mar 2024 16:41:47 +0000 Subject: [PATCH 04/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/integration/client/conftest.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/integration/client/conftest.py b/tests/integration/client/conftest.py index 1ddc5f7e60..37f7c7466b 100644 --- a/tests/integration/client/conftest.py +++ b/tests/integration/client/conftest.py @@ -25,7 +25,8 @@ LabelQuestion, MultiLabelQuestion, RankingQuestion, - RankingValueSchema, RatingQuestion, + RankingValueSchema, + RatingQuestion, TextField, TextQuestion, ) @@ -481,7 +482,9 @@ def feedback_dataset_records() -> List[FeedbackRecord]: "question-2": {"value": 2}, "question-3": {"value": "b"}, "question-4": {"value": ["b", "c"]}, - "question-5": {"value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank= 2, value="b")]}, + "question-5": { + "value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank=2, value="b")] + }, }, "status": "submitted", } @@ -517,7 +520,7 @@ def feedback_dataset_records() -> List[FeedbackRecord]: }, { "question_name": "question-5", - "value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank= 2, value="b")], + "value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank=2, value="b")], "type": "human", "score": 0.0, "agent": "agent-1", From 4fa98ee6083d3e64262bb64f9099ccf948779c25 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Tue, 5 Mar 2024 17:42:56 +0100 Subject: [PATCH 05/12] chore: Updaet CHANGELOG --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e1ecff982..de28a24bcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ These are the section headers that we use: ## [Unreleased]() +### Fixed + +- Fixed prepare for training when passing `RankingValueSchema` instances to suggestions. ([#4628](https://github.com/argilla-io/argilla/pull/4628)) + ## [1.25.0](https://github.com/argilla-io/argilla/compare/v1.24.0...v1.25.0) > [!NOTE] From 1d972e3142db59ff92b73f1eb04e602efa272f0f Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 6 Mar 2024 10:18:48 +0100 Subject: [PATCH 06/12] fix: Generalize suggestion serialization --- src/argilla/client/feedback/schemas/records.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/argilla/client/feedback/schemas/records.py b/src/argilla/client/feedback/schemas/records.py index b550a29bb0..bfd9d5267a 100644 --- a/src/argilla/client/feedback/schemas/records.py +++ b/src/argilla/client/feedback/schemas/records.py @@ -130,15 +130,10 @@ class Config: def to_server_payload(self, question_name_to_id: Dict[str, UUID]) -> Dict[str, Any]: """Method that will be used to create the payload that will be sent to Argilla to create a `SuggestionSchema` for a `FeedbackRecord`.""" - payload = {} + # We can do this because there is no default values for the fields + payload = self.dict(exclude_unset=True, include={"type", "score", "value", "agent "}) payload["question_id"] = str(question_name_to_id[self.question_name]) - payload["value"] = self.value - if self.type: - payload["type"] = self.type - if self.score: - payload["score"] = self.score - if self.agent: - payload["agent"] = self.agent + return payload From 28b045032888c14f48acf761c8dd30aff89c640d Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 6 Mar 2024 10:19:29 +0100 Subject: [PATCH 07/12] refactor: use pydantic serialization to extract suggestion value for HF properly --- .../client/feedback/integrations/huggingface/dataset.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/argilla/client/feedback/integrations/huggingface/dataset.py b/src/argilla/client/feedback/integrations/huggingface/dataset.py index 72677c5fbf..136e979604 100644 --- a/src/argilla/client/feedback/integrations/huggingface/dataset.py +++ b/src/argilla/client/feedback/integrations/huggingface/dataset.py @@ -148,10 +148,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset if record.suggestions: for suggestion in record.suggestions: if question.name == suggestion.question_name: - if question.type == QuestionTypes.ranking: - suggestion_value = [r.dict() for r in suggestion.value] - else: - suggestion_value = suggestion.value + suggestion_value = suggestion.dict(include={"value"})["value"] suggestion_metadata = { "type": suggestion.type, "score": suggestion.score, From d1f2f69822e4c2aae730913f363c69f2cd657e69 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 6 Mar 2024 10:59:21 +0100 Subject: [PATCH 08/12] fix agent naming --- src/argilla/client/feedback/schemas/records.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/client/feedback/schemas/records.py b/src/argilla/client/feedback/schemas/records.py index bfd9d5267a..1b72cb7836 100644 --- a/src/argilla/client/feedback/schemas/records.py +++ b/src/argilla/client/feedback/schemas/records.py @@ -131,7 +131,7 @@ def to_server_payload(self, question_name_to_id: Dict[str, UUID]) -> Dict[str, A """Method that will be used to create the payload that will be sent to Argilla to create a `SuggestionSchema` for a `FeedbackRecord`.""" # We can do this because there is no default values for the fields - payload = self.dict(exclude_unset=True, include={"type", "score", "value", "agent "}) + payload = self.dict(exclude_unset=True, include={"type", "score", "value", "agent"}) payload["question_id"] = str(question_name_to_id[self.question_name]) return payload From 732d7922ebaa3fe3392f3f4a3f24ccbf2e09260c Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 6 Mar 2024 11:12:00 +0100 Subject: [PATCH 09/12] Add failing condition to test --- .../integration/client/feedback/dataset/local/test_dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integration/client/feedback/dataset/local/test_dataset.py b/tests/integration/client/feedback/dataset/local/test_dataset.py index 955ebc472a..f115fa2a8e 100644 --- a/tests/integration/client/feedback/dataset/local/test_dataset.py +++ b/tests/integration/client/feedback/dataset/local/test_dataset.py @@ -566,6 +566,10 @@ def test_push_to_huggingface_and_from_huggingface( hf_response.dict() == response.dict() for hf_response, response in zip(hf_record.responses, record.responses) ) + assert all( + hf_suggestion.dict() == suggestion.dict() + for hf_suggestion, suggestion in zip(hf_record.suggestions, record.suggestions) + ), f"{[s.dict() for s in hf_record.suggestions]} != {[s.dict() for s in record.suggestions]}" dataset.add_records( records=[ From 7545ec9852a2999a3b11f1e441885bab437cf5a1 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 6 Mar 2024 13:26:04 +0100 Subject: [PATCH 10/12] fix: parse back hf format to ds format for ranking values --- .../client/feedback/integrations/huggingface/dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/argilla/client/feedback/integrations/huggingface/dataset.py b/src/argilla/client/feedback/integrations/huggingface/dataset.py index 136e979604..06ee29dfe1 100644 --- a/src/argilla/client/feedback/integrations/huggingface/dataset.py +++ b/src/argilla/client/feedback/integrations/huggingface/dataset.py @@ -428,10 +428,11 @@ def from_huggingface( f"{question.name}-suggestion" in hfds[index] and hfds[index][f"{question.name}-suggestion"] is not None ): - suggestion = { - "question_name": question.name, - "value": hfds[index][f"{question.name}-suggestion"], - } + value = hfds[index][f"{question.name}-suggestion"] + if question.type == QuestionTypes.ranking: + value = [{"rank": r, "value": v} for r, v in zip(value["rank"], value["value"])] + + suggestion = {"question_name": question.name, "value": value} if hfds[index][f"{question.name}-suggestion-metadata"] is not None: suggestion.update(hfds[index][f"{question.name}-suggestion-metadata"]) suggestions.append(suggestion) From 26a38e75e39dc2c4ea21c96dc72b5708aa2f97bc Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 6 Mar 2024 13:26:27 +0100 Subject: [PATCH 11/12] chore: fix doc example --- src/argilla/client/feedback/schemas/questions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/client/feedback/schemas/questions.py b/src/argilla/client/feedback/schemas/questions.py index 5b50c9ba83..7acda4ee8c 100644 --- a/src/argilla/client/feedback/schemas/questions.py +++ b/src/argilla/client/feedback/schemas/questions.py @@ -274,7 +274,7 @@ class RankingQuestion(QuestionSchema, LabelMappingMixin): Examples: >>> from argilla.client.feedback.schemas.questions import RankingQuestion - >>> RankingQuestion(name="ranking_question", title="Ranking Question", labels=["label_1", "label_2"]) + >>> RankingQuestion(name="ranking_question", title="Ranking Question", values=["label_1", "label_2"]) """ type: Literal[QuestionTypes.ranking] = Field(QuestionTypes.ranking.value, allow_mutation=False) From 431f6fa422896bcf96b92db75ba3de3cb1ed4096 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 6 Mar 2024 13:28:09 +0100 Subject: [PATCH 12/12] chore: update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index de28a24bcf..60844f0762 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ These are the section headers that we use: ### Fixed - Fixed prepare for training when passing `RankingValueSchema` instances to suggestions. ([#4628](https://github.com/argilla-io/argilla/pull/4628)) +- Fixed parsing ranking values in suggestions from HF datasets. ([#4629](https://github.com/argilla-io/argilla/pull/4629)) ## [1.25.0](https://github.com/argilla-io/argilla/compare/v1.24.0...v1.25.0)