Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 10, 2024
1 parent be22203 commit 8f24de6
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 119 deletions.
8 changes: 5 additions & 3 deletions argilla/src/argilla/cli/datasets/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ def callback(
dataset = FeedbackDataset.from_argilla(name=name, workspace=workspace)
except ValueError as e:
echo_in_panel(
f"`FeedbackDataset` with name={name} not found in Argilla. Try using '--workspace' option."
if not workspace
else f"`FeedbackDataset with name={name} and workspace={workspace} not found in Argilla.",
(
f"`FeedbackDataset` with name={name} not found in Argilla. Try using '--workspace' option."
if not workspace
else f"`FeedbackDataset with name={name} and workspace={workspace} not found in Argilla."
),
title="Dataset not found",
title_align="left",
success=False,
Expand Down
32 changes: 20 additions & 12 deletions argilla/src/argilla/client/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,16 +744,20 @@ def _to_datasets_dict(self) -> Dict:
for key in self._RECORD_TYPE.__fields__:
if key == "prediction":
ds_dict[key] = [
[{"label": pred[0], "score": pred[1]} for pred in rec.prediction]
if rec.prediction is not None
else None
(
[{"label": pred[0], "score": pred[1]} for pred in rec.prediction]
if rec.prediction is not None
else None
)
for rec in self._records
]
elif key == "explanation":
ds_dict[key] = [
{key: list(map(dict, tokattrs)) for key, tokattrs in rec.explanation.items()}
if rec.explanation is not None
else None
(
{key: list(map(dict, tokattrs)) for key, tokattrs in rec.explanation.items()}
if rec.explanation is not None
else None
)
for rec in self._records
]
elif key == "id":
Expand Down Expand Up @@ -1255,9 +1259,11 @@ def entities_to_dict(
if entities is None:
return None
return [
{"label": ent[0], "start": ent[1], "end": ent[2]}
if len(ent) == 3
else {"label": ent[0], "start": ent[1], "end": ent[2], "score": ent[3]}
(
{"label": ent[0], "start": ent[1], "end": ent[2]}
if len(ent) == 3
else {"label": ent[0], "start": ent[1], "end": ent[2], "score": ent[3]}
)
for ent in entities
]

Expand All @@ -1281,9 +1287,11 @@ def __entities_to_tuple__(
entities,
) -> List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]]:
return [
(ent["label"], ent["start"], ent["end"])
if len(ent) == 3
else (ent["label"], ent["start"], ent["end"], ent["score"] or 0.0)
(
(ent["label"], ent["start"], ent["end"])
if len(ent) == 3
else (ent["label"], ent["start"], ent["end"], ent["score"] or 0.0)
)
for ent in entities
]

Expand Down
6 changes: 3 additions & 3 deletions argilla/src/argilla/client/feedback/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class DatasetConfig(BaseModel):
fields: List[AllowedFieldTypes]
questions: List[Annotated[AllowedQuestionTypes, Field(..., discriminator="type")]]
guidelines: Optional[str] = None
metadata_properties: Optional[
List[Annotated[AllowedMetadataPropertyTypes, Field(..., discriminator="type")]]
] = None
metadata_properties: Optional[List[Annotated[AllowedMetadataPropertyTypes, Field(..., discriminator="type")]]] = (
None
)
allow_extra_metadata: bool = True
vectors_settings: Optional[List[VectorSettings]] = None

Expand Down
82 changes: 47 additions & 35 deletions argilla/src/argilla/client/feedback/dataset/local/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,23 +504,27 @@ def for_text_classification(
return cls(
fields=[TextField(name="text", use_markdown=use_markdown)],
questions=[
LabelQuestion(
name="label",
labels=labels,
description=description,
)
if not multi_label
else MultiLabelQuestion(
name="label",
labels=labels,
description=description,
(
LabelQuestion(
name="label",
labels=labels,
description=description,
)
if not multi_label
else MultiLabelQuestion(
name="label",
labels=labels,
description=description,
)
)
],
guidelines=guidelines
if guidelines is not None
else default_guidelines
if multi_label
else default_guidelines.replace("one or more labels", "one label"),
guidelines=(
guidelines
if guidelines is not None
else (
default_guidelines if multi_label else default_guidelines.replace("one or more labels", "one label")
)
),
metadata_properties=metadata_properties,
vectors_settings=vectors_settings,
)
Expand Down Expand Up @@ -739,11 +743,15 @@ def for_supervised_fine_tuning(
name="response", description="Write the response to the instruction.", use_markdown=use_markdown
)
],
guidelines=guidelines
if guidelines is not None
else default_guidelines + " Take the context into account when writing the response."
if context
else default_guidelines,
guidelines=(
guidelines
if guidelines is not None
else (
default_guidelines + " Take the context into account when writing the response."
if context
else default_guidelines
)
),
metadata_properties=metadata_properties,
vectors_settings=vectors_settings,
)
Expand Down Expand Up @@ -977,23 +985,27 @@ def for_multi_modal_classification(
return cls(
fields=[TextField(name="content", use_markdown=True, required=True)],
questions=[
LabelQuestion(
name="label",
labels=labels,
description=description,
)
if not multi_label
else MultiLabelQuestion(
name="label",
labels=labels,
description=description,
(
LabelQuestion(
name="label",
labels=labels,
description=description,
)
if not multi_label
else MultiLabelQuestion(
name="label",
labels=labels,
description=description,
)
)
],
guidelines=guidelines
if guidelines is not None
else default_guidelines
if multi_label
else default_guidelines.replace("one or more labels", "one label"),
guidelines=(
guidelines
if guidelines is not None
else (
default_guidelines if multi_label else default_guidelines.replace("one or more labels", "one label")
)
),
metadata_properties=metadata_properties,
vectors_settings=vectors_settings,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,11 @@ def _add_text_descriptives_to_metadata(
filtered_metrics = {key: value for key, value in metrics.items() if not pd.isna(value)}
if metadata_prop_types is not None:
filtered_metrics = {
key: int(value)
if metadata_prop_types.get(key) == "integer"
else float(value)
if metadata_prop_types.get(key) == "float"
else value
key: (
int(value)
if metadata_prop_types.get(key) == "integer"
else float(value) if metadata_prop_types.get(key) == "float" else value
)
for key, value in filtered_metrics.items()
}
record.metadata.update(filtered_metrics)
Expand Down
15 changes: 5 additions & 10 deletions argilla/src/argilla/client/feedback/schemas/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def title_must_have_value(cls, v: Optional[str], values: Dict[str, Any]) -> str:

@property
@abstractmethod
def server_settings(self) -> Dict[str, Any]:
...
def server_settings(self) -> Dict[str, Any]: ...

def to_server_payload(self) -> Dict[str, Any]:
return {
Expand All @@ -84,20 +83,17 @@ def to_server_payload(self) -> Dict[str, Any]:

@property
@abstractmethod
def _pydantic_field_with_validator(self) -> Tuple[Dict[str, Tuple[Any, ...]], Dict[str, Callable]]:
...
def _pydantic_field_with_validator(self) -> Tuple[Dict[str, Tuple[Any, ...]], Dict[str, Callable]]: ...

@abstractmethod
def _validate_filter(self, metadata_filter: "MetadataFilters") -> None:
pass

@abstractmethod
def _check_allowed_value_type(self, value: Any) -> Any:
...
def _check_allowed_value_type(self, value: Any) -> Any: ...

@abstractmethod
def _validator(self, value: Any) -> Any:
...
def _validator(self, value: Any) -> Any: ...


def _validator_definition(schema: MetadataPropertySchema) -> Dict[str, Any]:
Expand Down Expand Up @@ -395,8 +391,7 @@ class Config:

@property
@abstractmethod
def query_string(self) -> str:
...
def query_string(self) -> str: ...


class TermsMetadataFilter(MetadataFilterSchema):
Expand Down
26 changes: 14 additions & 12 deletions argilla/src/argilla/client/feedback/schemas/remote/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def __updated_record_data(self) -> None:

updated_record = self.from_api(
payload=response.parsed,
question_id_to_name={value: key for key, value in self.question_name_to_id.items()}
if self.question_name_to_id
else None,
question_id_to_name=(
{value: key for key, value in self.question_name_to_id.items()} if self.question_name_to_id else None
),
client=self.client,
)

Expand Down Expand Up @@ -306,15 +306,17 @@ def from_api(
id=payload.id,
client=client,
fields=payload.fields,
responses=[RemoteResponseSchema.from_api(response) for response in payload.responses]
if payload.responses
else [],
suggestions=[
RemoteSuggestionSchema.from_api(suggestion, question_id_to_name=question_id_to_name, client=client)
for suggestion in payload.suggestions
]
if payload.suggestions
else [],
responses=(
[RemoteResponseSchema.from_api(response) for response in payload.responses] if payload.responses else []
),
suggestions=(
[
RemoteSuggestionSchema.from_api(suggestion, question_id_to_name=question_id_to_name, client=client)
for suggestion in payload.suggestions
]
if payload.suggestions
else []
),
metadata=payload.metadata if payload.metadata else {},
vectors=payload.vectors if payload.vectors else {},
external_id=payload.external_id if payload.external_id else None,
Expand Down
14 changes: 8 additions & 6 deletions argilla/src/argilla/client/sdk/text2text/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ def from_client(cls, record: ClientText2TextRecord):
if record.prediction is not None:
prediction = Text2TextAnnotation(
sentences=[
Text2TextPrediction(text=pred[0], score=pred[1])
if isinstance(pred, tuple)
else Text2TextPrediction(text=pred)
(
Text2TextPrediction(text=pred[0], score=pred[1])
if isinstance(pred, tuple)
else Text2TextPrediction(text=pred)
)
for pred in record.prediction
],
agent=record.prediction_agent or MACHINE_NAME,
Expand Down Expand Up @@ -81,9 +83,9 @@ class Text2TextRecord(CreationText2TextRecord):
def to_client(self) -> ClientText2TextRecord:
return ClientText2TextRecord(
text=self.text,
prediction=[(sentence.text, sentence.score) for sentence in self.prediction.sentences]
if self.prediction
else None,
prediction=(
[(sentence.text, sentence.score) for sentence in self.prediction.sentences] if self.prediction else None
),
prediction_agent=self.prediction.agent if self.prediction else None,
annotation=self.annotation.sentences[0].text if self.annotation else None,
annotation_agent=self.annotation.agent if self.annotation else None,
Expand Down
20 changes: 11 additions & 9 deletions argilla/src/argilla/client/sdk/text_classification/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,21 @@ def to_client(self) -> ClientTextClassificationRecord:
multi_label=self.multi_label,
status=self.status,
metadata=self.metadata or {},
prediction=[(label.class_label, label.score) for label in self.prediction.labels]
if self.prediction
else None,
prediction=(
[(label.class_label, label.score) for label in self.prediction.labels] if self.prediction else None
),
prediction_agent=self.prediction.agent if self.prediction else None,
annotation=annotations,
annotation_agent=self.annotation.agent if self.annotation else None,
vectors=self._to_client_vectors(self.vectors),
explanation={
key: [ClientTokenAttributions.parse_obj(attribution) for attribution in attributions]
for key, attributions in self.explanation.items()
}
if self.explanation
else None,
explanation=(
{
key: [ClientTokenAttributions.parse_obj(attribution) for attribution in attributions]
for key, attributions in self.explanation.items()
}
if self.explanation
else None
),
metrics=self.metrics or None,
search_keywords=self.search_keywords or None,
)
Expand Down
22 changes: 13 additions & 9 deletions argilla/src/argilla/client/sdk/token_classification/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ def from_client(cls, record: ClientTokenClassificationRecord):
if record.prediction is not None:
prediction = TokenClassificationAnnotation(
entities=[
EntitySpan(label=ent[0], start=ent[1], end=ent[2])
if len(ent) == 3
else EntitySpan(label=ent[0], start=ent[1], end=ent[2], score=ent[3])
(
EntitySpan(label=ent[0], start=ent[1], end=ent[2])
if len(ent) == 3
else EntitySpan(label=ent[0], start=ent[1], end=ent[2], score=ent[3])
)
for ent in record.prediction
],
agent=record.prediction_agent or MACHINE_NAME,
Expand Down Expand Up @@ -94,13 +96,15 @@ def to_client(self) -> ClientTokenClassificationRecord:
return ClientTokenClassificationRecord(
text=self.text,
tokens=self.tokens,
prediction=[(ent.label, ent.start, ent.end, ent.score) for ent in self.prediction.entities]
if self.prediction
else None,
prediction=(
[(ent.label, ent.start, ent.end, ent.score) for ent in self.prediction.entities]
if self.prediction
else None
),
prediction_agent=self.prediction.agent if self.prediction else None,
annotation=[(ent.label, ent.start, ent.end) for ent in self.annotation.entities]
if self.annotation
else None,
annotation=(
[(ent.label, ent.start, ent.end) for ent in self.annotation.entities] if self.annotation else None
),
annotation_agent=self.annotation.agent if self.annotation else None,
vectors=self._to_client_vectors(self.vectors),
id=self.id,
Expand Down
Loading

0 comments on commit 8f24de6

Please sign in to comment.