diff --git a/environment_dev.yml b/environment_dev.yml index 8056f88224..2ecf1ed8e0 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -52,7 +52,7 @@ dependencies: - transformers[torch]>=4.30.0 # <- required for DPO with TRL - evaluate - seqeval - - setfit>=0.7.0,<1.0.0 + - setfit>=1.0.0 - span_marker - openai>=0.27.10,<1.0.0 - peft diff --git a/pyproject.toml b/pyproject.toml index 2d8cb4b5e5..cb780d653b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ integrations = [ "evaluate", "seqeval", "sentence-transformers", - "setfit>=0.7.0,<1.0.0", + "setfit>=1.0.0", "span_marker", "sentence-transformers>=2.0.0,<3.0.0", "textdescriptives>=2.7.0,<3.0.0", diff --git a/src/argilla/training/setfit.py b/src/argilla/training/setfit.py index 418988b47b..fb2c537db4 100644 --- a/src/argilla/training/setfit.py +++ b/src/argilla/training/setfit.py @@ -27,7 +27,7 @@ class ArgillaSetFitTrainer(ArgillaTransformersTrainer): _logger.setLevel(logging.INFO) def __init__(self, *args, **kwargs): - require_dependencies(["torch", "datasets", "transformers", "setfit>=0.6"]) + require_dependencies(["torch", "datasets", "transformers", "setfit>=1.0.0"]) if kwargs.get("model") is None and "model" in kwargs: kwargs["model"] = "all-MiniLM-L6-v2" self._logger.warning(f"No model defined. Using the default model {kwargs['model']}.") @@ -39,6 +39,11 @@ def __init__(self, *args, **kwargs): raise NotImplementedError("SetFit only supports the `TextClassification` task.") if self._multi_label: + # We shall rename binarized_label as label, we need to remove the column that was previously called label. + # This change is due to SetFit version >=1.0.0 + self._dataset = self._dataset.remove_columns("label") + self._eval_dataset = self._eval_dataset.remove_columns("label") + self._train_dataset = self._train_dataset.remove_columns("label") self._column_mapping = {"text": "text", "binarized_label": "label"} self.multi_target_strategy = "one-vs-rest" else: