Skip to content

Commit

Permalink
fix: format suggestions for ranking values (#4628)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR fixes problems with HF format when passing ranking value objects
(instead of raw dictionaries).


**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [X] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

Tested locally

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
frascuchon and pre-commit-ci[bot] authored Mar 6, 2024
1 parent 254f719 commit 54aa215
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 13 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ These are the section headers that we use:

## [Unreleased]()

### Fixed

- Fixed prepare for training when passing `RankingValueSchema` instances to suggestions. ([#4628](https://github.com/argilla-io/argilla/pull/4628))

## [1.25.0](https://github.com/argilla-io/argilla/compare/v1.24.0...v1.25.0)

> [!NOTE]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _huggingface_format(dataset: Union["FeedbackDataset", "RemoteFeedbackDataset
if record.suggestions:
for suggestion in record.suggestions:
if question.name == suggestion.question_name:
suggestion_value = suggestion.value
suggestion_value = suggestion.dict(include={"value"})["value"]
suggestion_metadata = {
"type": suggestion.type,
"score": suggestion.score,
Expand Down
11 changes: 3 additions & 8 deletions src/argilla/client/feedback/schemas/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,10 @@ class Config:
def to_server_payload(self, question_name_to_id: Dict[str, UUID]) -> Dict[str, Any]:
"""Method that will be used to create the payload that will be sent to Argilla
to create a `SuggestionSchema` for a `FeedbackRecord`."""
payload = {}
# We can do this because there is no default values for the fields
payload = self.dict(exclude_unset=True, include={"type", "score", "value", "agent"})
payload["question_id"] = str(question_name_to_id[self.question_name])
payload["value"] = self.value
if self.type:
payload["type"] = self.type
if self.score:
payload["score"] = self.score
if self.agent:
payload["agent"] = self.agent

return payload


Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/training/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def for_supervised_fine_tuning(
>>> from argilla import TrainingTask
>>> dataset = rg.FeedbackDataset.from_argilla(name="...")
>>> def formatting_func(sample: Dict[str, Any]):
... annotations = sample["good]
... annotations = sample["good"]
... if annotations and annotations[0]["value"] == "Bad":
... return
... return template.format(prompt=sample["prompt"][0]["value"], response=sample["response"][0]["value"])
Expand Down Expand Up @@ -973,7 +973,7 @@ class TrainingTaskForSFT(BaseModel, TrainingData):
>>> from argilla import TrainingTaskForSFT
>>> dataset = rg.FeedbackDataset.from_argilla(name="...")
>>> def formatting_func(sample: Dict[str, Any]):
... annotations = sample["good]
... annotations = sample["good"]
... if annotations and annotations[0]["value"] == "Bad":
... return
... yield template.format(prompt=sample["prompt"][0]["value"], response=sample["response"][0]["value"])
Expand Down
7 changes: 5 additions & 2 deletions tests/integration/client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LabelQuestion,
MultiLabelQuestion,
RankingQuestion,
RankingValueSchema,
RatingQuestion,
TextField,
TextQuestion,
Expand Down Expand Up @@ -481,7 +482,9 @@ def feedback_dataset_records() -> List[FeedbackRecord]:
"question-2": {"value": 2},
"question-3": {"value": "b"},
"question-4": {"value": ["b", "c"]},
"question-5": {"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}]},
"question-5": {
"value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank=2, value="b")]
},
},
"status": "submitted",
}
Expand Down Expand Up @@ -517,7 +520,7 @@ def feedback_dataset_records() -> List[FeedbackRecord]:
},
{
"question_name": "question-5",
"value": [{"rank": 1, "value": "a"}, {"rank": 2, "value": "b"}],
"value": [RankingValueSchema(rank=1, value="a"), RankingValueSchema(rank=2, value="b")],
"type": "human",
"score": 0.0,
"agent": "agent-1",
Expand Down

0 comments on commit 54aa215

Please sign in to comment.