diff --git a/src/argilla_sdk/_api/_fields.py b/src/argilla_sdk/_api/_fields.py index 8669c670..dc8a8f3a 100644 --- a/src/argilla_sdk/_api/_fields.py +++ b/src/argilla_sdk/_api/_fields.py @@ -16,14 +16,15 @@ from uuid import UUID import httpx + from argilla_sdk._api._base import ResourceAPI from argilla_sdk._exceptions import api_error_handler -from argilla_sdk._models import FieldBaseModel, TextFieldModel, FieldModel +from argilla_sdk._models import FieldModel __all__ = ["FieldsAPI"] -class FieldsAPI(ResourceAPI[FieldBaseModel]): +class FieldsAPI(ResourceAPI[FieldModel]): """Manage datasets via the API""" http_client: httpx.Client @@ -33,36 +34,39 @@ class FieldsAPI(ResourceAPI[FieldBaseModel]): ################ @api_error_handler - def create(self, dataset_id: UUID, field: FieldModel) -> FieldModel: - url = f"/api/v1/datasets/{dataset_id}/fields" + def get(self, id: UUID) -> FieldModel: + raise NotImplementedError() + + @api_error_handler + def create(self, field: FieldModel) -> FieldModel: + url = f"/api/v1/datasets/{field.dataset_id}/fields" response = self.http_client.post(url=url, json=field.model_dump()) response.raise_for_status() response_json = response.json() - field_model = self._model_from_json(response_json=response_json) - self._log_message(message=f"Created field {field_model.name} in dataset {dataset_id}") - return field_model + created_field = self._model_from_json(response_json=response_json) + self._log_message(message=f"Created field {created_field.name} in dataset {field.dataset_id}") + return created_field @api_error_handler def update(self, field: FieldModel) -> FieldModel: - # TODO: Implement update method for fields with server side ID - raise NotImplementedError + url = f"/api/v1/fields/{field.id}" + response = self.http_client.patch(url, json=field.model_dump()) + response.raise_for_status() + response_json = response.json() + updated_field = self._model_from_json(response_json) + self._log_message(message=f"Update field {updated_field.name} with id {field.id}") + return updated_field @api_error_handler - def delete(self, dataset_id: UUID) -> None: - # TODO: Implement delete method for fields with server side ID - raise NotImplementedError + def delete(self, field_id: UUID) -> None: + url = f"/api/v1/fields/{field_id}" + self.http_client.delete(url).raise_for_status() + self._log_message(message=f"Deleted field {field_id}") #################### # Utility methods # #################### - def create_many(self, dataset_id: UUID, fields: List[FieldModel]) -> List[FieldModel]: - field_models = [] - for field in fields: - field_model = self.create(dataset_id=dataset_id, field=field) - field_models.append(field_model) - return field_models - @api_error_handler def list(self, dataset_id: UUID) -> List[FieldModel]: response = self.http_client.get(f"/api/v1/datasets/{dataset_id}/fields") @@ -78,19 +82,7 @@ def list(self, dataset_id: UUID) -> List[FieldModel]: def _model_from_json(self, response_json: Dict) -> FieldModel: response_json["inserted_at"] = self._date_from_iso_format(date=response_json["inserted_at"]) response_json["updated_at"] = self._date_from_iso_format(date=response_json["updated_at"]) - return self._get_model_from_response(response_json=response_json) + return FieldModel(**response_json) def _model_from_jsons(self, response_jsons: List[Dict]) -> List[FieldModel]: return list(map(self._model_from_json, response_jsons)) - - def _get_model_from_response(self, response_json: Dict) -> FieldModel: - try: - field_type = response_json.get("settings", {}).get("type") - except Exception as e: - raise ValueError("Invalid response type: missing 'settings.type' in response") from e - if field_type == "text": - # TODO: Avoid apply validations here (check_fields=False?) - return TextFieldModel(**response_json) - else: - # TODO: Add more field types - raise ValueError(f"Invalid field type: {field_type}") diff --git a/src/argilla_sdk/_api/_metadata.py b/src/argilla_sdk/_api/_metadata.py index 5fb540a9..beef1f78 100644 --- a/src/argilla_sdk/_api/_metadata.py +++ b/src/argilla_sdk/_api/_metadata.py @@ -16,6 +16,7 @@ from uuid import UUID import httpx + from argilla_sdk._api._base import ResourceAPI from argilla_sdk._exceptions import api_error_handler from argilla_sdk._models import MetadataFieldModel @@ -33,53 +34,44 @@ class MetadataAPI(ResourceAPI[MetadataFieldModel]): ################ @api_error_handler - def create(self, dataset_id: UUID, metadata_field: MetadataFieldModel) -> MetadataFieldModel: - url = f"/api/v1/datasets/{dataset_id}/metadata-properties" - response = self.http_client.post(url=url, json=metadata_field.model_dump()) + def get(self, metadata_id: UUID) -> MetadataFieldModel: + raise NotImplementedError() + + @api_error_handler + def create(self, metadata: MetadataFieldModel) -> MetadataFieldModel: + url = f"/api/v1/datasets/{metadata.dataset_id}/metadata-properties" + response = self.http_client.post(url=url, json=metadata.model_dump()) response.raise_for_status() response_json = response.json() - metadata_field_model = self._model_from_json(response_json=response_json) - self._log_message(message=f"Created metadata field {metadata_field_model.name} in dataset {dataset_id}") - return metadata_field_model + created_metadata = self._model_from_json(response_json=response_json) + self._log_message(message=f"Created metadata field {created_metadata.name} in dataset {metadata.dataset_id}") + return created_metadata @api_error_handler - def update(self, metadata_field: MetadataFieldModel) -> MetadataFieldModel: - url = f"/api/v1/metadata-properties/{metadata_field.id}" - response = self.http_client.patch(url=url, json=metadata_field.model_dump()) + def update(self, metadata: MetadataFieldModel) -> MetadataFieldModel: + url = f"/api/v1/metadata-properties/{metadata.id}" + response = self.http_client.patch(url=url, json=metadata.model_dump()) response.raise_for_status() response_json = response.json() - metadata_field_model = self._model_from_json(response_json=response_json) - self._log_message(message=f"Updated field {metadata_field_model.name}") - return metadata_field_model + updated_metadata = self._model_from_json(response_json=response_json) + self._log_message(message=f"Updated metadata field {updated_metadata.name}") + return updated_metadata - @api_error_handler - def delete(self, id: UUID) -> None: - url = f"/api/v1/metadata-properties/{id}" + def delete(self, metadata_id: UUID) -> None: + url = f"/api/v1/metadata-properties/{metadata_id}" self.http_client.delete(url=url).raise_for_status() - self._log_message(message=f"Deleted field {id}") - - @api_error_handler - def get(self, id: UUID) -> MetadataFieldModel: - raise NotImplementedError() + self._log_message(message=f"Deleted metadata field {metadata_id}") #################### # Utility methods # #################### - def create_many(self, dataset_id: UUID, metadata_fields: List[MetadataFieldModel]) -> List[MetadataFieldModel]: - metadata_field_models = [] - for metadata_field in metadata_fields: - metadata_field_model = self.create(dataset_id=dataset_id, metadata_field=metadata_field) - metadata_field_models.append(metadata_field_model) - return metadata_field_models - @api_error_handler def list(self, dataset_id: UUID) -> List[MetadataFieldModel]: response = self.http_client.get(f"/api/v1/me/datasets/{dataset_id}/metadata-properties") response.raise_for_status() response_json = response.json() - metadata_field_model = self._model_from_jsons(response_jsons=response_json["items"]) - return metadata_field_model + return self._model_from_jsons(response_jsons=response_json["items"]) #################### # Private methods # diff --git a/src/argilla_sdk/_api/_vectors.py b/src/argilla_sdk/_api/_vectors.py index 4c08006a..211924de 100644 --- a/src/argilla_sdk/_api/_vectors.py +++ b/src/argilla_sdk/_api/_vectors.py @@ -16,6 +16,7 @@ from uuid import UUID import httpx + from argilla_sdk._api._base import ResourceAPI from argilla_sdk._exceptions import api_error_handler from argilla_sdk._models import VectorFieldModel @@ -33,36 +34,36 @@ class VectorsAPI(ResourceAPI[VectorFieldModel]): ################ @api_error_handler - def create(self, dataset_id: UUID, vector: VectorFieldModel) -> VectorFieldModel: - url = f"/api/v1/datasets/{dataset_id}/vectors-settings" + def create(self, vector: VectorFieldModel) -> VectorFieldModel: + url = f"/api/v1/datasets/{vector.dataset_id}/vectors-settings" response = self.http_client.post(url=url, json=vector.model_dump()) response.raise_for_status() response_json = response.json() - vector_model = self._model_from_json(response_json=response_json) - self._log_message(message=f"Created vector {vector_model.name} in dataset {dataset_id}") - return vector_model + created_vector = self._model_from_json(response_json=response_json) + self._log_message(message=f"Created vector {created_vector.name} in dataset {created_vector.dataset_id}") + return created_vector @api_error_handler def update(self, vector: VectorFieldModel) -> VectorFieldModel: - # TODO: Implement update method for vectors with server side ID - raise NotImplementedError + url = f"/api/v1/vectors-settings/{vector.id}" + response = self.http_client.patch(url, json=vector.model_dump()) + response.raise_for_status() + response_json = response.json() + updated_vector = self._model_from_json(response_json) + self._log_message(message=f"Updated vector {updated_vector.name} with id {updated_vector.id}") + return updated_vector @api_error_handler def delete(self, vector_id: UUID) -> None: - # TODO: Implement delete method for vectors with server side ID - raise NotImplementedError + url = f"/api/v1/vectors-settings/{vector_id}" + response = self.http_client.delete(url) + response.raise_for_status() + self._log_message(message=f"Deleted vector with id {vector_id}") #################### # Utility methods # #################### - def create_many(self, dataset_id: UUID, vectors: List[VectorFieldModel]) -> List[VectorFieldModel]: - vector_models = [] - for vector in vectors: - vector_model = self.create(dataset_id=dataset_id, vector=vector) - vector_models.append(vector_model) - return vector_models - @api_error_handler def list(self, dataset_id: UUID) -> List[VectorFieldModel]: response = self.http_client.get(f"/api/v1/datasets/{dataset_id}/vectors-settings") diff --git a/src/argilla_sdk/_models/__init__.py b/src/argilla_sdk/_models/__init__.py index f8de4d78..b7a2fa5c 100644 --- a/src/argilla_sdk/_models/__init__.py +++ b/src/argilla_sdk/_models/__init__.py @@ -32,9 +32,8 @@ ScopeModel, ) from argilla_sdk._models._settings._fields import ( - TextFieldModel, - FieldSettings, - FieldBaseModel, + FieldModel, + TextFieldSettings, FieldModel, ) from argilla_sdk._models._settings._questions import ( diff --git a/src/argilla_sdk/_models/_settings/_fields.py b/src/argilla_sdk/_models/_settings/_fields.py index 5e8dbc7d..fe7d7c90 100644 --- a/src/argilla_sdk/_models/_settings/_fields.py +++ b/src/argilla_sdk/_models/_settings/_fields.py @@ -12,27 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Literal from uuid import UUID -from pydantic import BaseModel, field_serializer, field_validator, Field +from pydantic import BaseModel, field_serializer, field_validator from pydantic_core.core_schema import ValidationInfo from argilla_sdk._helpers import log_message +from argilla_sdk._models import ResourceModel -class FieldSettings(BaseModel): - type: str = Field(validate_default=True) +class TextFieldSettings(BaseModel): + type: Literal["text"] = "text" use_markdown: Optional[bool] = False -class FieldBaseModel(BaseModel): - id: Optional[UUID] = None +class FieldModel(ResourceModel): name: str - title: Optional[str] = None required: bool = True description: Optional[str] = None + settings: TextFieldSettings = TextFieldSettings(use_markdown=False) + dataset_id: Optional[UUID] = None @field_validator("name") @classmethod @@ -48,13 +49,6 @@ def __title_default(cls, title: str, info: ValidationInfo) -> str: log_message(f"TextField title is {validated_title}") return validated_title - @field_serializer("id", when_used="unless-none") + @field_serializer("id", "dataset_id", when_used="unless-none") def serialize_id(self, value: UUID) -> str: return str(value) - - -class TextFieldModel(FieldBaseModel): - settings: FieldSettings = FieldSettings(type="text", use_markdown=False) - - -FieldModel = TextFieldModel diff --git a/src/argilla_sdk/_models/_settings/_metadata.py b/src/argilla_sdk/_models/_settings/_metadata.py index 5fff8e19..272c051c 100644 --- a/src/argilla_sdk/_models/_settings/_metadata.py +++ b/src/argilla_sdk/_models/_settings/_metadata.py @@ -19,6 +19,7 @@ from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator from argilla_sdk._exceptions import MetadataError +from argilla_sdk._models import ResourceModel class MetadataPropertyType(str, Enum): @@ -36,9 +37,7 @@ class TermsMetadataPropertySettings(BaseMetadataPropertySettings): type: Literal[MetadataPropertyType.terms] values: Optional[List[str]] = None - @field_validator( - "values", - ) + @field_validator("values") @classmethod def __validate_values(cls, values): if values is None: @@ -94,10 +93,9 @@ class FloatMetadataPropertySettings(NumericMetadataPropertySettings): ] -class MetadataFieldModel(BaseModel): +class MetadataFieldModel(ResourceModel): """The schema definition of a metadata field in an Argilla dataset.""" - id: Optional[UUID] = None name: str settings: MetadataPropertySettings @@ -105,6 +103,8 @@ class MetadataFieldModel(BaseModel): title: Optional[str] = None visible_for_annotators: Optional[bool] = True + dataset_id: Optional[UUID] = None + @field_validator("name") @classmethod def __name_lower(cls, name): @@ -117,7 +117,7 @@ def __title_default(cls, title, values): validated_title = title or values.data["name"] return validated_title - @field_serializer("id", when_used="unless-none") + @field_serializer("id", "dataset_id", when_used="unless-none") def serialize_id(self, value: UUID) -> str: return str(value) diff --git a/src/argilla_sdk/_models/_settings/_vectors.py b/src/argilla_sdk/_models/_settings/_vectors.py index f1d1b466..8b932938 100644 --- a/src/argilla_sdk/_models/_settings/_vectors.py +++ b/src/argilla_sdk/_models/_settings/_vectors.py @@ -15,18 +15,17 @@ from typing import Optional from uuid import UUID -from pydantic import BaseModel, field_validator, field_serializer +from pydantic import field_validator, field_serializer from pydantic_core.core_schema import ValidationInfo +from argilla_sdk._models import ResourceModel from argilla_sdk._helpers import log_message -class VectorFieldModel(BaseModel): +class VectorFieldModel(ResourceModel): name: str title: Optional[str] = None dimensions: int - - id: Optional[UUID] = None dataset_id: Optional[UUID] = None @field_serializer("id", "dataset_id", when_used="unless-none") diff --git a/src/argilla_sdk/datasets/_resource.py b/src/argilla_sdk/datasets/_resource.py index 43527d27..2fe31dba 100644 --- a/src/argilla_sdk/datasets/_resource.py +++ b/src/argilla_sdk/datasets/_resource.py @@ -164,6 +164,15 @@ def create(self) -> "Dataset": self.__rollback_dataset_creation() raise SettingsError from e + def update(self) -> "Dataset": + """Updates the dataset on the server with the current settings. + + Returns: + Dataset: The updated dataset object. + """ + self.settings.update() + return self + @classmethod def from_model(cls, model: DatasetModel, client: "Argilla") -> "Dataset": return cls(client=client, _model=model) @@ -173,7 +182,6 @@ def from_model(cls, model: DatasetModel, client: "Argilla") -> "Dataset": ##################### def _publish(self) -> "Dataset": - self.settings.validate() self._settings.create() self._api.publish(dataset_id=self._model.id) diff --git a/src/argilla_sdk/responses.py b/src/argilla_sdk/responses.py index 1f02fbbc..4cfb955b 100644 --- a/src/argilla_sdk/responses.py +++ b/src/argilla_sdk/responses.py @@ -169,7 +169,7 @@ def api_model(self): def to_dict(self) -> Dict[str, Any]: """Returns the UserResponse as a dictionary""" - return self._model.dict() + return self._model.model_dump() @staticmethod def __responses_as_model_values(answers: List[Response]) -> Dict[str, Dict[str, Any]]: diff --git a/src/argilla_sdk/settings/__init__.py b/src/argilla_sdk/settings/__init__.py index 122130d4..4e376a96 100644 --- a/src/argilla_sdk/settings/__init__.py +++ b/src/argilla_sdk/settings/__init__.py @@ -14,5 +14,6 @@ from argilla_sdk.settings._field import * # noqa: F403 from argilla_sdk.settings._metadata import * # noqa: F403 +from argilla_sdk.settings._vector import * # noqa: F403 from argilla_sdk.settings._question import * # noqa: F403 from argilla_sdk.settings._resource import * # noqa: F403 diff --git a/src/argilla_sdk/settings/_common.py b/src/argilla_sdk/settings/_common.py index 59fd1147..64e198cf 100644 --- a/src/argilla_sdk/settings/_common.py +++ b/src/argilla_sdk/settings/_common.py @@ -14,7 +14,7 @@ from typing import Any, Optional, Union -from argilla_sdk._models import FieldBaseModel, QuestionBaseModel +from argilla_sdk._models import FieldModel, QuestionBaseModel from argilla_sdk._resource import Resource __all__ = ["SettingsPropertyBase"] @@ -23,7 +23,7 @@ class SettingsPropertyBase(Resource): """Base class for dataset fields or questions in Settings class""" - _model: Union[FieldBaseModel, QuestionBaseModel] + _model: Union[FieldModel, QuestionBaseModel] def __repr__(self) -> str: return ( diff --git a/src/argilla_sdk/settings/_field.py b/src/argilla_sdk/settings/_field.py index 14483eb8..c2c85448 100644 --- a/src/argilla_sdk/settings/_field.py +++ b/src/argilla_sdk/settings/_field.py @@ -12,19 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional, Union, TYPE_CHECKING -from argilla_sdk._models import FieldSettings, MetadataFieldModel, TextFieldModel, VectorFieldModel +from argilla_sdk import Argilla +from argilla_sdk._api import FieldsAPI +from argilla_sdk._models import FieldModel, TextFieldSettings from argilla_sdk.settings._common import SettingsPropertyBase from argilla_sdk.settings._metadata import MetadataField, MetadataType +from argilla_sdk.settings._vector import VectorField -__all__ = ["TextField", "FieldType", "VectorField"] +if TYPE_CHECKING: + from argilla_sdk.datasets import Dataset + +__all__ = ["TextField"] class TextField(SettingsPropertyBase): """Text field for use in Argilla `Dataset` `Settings`""" - _model: TextFieldModel + _model: FieldModel + _api: FieldsAPI + + _dataset: "Dataset" def __init__( self, @@ -33,6 +42,7 @@ def __init__( use_markdown: Optional[bool] = False, required: Optional[bool] = True, description: Optional[str] = None, + client: Optional[Argilla] = None, ) -> None: """Text field for use in Argilla `Dataset` `Settings` Parameters: @@ -43,20 +53,19 @@ def __init__( description (Optional[str], optional): The description of the field. Defaults to None. """ - self._model = TextFieldModel( + client = client or Argilla._get_default() + + super().__init__(api=client.api.fields, client=client) + self._model = FieldModel( name=name, title=title, required=required or True, description=description, - settings=FieldSettings(type="text", use_markdown=use_markdown), + settings=TextFieldSettings(use_markdown=use_markdown), ) - @property - def use_markdown(self) -> Optional[bool]: - return self._model.settings.use_markdown - @classmethod - def from_model(cls, model: TextFieldModel) -> "TextField": + def from_model(cls, model: FieldModel) -> "TextField": instance = cls(name=model.name) instance._model = model @@ -64,90 +73,28 @@ def from_model(cls, model: TextFieldModel) -> "TextField": @classmethod def from_dict(cls, data: dict) -> "TextField": - model = TextFieldModel(**data) + model = FieldModel(**data) return cls.from_model(model=model) - -class VectorField(SettingsPropertyBase): - """Vector field for use in Argilla `Dataset` `Settings`""" - - _model: VectorFieldModel - - def __init__( - self, - name: str, - dimensions: int, - title: Optional[str] = None, - ) -> None: - """Vector field for use in Argilla `Dataset` `Settings` - - Parameters: - name (str): The name of the field - dimensions (int): The number of dimensions in the vector - title (Optional[str], optional): The title of the field. Defaults to None. - """ - self._model = VectorFieldModel( - name=name, - title=title, - dimensions=dimensions, - ) - - @classmethod - def from_model(cls, model: VectorFieldModel) -> "VectorField": - instance = cls(name=model.name, dimensions=model.dimensions) - instance._model = model - - return instance - - @classmethod - def from_dict(cls, data: dict) -> "VectorField": - model = VectorFieldModel(**data) - return cls.from_model(model=model) - - @property - def dimensions(self) -> int: - return self._model.dimensions - @property - def title(self) -> Optional[str]: - return self._model.title - - @property - def name(self) -> str: - return self._model.name + def use_markdown(self) -> Optional[bool]: + return self._model.settings.use_markdown - @property - def description(self) -> Optional[str]: - # TODO: Setting resources should be aligned at the API level - return None + @use_markdown.setter + def use_markdown(self, value: bool) -> None: + self._model.settings.use_markdown = value @property - def required(self) -> bool: - # TODO: Setting resources should be aligned at the API level - return False + def dataset(self) -> "Dataset": + return self._dataset - @property - def type(self) -> str: - # TODO: Setting resources should be aligned at the API level - return "vector" - - -FieldType = Union[TextField, VectorField, MetadataType] - - -def field_from_model(model: Union[TextFieldModel, VectorFieldModel, MetadataFieldModel]) -> FieldType: - """Create a field instance from a field model""" - if isinstance(model, TextFieldModel): - return TextField.from_model(model) - elif isinstance(model, VectorFieldModel): - return VectorField.from_model(model) - elif isinstance(model, MetadataFieldModel): - return MetadataField.from_model(model) - else: - raise ValueError(f"Unsupported field model type: {type(model)}") + @dataset.setter + def dataset(self, value: "Dataset") -> None: + self._dataset = value + self._model.dataset_id = self._dataset.id -def field_from_dict(data: dict) -> FieldType: +def field_from_dict(data: dict) -> Union[TextField, VectorField, MetadataType]: """Create a field instance from a field dictionary""" if data["type"] == "text": return TextField.from_dict(data) diff --git a/src/argilla_sdk/settings/_metadata.py b/src/argilla_sdk/settings/_metadata.py index d69eebee..9c54376f 100644 --- a/src/argilla_sdk/settings/_metadata.py +++ b/src/argilla_sdk/settings/_metadata.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union, List +from typing import Optional, Union, List, TYPE_CHECKING +from argilla_sdk._api._metadata import MetadataAPI from argilla_sdk._exceptions import MetadataError from argilla_sdk._models import ( MetadataPropertyType, @@ -22,8 +23,11 @@ IntegerMetadataPropertySettings, MetadataFieldModel, ) -from argilla_sdk.settings._common import SettingsPropertyBase +from argilla_sdk._resource import Resource +from argilla_sdk.client import Argilla +if TYPE_CHECKING: + from argilla_sdk import Dataset __all__ = [ "TermsMetadataProperty", @@ -33,18 +37,31 @@ ] -class MetadataPropertyBase(SettingsPropertyBase): +class MetadataPropertyBase(Resource): _model: MetadataFieldModel + _api: MetadataAPI + + _dataset: "Dataset" + + def __init__(self, client: Optional[Argilla] = None) -> None: + client = client or Argilla._get_default() + super().__init__(client=client, api=client.api.metadata) @property - def required(self) -> bool: - # This attribute is not present in the MetadataFieldModel - return False + def name(self) -> str: + return self._model.name + + @name.setter + def name(self, value: str) -> None: + self._model.name = value @property - def description(self) -> Optional[str]: - # This attribute is not present in the MetadataFieldModel - return None + def title(self) -> Optional[str]: + return self._model.title + + @title.setter + def title(self, value: Optional[str]) -> None: + self._model.title = value @property def visible_for_annotators(self) -> Optional[bool]: @@ -54,6 +71,20 @@ def visible_for_annotators(self) -> Optional[bool]: def visible_for_annotators(self, value: Optional[bool]) -> None: self._model.visible_for_annotators = value + @property + def dataset(self) -> Optional["Dataset"]: + return self._dataset + + @dataset.setter + def dataset(self, value: "Dataset") -> None: + self._dataset = value + self._model.dataset_id = value.id + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(name={self.name}, title={self.title}, dimensions={self.visible_for_annotators})" + ) + class TermsMetadataProperty(MetadataPropertyBase): def __init__( @@ -62,6 +93,7 @@ def __init__( options: Optional[List[str]] = None, title: Optional[str] = None, visible_for_annotators: Optional[bool] = True, + client: Optional[Argilla] = None, ) -> None: """Create a metadata field with terms settings. @@ -73,7 +105,7 @@ def __init__( Raises: MetadataError: If an error occurs while defining metadata settings """ - super().__init__() + super().__init__(client=client) try: settings = TermsMetadataPropertySettings(values=options, type=MetadataPropertyType.terms) @@ -106,7 +138,12 @@ def from_model(cls, model: MetadataFieldModel) -> "TermsMetadataProperty": class FloatMetadataProperty(MetadataPropertyBase): def __init__( - self, name: str, min: Optional[float] = None, max: Optional[float] = None, title: Optional[str] = None + self, + name: str, + min: Optional[float] = None, + max: Optional[float] = None, + title: Optional[str] = None, + client: Optional[Argilla] = None, ) -> None: """Create a metadata field with float settings. @@ -115,11 +152,14 @@ def __init__( min (Optional[float]): The minimum value max (Optional[float]): The maximum value title (Optional[str]): The title of the metadata field + client (Optional[Argilla]): The client to use for API requests Raises: MetadataError: If an error occurs while defining metadata settings - """ + + super().__init__(client=client) + try: settings = FloatMetadataPropertySettings(min=min, max=max, type=MetadataPropertyType.float) except ValueError as e: @@ -158,7 +198,12 @@ def from_model(cls, model: MetadataFieldModel) -> "FloatMetadataProperty": class IntegerMetadataProperty(MetadataPropertyBase): def __init__( - self, name: str, min: Optional[int] = None, max: Optional[int] = None, title: Optional[str] = None + self, + name: str, + min: Optional[int] = None, + max: Optional[int] = None, + title: Optional[str] = None, + client: Optional[Argilla] = None, ) -> None: """Create a metadata field with integer settings. @@ -170,6 +215,7 @@ def __init__( Raises: MetadataError: If an error occurs while defining metadata settings """ + super().__init__(client=client) try: settings = IntegerMetadataPropertySettings(min=min, max=max, type=MetadataPropertyType.integer) @@ -207,7 +253,11 @@ def from_model(cls, model: MetadataFieldModel) -> "IntegerMetadataProperty": return instance -MetadataType = Union[TermsMetadataProperty, FloatMetadataProperty, IntegerMetadataProperty] +MetadataType = Union[ + TermsMetadataProperty, + FloatMetadataProperty, + IntegerMetadataProperty, +] class MetadataField: @@ -233,6 +283,7 @@ def from_dict(cls, data: dict) -> MetadataType: } metadata_type = data["type"] try: - return switch[metadata_type](**data) + metadata_model = MetadataFieldModel(**data) + return switch[metadata_type].from_model(metadata_model) except KeyError as e: raise MetadataError(f"Unknown metadata property type: {metadata_type}") from e diff --git a/src/argilla_sdk/settings/_resource.py b/src/argilla_sdk/settings/_resource.py index 2a1ff772..d62f16b4 100644 --- a/src/argilla_sdk/settings/_resource.py +++ b/src/argilla_sdk/settings/_resource.py @@ -16,15 +16,16 @@ import os from functools import cached_property from pathlib import Path -from typing import List, Optional, TYPE_CHECKING, Dict, Union +from typing import List, Optional, TYPE_CHECKING, Dict, Union, Iterator, Sequence from uuid import UUID from argilla_sdk._exceptions import SettingsError, ArgillaAPIError, ArgillaSerializeError -from argilla_sdk._models import TextFieldModel, TextQuestionModel, DatasetModel +from argilla_sdk._models._dataset import DatasetModel from argilla_sdk._resource import Resource -from argilla_sdk.settings._field import FieldType, VectorField, field_from_model, field_from_dict -from argilla_sdk.settings._metadata import MetadataType +from argilla_sdk.settings._field import TextField +from argilla_sdk.settings._metadata import MetadataType, MetadataField from argilla_sdk.settings._question import QuestionType, question_from_model, question_from_dict, QuestionPropertyBase +from argilla_sdk.settings._vector import VectorField if TYPE_CHECKING: from argilla_sdk.datasets import Dataset @@ -42,7 +43,7 @@ class Settings(Resource): def __init__( self, - fields: Optional[List[FieldType]] = None, + fields: Optional[List[TextField]] = None, questions: Optional[List[QuestionType]] = None, vectors: Optional[List[VectorField]] = None, metadata: Optional[List[MetadataType]] = None, @@ -62,9 +63,9 @@ def __init__( super().__init__(client=_dataset._client if _dataset else None) self.__questions = questions or [] - self.__fields = fields or [] - self.__vectors = vectors or [] - self.__metadata = metadata or [] + self.__fields = SettingsProperties(self, fields) + self.__vectors = SettingsProperties(self, vectors) + self.__metadata = SettingsProperties(self, metadata) self.__guidelines = self.__process_guidelines(guidelines) self.__allow_extra_metadata = allow_extra_metadata @@ -76,12 +77,12 @@ def __init__( ##################### @property - def fields(self) -> List[FieldType]: + def fields(self) -> "SettingsProperties": return self.__fields @fields.setter - def fields(self, fields: List[FieldType]): - self.__fields = fields + def fields(self, fields: List[TextField]): + self.__fields = SettingsProperties(self, fields) @property def questions(self) -> List[QuestionType]: @@ -92,28 +93,28 @@ def questions(self, questions: List[QuestionType]): self.__questions = questions @property - def guidelines(self) -> str: - return self.__guidelines - - @guidelines.setter - def guidelines(self, guidelines: str): - self.__guidelines = self.__process_guidelines(guidelines) - - @property - def vectors(self) -> List[VectorField]: + def vectors(self) -> "SettingsProperties": return self.__vectors @vectors.setter def vectors(self, vectors: List[VectorField]): - self.__vectors = vectors + self.__vectors = SettingsProperties(self, vectors) @property - def metadata(self) -> List[MetadataType]: + def metadata(self) -> "SettingsProperties": return self.__metadata @metadata.setter def metadata(self, metadata: List[MetadataType]): - self.__metadata = metadata + self.__metadata = SettingsProperties(self, metadata) + + @property + def guidelines(self) -> str: + return self.__guidelines + + @guidelines.setter + def guidelines(self, guidelines: str): + self.__guidelines = self.__process_guidelines(guidelines) @property def allow_extra_metadata(self) -> bool: @@ -151,7 +152,7 @@ def schema(self) -> dict: return schema_dict @cached_property - def schema_by_id(self) -> Dict[UUID, Union[FieldType, QuestionType]]: + def schema_by_id(self) -> Dict[UUID, Union[TextField, QuestionType, MetadataType, VectorField]]: return {v.id: v for v in self.schema.values()} def validate(self) -> None: @@ -163,21 +164,35 @@ def validate(self) -> None: ##################### def get(self) -> "Settings": - self.__fetch_fields() - self.__fetch_questions() - self.__fetch_vectors() - self.__fetch_metadata() + self.fields = self._fetch_fields() + self.questions = self._fetch_questions() + self.vectors = self._fetch_vectors() + self.metadata = self._fetch_metadata() self.__get_dataset_related_attributes() self._update_last_api_call() return self def create(self) -> "Settings": - self.__upsert_fields() - self.__upsert_questions() - self.__upsert_vectors() - self.__upsert_metadata() - self.__update_dataset_related_attributes() + self.validate() + + self._update_dataset_related_attributes() + self.__fields.create() + self._create_questions() + self.__vectors.create() + self.__metadata.create() + + self._update_last_api_call() + return self + + def update(self) -> "Resource": + self.validate() + + self._update_dataset_related_attributes() + self.__fields.update() + self.__vectors.update() + self.__metadata.update() + # self.questions.update() self._update_last_api_call() return self @@ -198,8 +213,10 @@ def serialize(self): try: return { "guidelines": self.guidelines, - "fields": self.__serialize_fields(fields=self.fields), - "questions": self.__serialize_questions(questions=self.questions), + "questions": self.__serialize_questions(self.questions), + "fields": self.__fields.serialize(), + "vectors": self.vectors.serialize(), + "metadata": self.metadata.serialize(), "allow_extra_metadata": self.allow_extra_metadata, } except Exception as e: @@ -225,18 +242,23 @@ def from_json(cls, path: Union[Path, str]) -> "Settings": with open(path, "r") as file: settings_dict = json.load(file) + fields = settings_dict.get("fields", []) + vectors = settings_dict.get("vectors", []) + metadata = settings_dict.get("metadata", []) guidelines = settings_dict.get("guidelines") - fields = settings_dict.get("fields") - questions = settings_dict.get("questions") allow_extra_metadata = settings_dict.get("allow_extra_metadata") - fields = [field_from_dict(field) for field in fields] - questions = [question_from_dict(question) for question in questions] + questions = [question_from_dict(question) for question in settings_dict.get("questions", [])] + fields = [TextField.from_dict(field) for field in fields] + vectors = [VectorField.from_dict(vector) for vector in vectors] + metadata = [MetadataField.from_dict(metadata) for metadata in metadata] return cls( - guidelines=guidelines, - fields=fields, questions=questions, + fields=fields, + vectors=vectors, + metadata=metadata, + guidelines=guidelines, allow_extra_metadata=allow_extra_metadata, ) @@ -257,29 +279,21 @@ def __repr__(self) -> str: # Private methods # ##################### - def __fetch_fields(self) -> List[FieldType]: + def _fetch_fields(self) -> List[TextField]: models = self._client.api.fields.list(dataset_id=self._dataset.id) - self.__fields = [field_from_model(model) for model in models] + return [TextField.from_model(model) for model in models] - return self.__fields - - def __fetch_questions(self) -> List[QuestionType]: + def _fetch_questions(self) -> List[QuestionType]: models = self._client.api.questions.list(dataset_id=self._dataset.id) - self.__questions = [question_from_model(model) for model in models] - - return self.__questions + return [question_from_model(model) for model in models] - def __fetch_vectors(self) -> List[VectorField]: - models = self._client.api.vectors.list(dataset_id=self._dataset.id) - self.__vectors = [field_from_model(model) for model in models] + def _fetch_vectors(self) -> List[VectorField]: + models = self.dataset._client.api.vectors.list(self.dataset.id) + return [VectorField.from_model(model) for model in models] - return self.__vectors - - def __fetch_metadata(self) -> List[MetadataType]: + def _fetch_metadata(self) -> List[MetadataType]: models = self._client.api.metadata.list(dataset_id=self._dataset.id) - self.__metadata = [field_from_model(model) for model in models] - - return self.__metadata + return [MetadataField.from_model(model) for model in models] def __get_dataset_related_attributes(self): # This flow may be a bit weird, but it's the only way to update the dataset related attributes @@ -295,7 +309,7 @@ def __get_dataset_related_attributes(self): self.guidelines = dataset_model.guidelines self.allow_extra_metadata = dataset_model.allow_extra_metadata - def __update_dataset_related_attributes(self): + def _update_dataset_related_attributes(self): # This flow may be a bit weird, but it's the only way to update the dataset related attributes # Everything is point that we should have several settings-related endpoints in the API to handle this. # POST /api/v1/datasets/{dataset_id}/settings @@ -312,7 +326,7 @@ def __update_dataset_related_attributes(self): ) self._client.api.datasets.update(dataset_model) - def __upsert_questions(self) -> None: + def _create_questions(self) -> None: for question in self.__questions: try: question_model = self._client.api.questions.create( @@ -322,29 +336,6 @@ def __upsert_questions(self) -> None: except ArgillaAPIError as e: raise SettingsError(f"Failed to create question {question.name}") from e - def __upsert_fields(self) -> None: - for field in self.__fields: - try: - field_model = self._client.api.fields.create(dataset_id=self._dataset.id, field=field._model) - field._model = field_model - except ArgillaAPIError as e: - raise SettingsError(f"Failed to create field {field.name}") from e - - def __upsert_vectors(self) -> None: - for vector in self.__vectors: - try: - vector_model = self._client.api.vectors.create(dataset_id=self._dataset.id, vector=vector._model) - vector._model = vector_model - except ArgillaAPIError as e: - raise SettingsError(f"Failed to create vector {vector.name}") from e - - def __upsert_metadata(self) -> None: - for metadata in self.__metadata: - metadata_model = self._client.api.metadata.create( - dataset_id=self._dataset.id, metadata_field=metadata._model - ) - metadata._model = metadata_model - def _validate_empty_settings(self): if not all([self.fields, self.questions]): message = "Fields and questions are required" @@ -353,33 +344,14 @@ def _validate_empty_settings(self): def _validate_duplicate_names(self) -> None: dataset_properties_by_name = {} - for prop in self.fields + self.questions + self.vectors + self.metadata: - if prop.name in dataset_properties_by_name: - raise SettingsError( - f"names of dataset settings must be unique, " - f"but the name {prop.name!r} is used by {type(prop).__name__!r} and {type(dataset_properties_by_name[prop.name]).__name__!r} " - ) - dataset_properties_by_name[prop.name] = prop - - def __process_fields(self, fields: List[FieldType]) -> List["TextFieldModel"]: - processed_fields = [] - for field in fields: - try: - processed_field = field._model - except Exception as e: - raise SettingsError(f"Failed to process field {field.name}") from e - processed_fields.append(processed_field) - return processed_fields - - def __process_questions(self, questions: List[QuestionType]) -> List["TextQuestionModel"]: - processed_questions = [] - for question in questions: - try: - processed_question = question._model - except Exception as e: - raise SettingsError(f"Failed to process question {question.name}") from e - processed_questions.append(processed_question) - return processed_questions + for properties in [self.fields, self.questions, self.vectors, self.metadata]: + for property in properties: + if property.name in dataset_properties_by_name: + raise SettingsError( + f"names of dataset settings must be unique, " + f"but the name {property.name!r} is used by {type(property).__name__!r} and {type(dataset_properties_by_name[property.name]).__name__!r} " + ) + dataset_properties_by_name[property.name] = property def __process_guidelines(self, guidelines): if guidelines is None: @@ -394,8 +366,71 @@ def __process_guidelines(self, guidelines): return guidelines - def __serialize_fields(self, fields: List[FieldType]): - return [field.serialize() for field in fields] - def __serialize_questions(self, questions: List[QuestionType]): return [question.serialize() for question in questions] + + +Property = Union[TextField, VectorField, MetadataType, QuestionType] + + +class SettingsProperties(Sequence[Property]): + """A collection of properties (fields, questions, vectors and metadata) for a dataset settings object. + + This class is used to store the properties of a dataset settings object + """ + + def __init__(self, settings: "Settings", properties: List[Property]): + self._properties_by_name = {} + self._settings = settings + + for property in properties or []: + self.add(property) + + def __getitem__(self, key: Union[str, int]) -> Optional[Property]: + if isinstance(key, int): + return list(self._properties_by_name.values())[key] + return self._properties_by_name.get(key) + + def __iter__(self) -> Iterator[Property]: + return iter(self._properties_by_name.values()) + + def __len__(self): + return len(self._properties_by_name) + + def __eq__(self, other): + """Check if two instances are equal. Overloads the == operator.""" + if not isinstance(other, SettingsProperties): + return False + return self._properties_by_name == other._properties_by_name + + def add(self, property: Property) -> Property: + self._validate_new_property(property) + self._properties_by_name[property.name] = property + setattr(self, property.name, property) + return property + + def create(self): + for property in self: + try: + property.dataset = self._settings.dataset + property.create() + except ArgillaAPIError as e: + raise SettingsError(f"Failed to create property {property.name!r}: {e.message}") from e + + def update(self): + for item in self: + try: + item.dataset = self._settings.dataset + item.update() if item.id else item.create() + except ArgillaAPIError as e: + raise SettingsError(f"Failed to update {item.name!r}: {e.message}") from e + + def serialize(self) -> List[dict]: + return [property.serialize() for property in self] + + def _validate_new_property(self, property: Property) -> None: + if property.name in self._properties_by_name: + raise ValueError(f"Property with name {property.name!r} already exists in the collection") + + if property.name in dir(self): + raise ValueError(f"Property with name {property.name!r} conflicts with an existing attribute") diff --git a/src/argilla_sdk/settings/_vector.py b/src/argilla_sdk/settings/_vector.py new file mode 100644 index 00000000..cc6a5727 --- /dev/null +++ b/src/argilla_sdk/settings/_vector.py @@ -0,0 +1,100 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, TYPE_CHECKING + +from argilla_sdk._api._vectors import VectorsAPI +from argilla_sdk._models import VectorFieldModel +from argilla_sdk._resource import Resource +from argilla_sdk.client import Argilla + +if TYPE_CHECKING: + from argilla_sdk import Dataset + +__all__ = ["VectorField"] + + +class VectorField(Resource): + """Vector field for use in Argilla `Dataset` `Settings`""" + + _model: VectorFieldModel + _api: VectorsAPI + _dataset: "Dataset" + + def __init__( + self, + name: str, + dimensions: int, + title: Optional[str] = None, + _client: Optional["Argilla"] = None, + ) -> None: + """Vector field for use in Argilla `Dataset` `Settings` + + Parameters: + name (str): The name of the field + dimensions (int): The number of dimensions in the vector + title (Optional[str], optional): The title of the field. Defaults to None. + """ + client = _client or Argilla._get_default() + super().__init__(api=client.api.vectors, client=client) + self._model = VectorFieldModel(name=name, title=title, dimensions=dimensions) + self._dataset = None + + @property + def name(self) -> str: + return self._model.name + + @name.setter + def name(self, value: str) -> None: + self._model.name = value + + @property + def title(self) -> Optional[str]: + return self._model.title + + @title.setter + def title(self, value: Optional[str]) -> None: + self._model.title = value + + @property + def dimensions(self) -> int: + return self._model.dimensions + + @dimensions.setter + def dimensions(self, value: int) -> None: + self._model.dimensions = value + + @property + def dataset(self) -> "Dataset": + return self._dataset + + @dataset.setter + def dataset(self, value: "Dataset") -> None: + self._dataset = value + self._model.dataset_id = self._dataset.id + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self.name}, title={self.title}, dimensions={self.dimensions})" + + @classmethod + def from_model(cls, model: VectorFieldModel) -> "VectorField": + instance = cls(name=model.name, dimensions=model.dimensions) + instance._model = model + + return instance + + @classmethod + def from_dict(cls, data: dict) -> "VectorField": + model = VectorFieldModel(**data) + return cls.from_model(model=model) diff --git a/src/argilla_sdk/suggestions.py b/src/argilla_sdk/suggestions.py index 7d37a752..db9ec8c1 100644 --- a/src/argilla_sdk/suggestions.py +++ b/src/argilla_sdk/suggestions.py @@ -125,7 +125,7 @@ def from_model(cls, model: SuggestionModel, dataset: "Dataset") -> "Suggestion": model.question_name = question.name model.value = cls.__from_model_value(model.value, question) - return cls(**model.dict()) + return cls(**model.model_dump()) def api_model(self) -> SuggestionModel: if self.record is None or self.record.dataset is None: diff --git a/tests/integration/test_export_dataset.py b/tests/integration/test_export_dataset.py index b90d7615..b57158ac 100644 --- a/tests/integration/test_export_dataset.py +++ b/tests/integration/test_export_dataset.py @@ -50,17 +50,17 @@ def test_export_dataset_to_disk(dataset: rg.Dataset): { "text": "Hello World, how are you?", "label": "positive", - "external_id": uuid.uuid4(), + "id": uuid.uuid4(), }, { "text": "Hello World, how are you?", "label": "negative", - "external_id": uuid.uuid4(), + "id": uuid.uuid4(), }, { "text": "Hello World, how are you?", "label": "positive", - "external_id": uuid.uuid4(), + "id": uuid.uuid4(), }, ] dataset.records.log(records=mock_data) diff --git a/tests/integration/test_update_dataset_settings.py b/tests/integration/test_update_dataset_settings.py new file mode 100644 index 00000000..0be6aa6c --- /dev/null +++ b/tests/integration/test_update_dataset_settings.py @@ -0,0 +1,52 @@ +# Copyright 2024-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +import pytest + +from argilla_sdk import Dataset, Settings, TextField, LabelQuestion, Argilla, VectorField, FloatMetadataProperty + + +@pytest.fixture +def dataset(): + return Dataset( + name=f"test_dataset_{uuid.uuid4().int}", + settings=Settings( + fields=[TextField(name="text", use_markdown=False)], + questions=[LabelQuestion(name="label", labels=["a", "b", "c"])], + ), + ).create() + + +class TestUpdateDatasetSettings: + def test_update_settings(self, client: Argilla, dataset: Dataset): + settings = dataset.settings + + settings.fields.text.use_markdown = True + dataset.settings.vectors.add(VectorField(name="vector", dimensions=10)) + dataset.settings.metadata.add(FloatMetadataProperty(name="metadata")) + dataset.settings.update() + + dataset = client.datasets(dataset.name) + settings = dataset.settings + assert settings.fields.text.use_markdown is True + assert settings.vectors.vector.dimensions == 10 + assert isinstance(settings.metadata.metadata, FloatMetadataProperty) + + settings.vectors.vector.title = "A new title for vector" + + settings.update() + dataset = client.datasets(dataset.name) + assert dataset.settings.vectors.vector.title == "A new title for vector" diff --git a/tests/unit/export/test_settings_export_import_compatibillity.py b/tests/unit/export/test_settings_export_import_compatibillity.py index c72ccdde..69617a00 100644 --- a/tests/unit/export/test_settings_export_import_compatibillity.py +++ b/tests/unit/export/test_settings_export_import_compatibillity.py @@ -76,6 +76,22 @@ def dataset(httpx_mock: HTTPXMock, settings) -> rg.Dataset: yield dataset +def test_settings_to_json(settings): + with TemporaryDirectory() as temp_dir: + temp_file_path = f"{temp_dir}/settings.json" + settings.to_json(temp_file_path) + with open(temp_file_path, "r") as f: + settings_json = f.read() + + assert "fields" in settings_json + assert "questions" in settings_json + assert "metadata" in settings_json + assert "vectors" in settings_json + + loaded_settings = rg.Settings.from_json(temp_file_path) + assert settings == loaded_settings + + def test_export_settings_from_disk(settings): with TemporaryDirectory() as temp_dir: temp_file_path = f"{temp_dir}/settings.json" diff --git a/tests/unit/test_resources/test_datasets.py b/tests/unit/test_resources/test_datasets.py index 80f0b1b0..34b00052 100644 --- a/tests/unit/test_resources/test_datasets.py +++ b/tests/unit/test_resources/test_datasets.py @@ -74,34 +74,6 @@ def dataset(httpx_mock: HTTPXMock) -> rg.Dataset: class TestDatasets: - def mock_dataset_settings(self, httpx_mock: HTTPXMock, dataset_id: uuid.UUID, dataset_dict: dict): - mock_field = { - "id": str(uuid.uuid4()), - "name": "text", - "settings": {"type": "text", "use_markdown": True}, - "inserted_at": datetime.utcnow().isoformat(), - "updated_at": datetime.utcnow().isoformat(), - } - mock_question = { - "id": str(uuid.uuid4()), - "name": "response", - "settings": {"type": "text", "use_markdown": True}, - "inserted_at": datetime.utcnow().isoformat(), - "updated_at": datetime.utcnow().isoformat(), - } - httpx_mock.add_response( - json=dataset_dict, - url=self.url(f"/api/v1/datasets/{dataset_id}"), - method="PATCH", - status_code=200, - ) - httpx_mock.add_response( - json=mock_field, url=self.url(f"/api/v1/datasets/{dataset_id}/fields"), method="POST", status_code=200 - ) - httpx_mock.add_response( - json=mock_question, url=self.url(f"/api/v1/datasets/{dataset_id}/questions"), method="POST", status_code=200 - ) - def url(self, path: str) -> str: return f"http://test_url{path}" @@ -148,7 +120,7 @@ def test_create_dataset(self, httpx_mock: HTTPXMock, status_code, expected_excep method="PUT", status_code=200, ) - self.mock_dataset_settings(httpx_mock, mock_dataset_id, mock_return_value) + self._mock_dataset_settings(httpx_mock, mock_dataset_id, mock_return_value) with httpx.Client(): if expected_exception: with pytest.raises(expected_exception=expected_exception) as excinfo: @@ -183,6 +155,12 @@ def test_update_dataset(self, httpx_mock: HTTPXMock, status_code, expected_excep "inserted_at": datetime.utcnow().isoformat(), "updated_at": datetime.utcnow().isoformat(), } + httpx_mock.add_response( + json=mock_patch_return_value, + url=self.url(f"/api/v1/datasets/{mock_dataset_id}"), + method="GET", + status_code=200, + ) httpx_mock.add_response( json=mock_patch_return_value, url=self.url(f"/api/v1/datasets/{mock_dataset_id}"), @@ -191,6 +169,20 @@ def test_update_dataset(self, httpx_mock: HTTPXMock, status_code, expected_excep ) dataset.id = mock_dataset_id + if status_code == 200: + httpx_mock.add_response( + json={ + "id": str(uuid.uuid4()), + "name": "text", + "settings": {"type": "text", "use_markdown": True}, + "inserted_at": datetime.utcnow().isoformat(), + "updated_at": datetime.utcnow().isoformat(), + }, + url=self.url(f"/api/v1/datasets/{mock_dataset_id}/fields"), + method="POST", + status_code=200, + ) + with httpx.Client(): if expected_exception: with pytest.raises(expected_exception=expected_exception) as excinfo: @@ -233,6 +225,34 @@ def test_delete_dataset(self, httpx_mock: HTTPXMock, status_code, expected_excep dataset.delete() assert dataset.name == mock_return_value["name"] + def _mock_dataset_settings(self, httpx_mock: HTTPXMock, dataset_id: uuid.UUID, dataset_dict: dict): + mock_field = { + "id": str(uuid.uuid4()), + "name": "text", + "settings": {"type": "text", "use_markdown": True}, + "inserted_at": datetime.utcnow().isoformat(), + "updated_at": datetime.utcnow().isoformat(), + } + mock_question = { + "id": str(uuid.uuid4()), + "name": "response", + "settings": {"type": "text", "use_markdown": True}, + "inserted_at": datetime.utcnow().isoformat(), + "updated_at": datetime.utcnow().isoformat(), + } + httpx_mock.add_response( + json=dataset_dict, + url=self.url(f"/api/v1/datasets/{dataset_id}"), + method="PATCH", + status_code=200, + ) + httpx_mock.add_response( + json=mock_field, url=self.url(f"/api/v1/datasets/{dataset_id}/fields"), method="POST", status_code=200 + ) + httpx_mock.add_response( + json=mock_question, url=self.url(f"/api/v1/datasets/{dataset_id}/questions"), method="POST", status_code=200 + ) + class TestDatasetsAPI: def test_delete_dataset(self, httpx_mock: HTTPXMock): diff --git a/tests/unit/test_resources/test_fields.py b/tests/unit/test_resources/test_fields.py index 1064556f..44a1edb2 100644 --- a/tests/unit/test_resources/test_fields.py +++ b/tests/unit/test_resources/test_fields.py @@ -19,11 +19,11 @@ from pytest_httpx import HTTPXMock import argilla_sdk as rg -from argilla_sdk._models import TextFieldModel +from argilla_sdk._models import FieldModel class TestFieldsAPI: - def test_create_many_fields(self, httpx_mock: HTTPXMock): + def test_create_field(self, httpx_mock: HTTPXMock): # TODO: Add a test for the delete method in client mock_dataset_id = uuid.uuid4() mock_return_value = { @@ -42,7 +42,7 @@ def test_create_many_fields(self, httpx_mock: HTTPXMock): "required": True, "settings": {"type": "text", "use_markdown": False}, } - mock_field = TextFieldModel(**mock_field) + mock_field = FieldModel(**mock_field, dataset_id=mock_dataset_id) httpx_mock.add_response( json=mock_return_value, url=f"http://test_url/api/v1/datasets/{mock_dataset_id}/fields", @@ -51,4 +51,4 @@ def test_create_many_fields(self, httpx_mock: HTTPXMock): ) with httpx.Client() as client: client = rg.Argilla(api_url="http://test_url") - client.api.fields.create_many(dataset_id=mock_dataset_id, fields=[mock_field]) + client.api.fields.create(mock_field) diff --git a/tests/unit/test_settings/test_settings.py b/tests/unit/test_settings/test_settings.py index c39230f7..92650e49 100644 --- a/tests/unit/test_settings/test_settings.py +++ b/tests/unit/test_settings/test_settings.py @@ -20,8 +20,8 @@ class TestSettings: def test_init_settings(self): settings = rg.Settings() - assert settings.fields == [] - assert settings.questions == [] + assert len(settings.fields) == 0 + assert len(settings.questions) == 0 def test_with_guidelines(self): mock_guidelines = "This is a guideline" diff --git a/tests/unit/test_settings/test_terms_metadata.py b/tests/unit/test_settings/test_terms_metadata.py index d4eda709..0cc39acd 100644 --- a/tests/unit/test_settings/test_terms_metadata.py +++ b/tests/unit/test_settings/test_terms_metadata.py @@ -32,11 +32,14 @@ def test_create_metadata_terms(self): assert property.api_model().model_dump() == { "id": None, + "dataset_id": None, "name": "metadata", "settings": {"type": "terms", "values": ["option1", "option2"], "visible_for_annotators": True}, "title": "A metadata property", "type": "terms", "visible_for_annotators": True, + "inserted_at": None, + "updated_at": None, } def test_create_terms_metadata_without_options(self): @@ -51,11 +54,14 @@ def test_create_terms_metadata_without_options(self): assert model.type == "terms" assert model.model_dump() == { "id": None, + "dataset_id": None, "name": "metadata", "title": "metadata", "settings": {"type": "terms", "values": None, "visible_for_annotators": True}, "type": "terms", "visible_for_annotators": True, + "inserted_at": None, + "updated_at": None, } def test_create_from_model(self):