From 8f24de67863acda33f936b96a410f59034ab53c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Jun 2024 23:13:39 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- argilla/src/argilla/cli/datasets/__main__.py | 8 +- argilla/src/argilla/client/datasets.py | 32 +++++--- argilla/src/argilla/client/feedback/config.py | 6 +- .../client/feedback/dataset/local/mixins.py | 82 +++++++++++-------- .../feedback/integrations/textdescriptives.py | 10 +-- .../client/feedback/schemas/metadata.py | 15 ++-- .../client/feedback/schemas/remote/records.py | 26 +++--- .../argilla/client/sdk/text2text/models.py | 14 ++-- .../client/sdk/text_classification/models.py | 20 +++-- .../client/sdk/token_classification/models.py | 22 +++-- .../argilla/training/autotrain_advanced.py | 8 +- .../text_classification/test_weak_labels.py | 28 ++++--- 12 files changed, 152 insertions(+), 119 deletions(-) diff --git a/argilla/src/argilla/cli/datasets/__main__.py b/argilla/src/argilla/cli/datasets/__main__.py index c6c022609f1..08ddbd6959e 100644 --- a/argilla/src/argilla/cli/datasets/__main__.py +++ b/argilla/src/argilla/cli/datasets/__main__.py @@ -44,9 +44,11 @@ def callback( dataset = FeedbackDataset.from_argilla(name=name, workspace=workspace) except ValueError as e: echo_in_panel( - f"`FeedbackDataset` with name={name} not found in Argilla. Try using '--workspace' option." - if not workspace - else f"`FeedbackDataset with name={name} and workspace={workspace} not found in Argilla.", + ( + f"`FeedbackDataset` with name={name} not found in Argilla. Try using '--workspace' option." + if not workspace + else f"`FeedbackDataset with name={name} and workspace={workspace} not found in Argilla." + ), title="Dataset not found", title_align="left", success=False, diff --git a/argilla/src/argilla/client/datasets.py b/argilla/src/argilla/client/datasets.py index 48ae55b3f8f..1dade709214 100644 --- a/argilla/src/argilla/client/datasets.py +++ b/argilla/src/argilla/client/datasets.py @@ -744,16 +744,20 @@ def _to_datasets_dict(self) -> Dict: for key in self._RECORD_TYPE.__fields__: if key == "prediction": ds_dict[key] = [ - [{"label": pred[0], "score": pred[1]} for pred in rec.prediction] - if rec.prediction is not None - else None + ( + [{"label": pred[0], "score": pred[1]} for pred in rec.prediction] + if rec.prediction is not None + else None + ) for rec in self._records ] elif key == "explanation": ds_dict[key] = [ - {key: list(map(dict, tokattrs)) for key, tokattrs in rec.explanation.items()} - if rec.explanation is not None - else None + ( + {key: list(map(dict, tokattrs)) for key, tokattrs in rec.explanation.items()} + if rec.explanation is not None + else None + ) for rec in self._records ] elif key == "id": @@ -1255,9 +1259,11 @@ def entities_to_dict( if entities is None: return None return [ - {"label": ent[0], "start": ent[1], "end": ent[2]} - if len(ent) == 3 - else {"label": ent[0], "start": ent[1], "end": ent[2], "score": ent[3]} + ( + {"label": ent[0], "start": ent[1], "end": ent[2]} + if len(ent) == 3 + else {"label": ent[0], "start": ent[1], "end": ent[2], "score": ent[3]} + ) for ent in entities ] @@ -1281,9 +1287,11 @@ def __entities_to_tuple__( entities, ) -> List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]]: return [ - (ent["label"], ent["start"], ent["end"]) - if len(ent) == 3 - else (ent["label"], ent["start"], ent["end"], ent["score"] or 0.0) + ( + (ent["label"], ent["start"], ent["end"]) + if len(ent) == 3 + else (ent["label"], ent["start"], ent["end"], ent["score"] or 0.0) + ) for ent in entities ] diff --git a/argilla/src/argilla/client/feedback/config.py b/argilla/src/argilla/client/feedback/config.py index a559164eefa..4a4c08faeb7 100644 --- a/argilla/src/argilla/client/feedback/config.py +++ b/argilla/src/argilla/client/feedback/config.py @@ -43,9 +43,9 @@ class DatasetConfig(BaseModel): fields: List[AllowedFieldTypes] questions: List[Annotated[AllowedQuestionTypes, Field(..., discriminator="type")]] guidelines: Optional[str] = None - metadata_properties: Optional[ - List[Annotated[AllowedMetadataPropertyTypes, Field(..., discriminator="type")]] - ] = None + metadata_properties: Optional[List[Annotated[AllowedMetadataPropertyTypes, Field(..., discriminator="type")]]] = ( + None + ) allow_extra_metadata: bool = True vectors_settings: Optional[List[VectorSettings]] = None diff --git a/argilla/src/argilla/client/feedback/dataset/local/mixins.py b/argilla/src/argilla/client/feedback/dataset/local/mixins.py index 48b1f1ff763..0ffe6e51933 100644 --- a/argilla/src/argilla/client/feedback/dataset/local/mixins.py +++ b/argilla/src/argilla/client/feedback/dataset/local/mixins.py @@ -504,23 +504,27 @@ def for_text_classification( return cls( fields=[TextField(name="text", use_markdown=use_markdown)], questions=[ - LabelQuestion( - name="label", - labels=labels, - description=description, - ) - if not multi_label - else MultiLabelQuestion( - name="label", - labels=labels, - description=description, + ( + LabelQuestion( + name="label", + labels=labels, + description=description, + ) + if not multi_label + else MultiLabelQuestion( + name="label", + labels=labels, + description=description, + ) ) ], - guidelines=guidelines - if guidelines is not None - else default_guidelines - if multi_label - else default_guidelines.replace("one or more labels", "one label"), + guidelines=( + guidelines + if guidelines is not None + else ( + default_guidelines if multi_label else default_guidelines.replace("one or more labels", "one label") + ) + ), metadata_properties=metadata_properties, vectors_settings=vectors_settings, ) @@ -739,11 +743,15 @@ def for_supervised_fine_tuning( name="response", description="Write the response to the instruction.", use_markdown=use_markdown ) ], - guidelines=guidelines - if guidelines is not None - else default_guidelines + " Take the context into account when writing the response." - if context - else default_guidelines, + guidelines=( + guidelines + if guidelines is not None + else ( + default_guidelines + " Take the context into account when writing the response." + if context + else default_guidelines + ) + ), metadata_properties=metadata_properties, vectors_settings=vectors_settings, ) @@ -977,23 +985,27 @@ def for_multi_modal_classification( return cls( fields=[TextField(name="content", use_markdown=True, required=True)], questions=[ - LabelQuestion( - name="label", - labels=labels, - description=description, - ) - if not multi_label - else MultiLabelQuestion( - name="label", - labels=labels, - description=description, + ( + LabelQuestion( + name="label", + labels=labels, + description=description, + ) + if not multi_label + else MultiLabelQuestion( + name="label", + labels=labels, + description=description, + ) ) ], - guidelines=guidelines - if guidelines is not None - else default_guidelines - if multi_label - else default_guidelines.replace("one or more labels", "one label"), + guidelines=( + guidelines + if guidelines is not None + else ( + default_guidelines if multi_label else default_guidelines.replace("one or more labels", "one label") + ) + ), metadata_properties=metadata_properties, vectors_settings=vectors_settings, ) diff --git a/argilla/src/argilla/client/feedback/integrations/textdescriptives.py b/argilla/src/argilla/client/feedback/integrations/textdescriptives.py index a5e552250db..70e5a564ee6 100644 --- a/argilla/src/argilla/client/feedback/integrations/textdescriptives.py +++ b/argilla/src/argilla/client/feedback/integrations/textdescriptives.py @@ -308,11 +308,11 @@ def _add_text_descriptives_to_metadata( filtered_metrics = {key: value for key, value in metrics.items() if not pd.isna(value)} if metadata_prop_types is not None: filtered_metrics = { - key: int(value) - if metadata_prop_types.get(key) == "integer" - else float(value) - if metadata_prop_types.get(key) == "float" - else value + key: ( + int(value) + if metadata_prop_types.get(key) == "integer" + else float(value) if metadata_prop_types.get(key) == "float" else value + ) for key, value in filtered_metrics.items() } record.metadata.update(filtered_metrics) diff --git a/argilla/src/argilla/client/feedback/schemas/metadata.py b/argilla/src/argilla/client/feedback/schemas/metadata.py index d39f46094c5..65e96592ab6 100644 --- a/argilla/src/argilla/client/feedback/schemas/metadata.py +++ b/argilla/src/argilla/client/feedback/schemas/metadata.py @@ -71,8 +71,7 @@ def title_must_have_value(cls, v: Optional[str], values: Dict[str, Any]) -> str: @property @abstractmethod - def server_settings(self) -> Dict[str, Any]: - ... + def server_settings(self) -> Dict[str, Any]: ... def to_server_payload(self) -> Dict[str, Any]: return { @@ -84,20 +83,17 @@ def to_server_payload(self) -> Dict[str, Any]: @property @abstractmethod - def _pydantic_field_with_validator(self) -> Tuple[Dict[str, Tuple[Any, ...]], Dict[str, Callable]]: - ... + def _pydantic_field_with_validator(self) -> Tuple[Dict[str, Tuple[Any, ...]], Dict[str, Callable]]: ... @abstractmethod def _validate_filter(self, metadata_filter: "MetadataFilters") -> None: pass @abstractmethod - def _check_allowed_value_type(self, value: Any) -> Any: - ... + def _check_allowed_value_type(self, value: Any) -> Any: ... @abstractmethod - def _validator(self, value: Any) -> Any: - ... + def _validator(self, value: Any) -> Any: ... def _validator_definition(schema: MetadataPropertySchema) -> Dict[str, Any]: @@ -395,8 +391,7 @@ class Config: @property @abstractmethod - def query_string(self) -> str: - ... + def query_string(self) -> str: ... class TermsMetadataFilter(MetadataFilterSchema): diff --git a/argilla/src/argilla/client/feedback/schemas/remote/records.py b/argilla/src/argilla/client/feedback/schemas/remote/records.py index 469e9d04f82..77298858c49 100644 --- a/argilla/src/argilla/client/feedback/schemas/remote/records.py +++ b/argilla/src/argilla/client/feedback/schemas/remote/records.py @@ -207,9 +207,9 @@ def __updated_record_data(self) -> None: updated_record = self.from_api( payload=response.parsed, - question_id_to_name={value: key for key, value in self.question_name_to_id.items()} - if self.question_name_to_id - else None, + question_id_to_name=( + {value: key for key, value in self.question_name_to_id.items()} if self.question_name_to_id else None + ), client=self.client, ) @@ -306,15 +306,17 @@ def from_api( id=payload.id, client=client, fields=payload.fields, - responses=[RemoteResponseSchema.from_api(response) for response in payload.responses] - if payload.responses - else [], - suggestions=[ - RemoteSuggestionSchema.from_api(suggestion, question_id_to_name=question_id_to_name, client=client) - for suggestion in payload.suggestions - ] - if payload.suggestions - else [], + responses=( + [RemoteResponseSchema.from_api(response) for response in payload.responses] if payload.responses else [] + ), + suggestions=( + [ + RemoteSuggestionSchema.from_api(suggestion, question_id_to_name=question_id_to_name, client=client) + for suggestion in payload.suggestions + ] + if payload.suggestions + else [] + ), metadata=payload.metadata if payload.metadata else {}, vectors=payload.vectors if payload.vectors else {}, external_id=payload.external_id if payload.external_id else None, diff --git a/argilla/src/argilla/client/sdk/text2text/models.py b/argilla/src/argilla/client/sdk/text2text/models.py index d6e7a7a34ad..748feef9963 100644 --- a/argilla/src/argilla/client/sdk/text2text/models.py +++ b/argilla/src/argilla/client/sdk/text2text/models.py @@ -48,9 +48,11 @@ def from_client(cls, record: ClientText2TextRecord): if record.prediction is not None: prediction = Text2TextAnnotation( sentences=[ - Text2TextPrediction(text=pred[0], score=pred[1]) - if isinstance(pred, tuple) - else Text2TextPrediction(text=pred) + ( + Text2TextPrediction(text=pred[0], score=pred[1]) + if isinstance(pred, tuple) + else Text2TextPrediction(text=pred) + ) for pred in record.prediction ], agent=record.prediction_agent or MACHINE_NAME, @@ -81,9 +83,9 @@ class Text2TextRecord(CreationText2TextRecord): def to_client(self) -> ClientText2TextRecord: return ClientText2TextRecord( text=self.text, - prediction=[(sentence.text, sentence.score) for sentence in self.prediction.sentences] - if self.prediction - else None, + prediction=( + [(sentence.text, sentence.score) for sentence in self.prediction.sentences] if self.prediction else None + ), prediction_agent=self.prediction.agent if self.prediction else None, annotation=self.annotation.sentences[0].text if self.annotation else None, annotation_agent=self.annotation.agent if self.annotation else None, diff --git a/argilla/src/argilla/client/sdk/text_classification/models.py b/argilla/src/argilla/client/sdk/text_classification/models.py index eb6e64c41f3..9ccae010b8e 100644 --- a/argilla/src/argilla/client/sdk/text_classification/models.py +++ b/argilla/src/argilla/client/sdk/text_classification/models.py @@ -100,19 +100,21 @@ def to_client(self) -> ClientTextClassificationRecord: multi_label=self.multi_label, status=self.status, metadata=self.metadata or {}, - prediction=[(label.class_label, label.score) for label in self.prediction.labels] - if self.prediction - else None, + prediction=( + [(label.class_label, label.score) for label in self.prediction.labels] if self.prediction else None + ), prediction_agent=self.prediction.agent if self.prediction else None, annotation=annotations, annotation_agent=self.annotation.agent if self.annotation else None, vectors=self._to_client_vectors(self.vectors), - explanation={ - key: [ClientTokenAttributions.parse_obj(attribution) for attribution in attributions] - for key, attributions in self.explanation.items() - } - if self.explanation - else None, + explanation=( + { + key: [ClientTokenAttributions.parse_obj(attribution) for attribution in attributions] + for key, attributions in self.explanation.items() + } + if self.explanation + else None + ), metrics=self.metrics or None, search_keywords=self.search_keywords or None, ) diff --git a/argilla/src/argilla/client/sdk/token_classification/models.py b/argilla/src/argilla/client/sdk/token_classification/models.py index a0f4becd17b..8fa2f64aac7 100644 --- a/argilla/src/argilla/client/sdk/token_classification/models.py +++ b/argilla/src/argilla/client/sdk/token_classification/models.py @@ -58,9 +58,11 @@ def from_client(cls, record: ClientTokenClassificationRecord): if record.prediction is not None: prediction = TokenClassificationAnnotation( entities=[ - EntitySpan(label=ent[0], start=ent[1], end=ent[2]) - if len(ent) == 3 - else EntitySpan(label=ent[0], start=ent[1], end=ent[2], score=ent[3]) + ( + EntitySpan(label=ent[0], start=ent[1], end=ent[2]) + if len(ent) == 3 + else EntitySpan(label=ent[0], start=ent[1], end=ent[2], score=ent[3]) + ) for ent in record.prediction ], agent=record.prediction_agent or MACHINE_NAME, @@ -94,13 +96,15 @@ def to_client(self) -> ClientTokenClassificationRecord: return ClientTokenClassificationRecord( text=self.text, tokens=self.tokens, - prediction=[(ent.label, ent.start, ent.end, ent.score) for ent in self.prediction.entities] - if self.prediction - else None, + prediction=( + [(ent.label, ent.start, ent.end, ent.score) for ent in self.prediction.entities] + if self.prediction + else None + ), prediction_agent=self.prediction.agent if self.prediction else None, - annotation=[(ent.label, ent.start, ent.end) for ent in self.annotation.entities] - if self.annotation - else None, + annotation=( + [(ent.label, ent.start, ent.end) for ent in self.annotation.entities] if self.annotation else None + ), annotation_agent=self.annotation.agent if self.annotation else None, vectors=self._to_client_vectors(self.vectors), id=self.id, diff --git a/argilla/src/argilla/training/autotrain_advanced.py b/argilla/src/argilla/training/autotrain_advanced.py index e847f6bd270..b7acf768ea7 100644 --- a/argilla/src/argilla/training/autotrain_advanced.py +++ b/argilla/src/argilla/training/autotrain_advanced.py @@ -65,9 +65,11 @@ def get_project_cost(self): token=self.HF_TOKEN, task=self.task, num_samples=self._num_samples, - num_models=self.trainer_kwargs["autotrain"][0]["num_models"] - if self._model.lower() == "autotrain" - else len(self.trainer_kwargs["hub_model"]), + num_models=( + self.trainer_kwargs["autotrain"][0]["num_models"] + if self._model.lower() == "autotrain" + else len(self.trainer_kwargs["hub_model"]) + ), ) def initialize_project(self): diff --git a/argilla/tests/integration/labeling/text_classification/test_weak_labels.py b/argilla/tests/integration/labeling/text_classification/test_weak_labels.py index 358243fe7a8..e26e81ce00e 100644 --- a/argilla/tests/integration/labeling/text_classification/test_weak_labels.py +++ b/argilla/tests/integration/labeling/text_classification/test_weak_labels.py @@ -53,12 +53,14 @@ def log_dataset(mocked_client: SecuredClient) -> str: CreationTextClassificationRecord.parse_obj( { "inputs": {"text": text}, - "annotation": { - "labels": [{"class": label, "score": 1}], - "agent": "test", - } - if label is not None - else None, + "annotation": ( + { + "labels": [{"class": label, "score": 1}], + "agent": "test", + } + if label is not None + else None + ), "id": idx, } ) @@ -114,12 +116,14 @@ def log_multilabel_dataset(mocked_client: SecuredClient) -> str: CreationTextClassificationRecord.parse_obj( { "inputs": {"text": text}, - "annotation": { - "labels": [{"class": label, "score": 1} for label in labels], - "agent": "test", - } - if labels is not None - else None, + "annotation": ( + { + "labels": [{"class": label, "score": 1} for label in labels], + "agent": "test", + } + if labels is not None + else None + ), "multi_label": True, "id": idx, }