Skip to content

Commit

Permalink
SetFit dependency >= 1.0.0 (#4467)
Browse files Browse the repository at this point in the history
Closes #4435.

Updating SetFit to bump the version >= 1.0.0

**Type of change**
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [X] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [x] Running the SetFit tutorial
- [x] Running SetFit tests on the suite
  • Loading branch information
ignacioct authored Jan 31, 2024
1 parent f15cb07 commit bb51eda
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion environment_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dependencies:
- transformers[torch]>=4.30.0 # <- required for DPO with TRL
- evaluate
- seqeval
- setfit>=0.7.0,<1.0.0
- setfit>=1.0.0
- span_marker
- openai>=0.27.10,<1.0.0
- peft
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ integrations = [
"evaluate",
"seqeval",
"sentence-transformers",
"setfit>=0.7.0,<1.0.0",
"setfit>=1.0.0",
"span_marker",
"sentence-transformers>=2.0.0,<3.0.0",
"textdescriptives>=2.7.0,<3.0.0",
Expand Down
7 changes: 6 additions & 1 deletion src/argilla/training/setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ArgillaSetFitTrainer(ArgillaTransformersTrainer):
_logger.setLevel(logging.INFO)

def __init__(self, *args, **kwargs):
require_dependencies(["torch", "datasets", "transformers", "setfit>=0.6"])
require_dependencies(["torch", "datasets", "transformers", "setfit>=1.0.0"])
if kwargs.get("model") is None and "model" in kwargs:
kwargs["model"] = "all-MiniLM-L6-v2"
self._logger.warning(f"No model defined. Using the default model {kwargs['model']}.")
Expand All @@ -39,6 +39,11 @@ def __init__(self, *args, **kwargs):
raise NotImplementedError("SetFit only supports the `TextClassification` task.")

if self._multi_label:
# We shall rename binarized_label as label, we need to remove the column that was previously called label.
# This change is due to SetFit version >=1.0.0
self._dataset = self._dataset.remove_columns("label")
self._eval_dataset = self._eval_dataset.remove_columns("label")
self._train_dataset = self._train_dataset.remove_columns("label")
self._column_mapping = {"text": "text", "binarized_label": "label"}
self.multi_target_strategy = "one-vs-rest"
else:
Expand Down

0 comments on commit bb51eda

Please sign in to comment.