diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e1ecff982..de28a24bcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ These are the section headers that we use: ## [Unreleased]() +### Fixed + +- Fixed prepare for training when passing `RankingValueSchema` instances to suggestions. ([#4628](https://github.com/argilla-io/argilla/pull/4628)) + ## [1.25.0](https://github.com/argilla-io/argilla/compare/v1.24.0...v1.25.0) > [!NOTE] diff --git a/src/argilla/client/feedback/integrations/huggingface/dataset.py b/src/argilla/client/feedback/integrations/huggingface/dataset.py index faba10cdac..136e979604 100644 --- a/src/argilla/client/feedback/integrations/huggingface/dataset.py +++ b/src/argilla/client/feedback/integrations/huggingface/dataset.py @@ -148,7 +148,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset if record.suggestions: for suggestion in record.suggestions: if question.name == suggestion.question_name: - suggestion_value = suggestion.value + suggestion_value = suggestion.dict(include={"value"})["value"] suggestion_metadata = { "type": suggestion.type, "score": suggestion.score, diff --git a/src/argilla/client/feedback/schemas/records.py b/src/argilla/client/feedback/schemas/records.py index b550a29bb0..1b72cb7836 100644 --- a/src/argilla/client/feedback/schemas/records.py +++ b/src/argilla/client/feedback/schemas/records.py @@ -130,15 +130,10 @@ class Config: def to_server_payload(self, question_name_to_id: Dict[str, UUID]) -> Dict[str, Any]: """Method that will be used to create the payload that will be sent to Argilla to create a `SuggestionSchema` for a `FeedbackRecord`.""" - payload = {} + # We can do this because there is no default values for the fields + payload = self.dict(exclude_unset=True, include={"type", "score", "value", "agent"}) payload["question_id"] = str(question_name_to_id[self.question_name]) - payload["value"] = self.value - if self.type: - payload["type"] = self.type - if self.score: - payload["score"] = self.score - if self.agent: - payload["agent"] = self.agent + return payload diff --git a/src/argilla/client/feedback/training/schemas/base.py b/src/argilla/client/feedback/training/schemas/base.py index 65c4544369..edf3ad5dc0 100644 --- a/src/argilla/client/feedback/training/schemas/base.py +++ b/src/argilla/client/feedback/training/schemas/base.py @@ -410,7 +410,7 @@ def for_supervised_fine_tuning( >>> from argilla import TrainingTask >>> dataset = rg.FeedbackDataset.from_argilla(name="...") >>> def formatting_func(sample: Dict[str, Any]): - ... annotations = sample["good] + ... annotations = sample["good"] ... if annotations and annotations[0]["value"] == "Bad": ... return ... return template.format(prompt=sample["prompt"][0]["value"], response=sample["response"][0]["value"]) @@ -973,7 +973,7 @@ class TrainingTaskForSFT(BaseModel, TrainingData): >>> from argilla import TrainingTaskForSFT >>> dataset = rg.FeedbackDataset.from_argilla(name="...") >>> def formatting_func(sample: Dict[str, Any]): - ... annotations = sample["good] + ... annotations = sample["good"] ... if annotations and annotations[0]["value"] == "Bad": ... return ... yield template.format(prompt=sample["prompt"][0]["value"], response=sample["response"][0]["value"]) diff --git a/tests/integration/client/conftest.py b/tests/integration/client/conftest.py index ae25690e20..37f7c7466b 100644 --- a/tests/integration/client/conftest.py +++ b/tests/integration/client/conftest.py @@ -25,6 +25,7 @@ LabelQuestion, MultiLabelQuestion, RankingQuestion, + RankingValueSchema, RatingQuestion, TextField, TextQuestion, @@ -481,7 +482,9 @@ def feedback_dataset_records() -> List[FeedbackRecord]: "question-2": {"value": 2}, "question-3": {"value": "b"}, "question-4": {"value": ["b", "c"]}, - "question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]}, + "question-5": { + "value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank=2, value="b")] + }, }, "status": "submitted", } @@ -517,7 +520,7 @@ def feedback_dataset_records() -> List[FeedbackRecord]: }, { "question_name": "question-5", - "value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}], + "value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank=2, value="b")], "type": "human", "score": 0.0, "agent": "agent-1",