Skip to content

Commit

Permalink
fix: Use RatingQuestion instead of RankingQuestion for sentence s…
Browse files Browse the repository at this point in the history
…imilarity in the `ArgillaTrainer` (#4171)

<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

Update the `ArgillaTrainer` to use `RatingQuestion` instead of
`RankingQuestion` when training for `sentence-similarity` with numerical
labels.

Closes #4049 

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [x]
`test/integration/client/feedback/training/test_sentence_transformers.py`

**Checklist**

- [x] I followed the style guidelines of this project
- [x] I did a self-review of my code
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)
  • Loading branch information
plaguss authored Nov 9, 2023
1 parent 5d3b40c commit da043b1
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 26 deletions.
5 changes: 2 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ These are the section headers that we use:

## [Unreleased]()

### Contributors

- @Racso-3141 Added a progress bar for parsing records process to `from_huggingface()` method with `trange` in `tqdm`.([#4132](https://github.com/argilla-io/argilla/pull/4132)).
### Fixed

- Fixed error in `ArgillaTrainer`, with numerical labels use `RatingQuestion` instead of `RankingQuestion` ([#4171](https://github.com/argilla-io/argilla/pull/4171))

## [1.19.0]()

Expand Down
10 changes: 10 additions & 0 deletions docs/_source/practical_guides/fine_tune.md
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,16 @@ task = TrainingTask.for_sentence_similarity(
)
```

For datasets that where annotated with numerical values we could also pass the label strategy we want to use (let's assume we have another question in the dataset named "other-question" that contains values that come from rated answers):

```python
task = TrainingTask.for_sentence_similarity(
texts=[dataset.field_by_name("premise"), dataset.field_by_name("hypothesis")],
label=dataset.question_by_name("other-question"),
label_strategy="majority" # or "mean" for RankingQuestion
)
```

:::

:::{tab-item} formatting_func
Expand Down
43 changes: 25 additions & 18 deletions src/argilla/client/feedback/training/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def for_sentence_similarity(
List[Dict[str, str]],
],
] = None,
label_strategy: Optional[LabelQuestionUnification] = None,
label_strategy: Optional[Union[LabelQuestionUnification, RatingQuestionUnification]] = None,
) -> "TrainingTaskForSentenceSimilarity":
"""
Expand Down Expand Up @@ -557,12 +557,12 @@ def for_sentence_similarity(
)

if formatting_func is not None:
return TrainingTaskForSentenceSimilarity(formatting_func=formatting_func)
return TrainingTaskForSentenceSimilarity(formatting_func=formatting_func, label=label_strategy)
else:
if not label:
return TrainingTaskForSentenceSimilarity(texts=texts)
return TrainingTaskForSentenceSimilarity(texts=texts, label=label_strategy)

if isinstance(label, LabelQuestionUnification):
if isinstance(label, (LabelQuestionUnification, RatingQuestionUnification)):
if label_strategy is not None:
raise ValueError("label_strategy is already defined via Unification class.")
else:
Expand All @@ -573,8 +573,8 @@ def for_sentence_similarity(
_LOGGER.info(f"No label strategy defined. Using default strategy for {type(label)}.")
if isinstance(label, LabelQuestion):
label = LabelQuestionUnification(**unification_kwargs)
elif isinstance(label, RankingQuestion):
label = RankingQuestionUnification(**unification_kwargs)
elif isinstance(label, RatingQuestion):
label = RatingQuestionUnification(**unification_kwargs)
else:
raise ValueError(f"Label type {type(label)} is not supported.")
return TrainingTaskForSentenceSimilarity(texts=texts, label=label)
Expand Down Expand Up @@ -1485,7 +1485,7 @@ class TrainingTaskForSentenceSimilarity(BaseModel, TrainingData):
],
] = None
texts: Optional[List[TextField]] = None
label: Optional[Union[LabelQuestionUnification, RankingQuestionUnification]] = None
label: Optional[Union[LabelQuestionUnification, RatingQuestionUnification]] = None

@property
def supported_frameworks(self):
Expand Down Expand Up @@ -1537,9 +1537,15 @@ def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, Any]]:
else:
_all_labels.add(sample["label"])

self.label = LabelQuestionUnification(
question=LabelQuestion(name="custom_func", labels=list(_all_labels))
)
if self.label is None:
labels = list(_all_labels)
if isinstance(labels[0], int):
label = RatingQuestionUnification(
question=RatingQuestion(name="custom_func", values=labels), strategy="majority"
)
else:
label = LabelQuestionUnification(question=LabelQuestion(name="custom_func", labels=labels))
self.label = label

return outputs

Expand All @@ -1555,16 +1561,17 @@ def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, Any]]:
for example in formatted_data:
record = {}
for k, v in new_keys.items():
value = example[k]
if v == "label":
value = example[v]
# At this point the label must be either an int or a float, determine which one is it.
if value.lstrip("-").isdigit():
value = int(value)
else:
value = float(value)
if isinstance(self.label, RankingQuestionUnification):
max_value = max([float(x) for x in self.label.question.__all_labels__])
value = (value / 100) * float(max_value)
if isinstance(value, str):
if value.lstrip("-").isdigit():
value = int(value)
else:
value = float(value)
else:
value = example[k]

record[v] = value
outputs.append(record)

Expand Down
19 changes: 15 additions & 4 deletions tests/integration/client/feedback/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def formatting_func_sentence_transformers(sample: dict):
elif labels[0] == "c":
return [
{"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 1},
{"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 0},
{"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 2},
]


Expand Down Expand Up @@ -180,7 +180,7 @@ def formatting_func_sentence_transformers_case_3_a(sample):
elif labels[0] == "b":
return {"sentence": sample["text"], "label": 1}
elif labels[0] == "c":
return [{"sentence": sample["text"], "label": 1}, {"sentence": sample["text"], "label": 0}]
return [{"sentence": sample["text"], "label": 1}, {"sentence": sample["text"], "label": 2}]


def formatting_func_sentence_transformers_case_3_b(sample):
Expand All @@ -202,7 +202,7 @@ def formatting_func_sentence_transformers_case_3_b(sample):
elif labels[0] == "c":
return [
{"sentence-1": sample["text"], "sentence-2": sample["text"], "sentence-3": sample["text"], "label": 1},
{"sentence-1": sample["text"], "sentence-2": sample["text"], "sentence-3": sample["text"], "label": 0},
{"sentence-1": sample["text"], "sentence-2": sample["text"], "sentence-3": sample["text"], "label": 2},
]


Expand All @@ -221,6 +221,17 @@ def formatting_func_sentence_transformers_case_4(sample):
return [{"sentence-1": sample["text"], "sentence-2": sample["text"], "sentence-3": sample["text"]}] * 2


def formatting_func_sentence_transformers_rating_question(sample: dict):
# Formatting function to test the RatingQuestion
labels = [
annotation["value"]
for annotation in sample["question-2"]
if annotation["status"] == "submitted" and annotation["value"] is not None
]
if labels:
return {"sentence-1": sample["text"], "sentence-2": sample["text"], "label": labels[0]}


def model_card_pattern(framework: Framework, training_task: Any) -> str:
# def model_card_pattern() -> str:
# def inner(framework: Framework, training_task: Any):
Expand Down Expand Up @@ -277,7 +288,7 @@ def formatting_func_sentence_transformers(sample: dict):
elif labels[0] == "c":
return [
{"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 1},
{"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 0},
{"sentence-1": sample["text"], "sentence-2": sample["text"], "label": 2},
]
task = TrainingTask.for_sentence_similarity(formatting_func=formatting_func_sentence_transformers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@

import pytest
from argilla.client.feedback.dataset import FeedbackDataset
from argilla.client.feedback.schemas.fields import TextField
from argilla.client.feedback.schemas.questions import LabelQuestion
from argilla.client.feedback.schemas.records import FeedbackRecord
from argilla.client.feedback.training.base import ArgillaTrainer
from argilla.client.feedback.training.schemas import (
LabelQuestion,
LabelQuestionUnification,
RatingQuestion,
RatingQuestionUnification,
TrainingTask,
)
from sentence_transformers import CrossEncoder, InputExample, SentenceTransformer
Expand All @@ -31,6 +37,7 @@
formatting_func_sentence_transformers_case_3_a,
formatting_func_sentence_transformers_case_3_b,
formatting_func_sentence_transformers_case_4,
formatting_func_sentence_transformers_rating_question,
)
from tests.integration.training.helpers import train_with_cleanup

Expand Down Expand Up @@ -60,6 +67,7 @@ def formatting_func_errored(sample):
formatting_func_sentence_transformers_case_2,
formatting_func_sentence_transformers_case_3_b,
formatting_func_sentence_transformers_case_4,
formatting_func_sentence_transformers_rating_question,
],
)
@pytest.mark.usefixtures(
Expand All @@ -84,7 +92,12 @@ def test_prepare_for_training_sentence_transformers(
)
dataset.add_records(records=feedback_dataset_records * 2)

task = TrainingTask.for_sentence_similarity(formatting_func=formatting_func)
if formatting_func.__name__ == "formatting_func_sentence_transformers_rating_question":
label_strategy = RatingQuestionUnification(question=dataset.question_by_name("question-2"), strategy="majority")
task = TrainingTask.for_sentence_similarity(formatting_func=formatting_func, label_strategy=label_strategy)
else:
task = TrainingTask.for_sentence_similarity(formatting_func=formatting_func)

train_dataset = dataset.prepare_for_training(framework=__FRAMEWORK__, task=task)

assert isinstance(train_dataset, list)
Expand Down Expand Up @@ -128,6 +141,38 @@ def test_prepare_for_training_sentence_transformers(
assert len(eval_trainer.predict(["first sentence", ["to compare", "another one"]])) == 2


def test_task_with_different_naming():
dataset = FeedbackDataset(
fields=[
TextField(name="query"),
TextField(name="retrieved_document_1"),
],
questions=[
LabelQuestion(
name="sentence_similarity",
labels={"0": "Not-similar", "1": "Missing-information", "2": "Similar"},
),
],
)

records = [
FeedbackRecord(
fields={"query": "some text", "retrieved_document_1": "retrieved data"},
responses=[{"values": {"sentence_similarity": {"value": value}}}],
)
for value in ["0", "1", "2"]
]

dataset.add_records(records)

task = TrainingTask.for_sentence_similarity(
texts=[dataset.field_by_name("query"), dataset.field_by_name("retrieved_document_1")],
label=dataset.question_by_name("sentence_similarity"),
)
train_dataset = dataset.prepare_for_training(framework=__FRAMEWORK__, task=task)
assert all(example.label == label for example, label in zip(train_dataset, [0, 1, 2]))


@pytest.mark.parametrize("cross_encoder", [False, True])
@pytest.mark.parametrize("formatting_func", [formatting_func_sentence_transformers_case_3_a, formatting_func_errored])
@pytest.mark.usefixtures(
Expand Down

0 comments on commit da043b1

Please sign in to comment.