From bb51edacef842812db35f7709659b6b01db5c7af Mon Sep 17 00:00:00 2001 From: Ignacio Talavera Date: Wed, 31 Jan 2024 13:27:27 +0100 Subject: [PATCH] SetFit dependency >= 1.0.0 (#4467) Closes #4435. Updating SetFit to bump the version >= 1.0.0 **Type of change** - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [X] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [x] Running the SetFit tutorial - [x] Running SetFit tests on the suite --- environment_dev.yml | 2 +- pyproject.toml | 2 +- src/argilla/training/setfit.py | 7 ++++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/environment_dev.yml b/environment_dev.yml index 8056f88224..2ecf1ed8e0 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -52,7 +52,7 @@ dependencies: - transformers[torch]>=4.30.0 # <- required for DPO with TRL - evaluate - seqeval - - setfit>=0.7.0,<1.0.0 + - setfit>=1.0.0 - span_marker - openai>=0.27.10,<1.0.0 - peft diff --git a/pyproject.toml b/pyproject.toml index 2d8cb4b5e5..cb780d653b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ integrations = [ "evaluate", "seqeval", "sentence-transformers", - "setfit>=0.7.0,<1.0.0", + "setfit>=1.0.0", "span_marker", "sentence-transformers>=2.0.0,<3.0.0", "textdescriptives>=2.7.0,<3.0.0", diff --git a/src/argilla/training/setfit.py b/src/argilla/training/setfit.py index 418988b47b..fb2c537db4 100644 --- a/src/argilla/training/setfit.py +++ b/src/argilla/training/setfit.py @@ -27,7 +27,7 @@ class ArgillaSetFitTrainer(ArgillaTransformersTrainer): _logger.setLevel(logging.INFO) def __init__(self, *args, **kwargs): - require_dependencies(["torch", "datasets", "transformers", "setfit>=0.6"]) + require_dependencies(["torch", "datasets", "transformers", "setfit>=1.0.0"]) if kwargs.get("model") is None and "model" in kwargs: kwargs["model"] = "all-MiniLM-L6-v2" self._logger.warning(f"No model defined. Using the default model {kwargs['model']}.") @@ -39,6 +39,11 @@ def __init__(self, *args, **kwargs): raise NotImplementedError("SetFit only supports the `TextClassification` task.") if self._multi_label: + # We shall rename binarized_label as label, we need to remove the column that was previously called label. + # This change is due to SetFit version >=1.0.0 + self._dataset = self._dataset.remove_columns("label") + self._eval_dataset = self._eval_dataset.remove_columns("label") + self._train_dataset = self._train_dataset.remove_columns("label") self._column_mapping = {"text": "text", "binarized_label": "label"} self.multi_target_strategy = "one-vs-rest" else: