From e3938e89b74e50947ecdd4b88583b3f9a2f9ee15 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 16 Dec 2024 16:13:23 +0100 Subject: [PATCH] [BUGFIX] `argilla`: review datasest import with new export flow (#5756) # Description This PR reviews and fixes error when importing datasets from exported datasets **Type of change** - Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla/src/argilla/datasets/_io/_disk.py | 1 + argilla/src/argilla/datasets/_io/_hub.py | 30 +++++++++++++++++------ argilla/src/argilla/responses.py | 20 ++++++++++----- argilla/src/argilla/settings/_field.py | 22 +++-------------- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/argilla/src/argilla/datasets/_io/_disk.py b/argilla/src/argilla/datasets/_io/_disk.py index 09b4473182..b0d53ae634 100644 --- a/argilla/src/argilla/datasets/_io/_disk.py +++ b/argilla/src/argilla/datasets/_io/_disk.py @@ -114,6 +114,7 @@ def from_disk( name=dataset_model.name, workspace_id=workspace.id ) dataset = cls.from_model(model=dataset_model, client=client) + dataset.get() else: # Create a new dataset and load the settings and records if not os.path.exists(settings_path): diff --git a/argilla/src/argilla/datasets/_io/_hub.py b/argilla/src/argilla/datasets/_io/_hub.py index a2914022e7..d45ac44a8a 100644 --- a/argilla/src/argilla/datasets/_io/_hub.py +++ b/argilla/src/argilla/datasets/_io/_hub.py @@ -220,8 +220,11 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"): for col in responses_columns: question_name = col.split(".")[0] if col.endswith("users"): - response_questions[question_name]["users"] = hf_dataset[col] - user_ids.update({UUID(user_id): UUID(user_id) for user_id in set(sum(hf_dataset[col], []))}) + response_questions[question_name]["users"] = hf_dataset[col] or [] + for users in hf_dataset[col]: + if users is None: + continue + user_ids.update({UUID(user_id): user_id for user_id in users}) elif col.endswith("responses"): response_questions[question_name]["responses"] = hf_dataset[col] elif col.endswith("status"): @@ -240,7 +243,15 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"): user_ids[unknown_user_id] = my_user.id # Create a mapper to map the Hugging Face dataset to a Record object - mapping = {col: col for col in hf_dataset.column_names if ".suggestion" in col} + mapping = {} + for col in hf_dataset.column_names: + if ".suggestion" in col: + mapping[col] = col + elif col.startswith("metadata.") and col.replace("metadata.", "") in dataset.schema: + mapping[col] = col.replace("metadata.", "") + elif col.startswith("vector.") and col.replace("vector.", "") in dataset.schema: + mapping[col] = col.replace("vector.", "") + mapper = IngestedRecordMapper(dataset=dataset, mapping=mapping, user_id=my_user.id) # Extract responses and create Record objects @@ -249,14 +260,17 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"): for idx, row in enumerate(hf_dataset): record = mapper(row) for question_name, values in response_questions.items(): - response_values = values["responses"][idx] - response_users = values["users"][idx] - response_status = values["status"][idx] + response_values = values["responses"][idx] or [] + response_users = values["users"][idx] or [] + response_status = values["status"][idx] or [] + + used_users = set() for value, user_id, status in zip(response_values, response_users, response_status): user_id = user_ids[UUID(user_id)] - if user_id in response_users: + if user_id in used_users: continue - response_users[user_id] = True + + used_users.add(user_id) response = Response( user_id=user_id, question_name=question_name, diff --git a/argilla/src/argilla/responses.py b/argilla/src/argilla/responses.py index 807627f624..2d8192cf5a 100644 --- a/argilla/src/argilla/responses.py +++ b/argilla/src/argilla/responses.py @@ -61,16 +61,16 @@ def __init__( status (Union[ResponseStatus, str]): The status of the response as "draft", "submitted", "discarded". """ + if isinstance(status, str): + status = ResponseStatus(status) + if question_name is None: raise ValueError("question_name is required") - if value is None: + if value is None and status == ResponseStatus.submitted: raise ValueError("value is required") if user_id is None: raise ValueError("user_id is required") - if isinstance(status, str): - status = ResponseStatus(status) - self._record = _record self.question_name = question_name self.value = value @@ -253,7 +253,7 @@ def _compute_user_id_from_responses(responses: List[Response]) -> Optional[UUID] @staticmethod def __responses_as_model_values(responses: List[Response]) -> Dict[str, Dict[str, Any]]: """Creates a dictionary of response values from a list of Responses""" - return {answer.question_name: {"value": answer.value} for answer in responses} + return {answer.question_name: {"value": answer.value} for answer in responses if answer.value is not None} @classmethod def __model_as_responses_list(cls, model: UserResponseModel, record: "Record") -> List[Response]: @@ -276,4 +276,12 @@ def __ranking_from_model_value(cls, value: List[Dict[str, Any]]) -> List[str]: @classmethod def __ranking_to_model_value(cls, value: List[str]) -> List[Dict[str, str]]: - return [{"value": v} for v in value] + values = [] + for v in value or []: + if isinstance(v, dict): + values.append(v) + elif isinstance(v, str): + values.append({"value": v}) + else: + raise RecordResponsesError(f"Invalid value for ranking question: {v}") + return values diff --git a/argilla/src/argilla/settings/_field.py b/argilla/src/argilla/settings/_field.py index 3f39c1c1f7..722248053f 100644 --- a/argilla/src/argilla/settings/_field.py +++ b/argilla/src/argilla/settings/_field.py @@ -29,8 +29,7 @@ FieldSettings, ) from argilla.settings._common import SettingsPropertyBase -from argilla.settings._metadata import MetadataField, MetadataType -from argilla.settings._vector import VectorField + try: from typing import Self @@ -296,21 +295,6 @@ def _field_from_model(model: FieldModel) -> Field: raise ArgillaError(f"Unsupported field type: {model.settings.type}") -def _field_from_dict(data: dict) -> Union[Field, VectorField, MetadataType]: +def _field_from_dict(data: dict) -> Field: """Create a field instance from a field dictionary""" - field_type = data["type"] - - if field_type == "text": - return TextField.from_dict(data) - elif field_type == "image": - return ImageField.from_dict(data) - elif field_type == "chat": - return ChatField.from_dict(data) - elif field_type == "custom": - return CustomField.from_dict(data) - elif field_type == "vector": - return VectorField.from_dict(data) - elif field_type == "metadata": - return MetadataField.from_dict(data) - else: - raise ArgillaError(f"Unsupported field type: {field_type}") + return _field_from_model(FieldModel(**data))