Skip to content
This repository has been archived by the owner on Oct 2, 2024. It is now read-only.

Commit

Permalink
[FEATURE] Allow update dataset settings for fields, vectors and metad…
Browse files Browse the repository at this point in the history
…ata (#232)

* refactor: Define VectorFieldModel as a ResourceModel

* feat: Align VectorsAPI methods with endpoints

* refactor: Define VectoField as Resource

* refactor: Using VectorField methods

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* chore: apply PR suggestions

* refactor: Redefine field model to customize only settings

* feat: Implement Fields API methods using new model def

* chore: Review some naming

* feat: Redefine TextField as a Resource

* chore: define private VectoField properties

* chore: Redefine fields refs based on TextField

* chore: Remove unused defs

* tests: Update tests

* chore: Using proper import

* refactor: Review Metadata API model and methods

* refactor: Align metadata fields with Resource class

* refactor: implement upsert metadata using metadata resource methods

* tests: Adapt tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat: Allow update fields, vectors and metadata settings

* feat: update dataset with its settings

* tests: Add integration tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: dataset resource tests

* chore: using model_dump instead of dict

* refactor: Using SettingsProperties class for field, vectors, and metadata management

* chore: Using id instead of external_id

* chore: Change test conditioN

* chore: Update settings tests with new settings properties container

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [BUGFIX] Export and import settings including vectors and metadata (#235)

* fix: metadata and vector infos are included for serialization

* tests: Add more tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestions from code review

Co-authored-by: David Berenstein <[email protected]>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: David Berenstein <[email protected]>
  • Loading branch information
3 people authored Jun 3, 2024
1 parent d3c82e1 commit 607cd27
Show file tree
Hide file tree
Showing 23 changed files with 575 additions and 362 deletions.
56 changes: 24 additions & 32 deletions src/argilla_sdk/_api/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
from uuid import UUID

import httpx

from argilla_sdk._api._base import ResourceAPI
from argilla_sdk._exceptions import api_error_handler
from argilla_sdk._models import FieldBaseModel, TextFieldModel, FieldModel
from argilla_sdk._models import FieldModel

__all__ = ["FieldsAPI"]


class FieldsAPI(ResourceAPI[FieldBaseModel]):
class FieldsAPI(ResourceAPI[FieldModel]):
"""Manage datasets via the API"""

http_client: httpx.Client
Expand All @@ -33,36 +34,39 @@ class FieldsAPI(ResourceAPI[FieldBaseModel]):
################

@api_error_handler
def create(self, dataset_id: UUID, field: FieldModel) -> FieldModel:
url = f"/api/v1/datasets/{dataset_id}/fields"
def get(self, id: UUID) -> FieldModel:
raise NotImplementedError()

@api_error_handler
def create(self, field: FieldModel) -> FieldModel:
url = f"/api/v1/datasets/{field.dataset_id}/fields"
response = self.http_client.post(url=url, json=field.model_dump())
response.raise_for_status()
response_json = response.json()
field_model = self._model_from_json(response_json=response_json)
self._log_message(message=f"Created field {field_model.name} in dataset {dataset_id}")
return field_model
created_field = self._model_from_json(response_json=response_json)
self._log_message(message=f"Created field {created_field.name} in dataset {field.dataset_id}")
return created_field

@api_error_handler
def update(self, field: FieldModel) -> FieldModel:
# TODO: Implement update method for fields with server side ID
raise NotImplementedError
url = f"/api/v1/fields/{field.id}"
response = self.http_client.patch(url, json=field.model_dump())
response.raise_for_status()
response_json = response.json()
updated_field = self._model_from_json(response_json)
self._log_message(message=f"Update field {updated_field.name} with id {field.id}")
return updated_field

@api_error_handler
def delete(self, dataset_id: UUID) -> None:
# TODO: Implement delete method for fields with server side ID
raise NotImplementedError
def delete(self, field_id: UUID) -> None:
url = f"/api/v1/fields/{field_id}"
self.http_client.delete(url).raise_for_status()
self._log_message(message=f"Deleted field {field_id}")

####################
# Utility methods #
####################

def create_many(self, dataset_id: UUID, fields: List[FieldModel]) -> List[FieldModel]:
field_models = []
for field in fields:
field_model = self.create(dataset_id=dataset_id, field=field)
field_models.append(field_model)
return field_models

@api_error_handler
def list(self, dataset_id: UUID) -> List[FieldModel]:
response = self.http_client.get(f"/api/v1/datasets/{dataset_id}/fields")
Expand All @@ -78,19 +82,7 @@ def list(self, dataset_id: UUID) -> List[FieldModel]:
def _model_from_json(self, response_json: Dict) -> FieldModel:
response_json["inserted_at"] = self._date_from_iso_format(date=response_json["inserted_at"])
response_json["updated_at"] = self._date_from_iso_format(date=response_json["updated_at"])
return self._get_model_from_response(response_json=response_json)
return FieldModel(**response_json)

def _model_from_jsons(self, response_jsons: List[Dict]) -> List[FieldModel]:
return list(map(self._model_from_json, response_jsons))

def _get_model_from_response(self, response_json: Dict) -> FieldModel:
try:
field_type = response_json.get("settings", {}).get("type")
except Exception as e:
raise ValueError("Invalid response type: missing 'settings.type' in response") from e
if field_type == "text":
# TODO: Avoid apply validations here (check_fields=False?)
return TextFieldModel(**response_json)
else:
# TODO: Add more field types
raise ValueError(f"Invalid field type: {field_type}")
50 changes: 21 additions & 29 deletions src/argilla_sdk/_api/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from uuid import UUID

import httpx

from argilla_sdk._api._base import ResourceAPI
from argilla_sdk._exceptions import api_error_handler
from argilla_sdk._models import MetadataFieldModel
Expand All @@ -33,53 +34,44 @@ class MetadataAPI(ResourceAPI[MetadataFieldModel]):
################

@api_error_handler
def create(self, dataset_id: UUID, metadata_field: MetadataFieldModel) -> MetadataFieldModel:
url = f"/api/v1/datasets/{dataset_id}/metadata-properties"
response = self.http_client.post(url=url, json=metadata_field.model_dump())
def get(self, metadata_id: UUID) -> MetadataFieldModel:
raise NotImplementedError()

@api_error_handler
def create(self, metadata: MetadataFieldModel) -> MetadataFieldModel:
url = f"/api/v1/datasets/{metadata.dataset_id}/metadata-properties"
response = self.http_client.post(url=url, json=metadata.model_dump())
response.raise_for_status()
response_json = response.json()
metadata_field_model = self._model_from_json(response_json=response_json)
self._log_message(message=f"Created metadata field {metadata_field_model.name} in dataset {dataset_id}")
return metadata_field_model
created_metadata = self._model_from_json(response_json=response_json)
self._log_message(message=f"Created metadata field {created_metadata.name} in dataset {metadata.dataset_id}")
return created_metadata

@api_error_handler
def update(self, metadata_field: MetadataFieldModel) -> MetadataFieldModel:
url = f"/api/v1/metadata-properties/{metadata_field.id}"
response = self.http_client.patch(url=url, json=metadata_field.model_dump())
def update(self, metadata: MetadataFieldModel) -> MetadataFieldModel:
url = f"/api/v1/metadata-properties/{metadata.id}"
response = self.http_client.patch(url=url, json=metadata.model_dump())
response.raise_for_status()
response_json = response.json()
metadata_field_model = self._model_from_json(response_json=response_json)
self._log_message(message=f"Updated field {metadata_field_model.name}")
return metadata_field_model
updated_metadata = self._model_from_json(response_json=response_json)
self._log_message(message=f"Updated metadata field {updated_metadata.name}")
return updated_metadata

@api_error_handler
def delete(self, id: UUID) -> None:
url = f"/api/v1/metadata-properties/{id}"
def delete(self, metadata_id: UUID) -> None:
url = f"/api/v1/metadata-properties/{metadata_id}"
self.http_client.delete(url=url).raise_for_status()
self._log_message(message=f"Deleted field {id}")

@api_error_handler
def get(self, id: UUID) -> MetadataFieldModel:
raise NotImplementedError()
self._log_message(message=f"Deleted metadata field {metadata_id}")

####################
# Utility methods #
####################

def create_many(self, dataset_id: UUID, metadata_fields: List[MetadataFieldModel]) -> List[MetadataFieldModel]:
metadata_field_models = []
for metadata_field in metadata_fields:
metadata_field_model = self.create(dataset_id=dataset_id, metadata_field=metadata_field)
metadata_field_models.append(metadata_field_model)
return metadata_field_models

@api_error_handler
def list(self, dataset_id: UUID) -> List[MetadataFieldModel]:
response = self.http_client.get(f"/api/v1/me/datasets/{dataset_id}/metadata-properties")
response.raise_for_status()
response_json = response.json()
metadata_field_model = self._model_from_jsons(response_jsons=response_json["items"])
return metadata_field_model
return self._model_from_jsons(response_jsons=response_json["items"])

####################
# Private methods #
Expand Down
33 changes: 17 additions & 16 deletions src/argilla_sdk/_api/_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from uuid import UUID

import httpx

from argilla_sdk._api._base import ResourceAPI
from argilla_sdk._exceptions import api_error_handler
from argilla_sdk._models import VectorFieldModel
Expand All @@ -33,36 +34,36 @@ class VectorsAPI(ResourceAPI[VectorFieldModel]):
################

@api_error_handler
def create(self, dataset_id: UUID, vector: VectorFieldModel) -> VectorFieldModel:
url = f"/api/v1/datasets/{dataset_id}/vectors-settings"
def create(self, vector: VectorFieldModel) -> VectorFieldModel:
url = f"/api/v1/datasets/{vector.dataset_id}/vectors-settings"
response = self.http_client.post(url=url, json=vector.model_dump())
response.raise_for_status()
response_json = response.json()
vector_model = self._model_from_json(response_json=response_json)
self._log_message(message=f"Created vector {vector_model.name} in dataset {dataset_id}")
return vector_model
created_vector = self._model_from_json(response_json=response_json)
self._log_message(message=f"Created vector {created_vector.name} in dataset {created_vector.dataset_id}")
return created_vector

@api_error_handler
def update(self, vector: VectorFieldModel) -> VectorFieldModel:
# TODO: Implement update method for vectors with server side ID
raise NotImplementedError
url = f"/api/v1/vectors-settings/{vector.id}"
response = self.http_client.patch(url, json=vector.model_dump())
response.raise_for_status()
response_json = response.json()
updated_vector = self._model_from_json(response_json)
self._log_message(message=f"Updated vector {updated_vector.name} with id {updated_vector.id}")
return updated_vector

@api_error_handler
def delete(self, vector_id: UUID) -> None:
# TODO: Implement delete method for vectors with server side ID
raise NotImplementedError
url = f"/api/v1/vectors-settings/{vector_id}"
response = self.http_client.delete(url)
response.raise_for_status()
self._log_message(message=f"Deleted vector with id {vector_id}")

####################
# Utility methods #
####################

def create_many(self, dataset_id: UUID, vectors: List[VectorFieldModel]) -> List[VectorFieldModel]:
vector_models = []
for vector in vectors:
vector_model = self.create(dataset_id=dataset_id, vector=vector)
vector_models.append(vector_model)
return vector_models

@api_error_handler
def list(self, dataset_id: UUID) -> List[VectorFieldModel]:
response = self.http_client.get(f"/api/v1/datasets/{dataset_id}/vectors-settings")
Expand Down
5 changes: 2 additions & 3 deletions src/argilla_sdk/_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
ScopeModel,
)
from argilla_sdk._models._settings._fields import (
TextFieldModel,
FieldSettings,
FieldBaseModel,
FieldModel,
TextFieldSettings,
FieldModel,
)
from argilla_sdk._models._settings._questions import (
Expand Down
24 changes: 9 additions & 15 deletions src/argilla_sdk/_models/_settings/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Optional, Literal
from uuid import UUID

from pydantic import BaseModel, field_serializer, field_validator, Field
from pydantic import BaseModel, field_serializer, field_validator
from pydantic_core.core_schema import ValidationInfo

from argilla_sdk._helpers import log_message
from argilla_sdk._models import ResourceModel


class FieldSettings(BaseModel):
type: str = Field(validate_default=True)
class TextFieldSettings(BaseModel):
type: Literal["text"] = "text"
use_markdown: Optional[bool] = False


class FieldBaseModel(BaseModel):
id: Optional[UUID] = None
class FieldModel(ResourceModel):
name: str

title: Optional[str] = None
required: bool = True
description: Optional[str] = None
settings: TextFieldSettings = TextFieldSettings(use_markdown=False)
dataset_id: Optional[UUID] = None

@field_validator("name")
@classmethod
Expand All @@ -48,13 +49,6 @@ def __title_default(cls, title: str, info: ValidationInfo) -> str:
log_message(f"TextField title is {validated_title}")
return validated_title

@field_serializer("id", when_used="unless-none")
@field_serializer("id", "dataset_id", when_used="unless-none")
def serialize_id(self, value: UUID) -> str:
return str(value)


class TextFieldModel(FieldBaseModel):
settings: FieldSettings = FieldSettings(type="text", use_markdown=False)


FieldModel = TextFieldModel
12 changes: 6 additions & 6 deletions src/argilla_sdk/_models/_settings/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator

from argilla_sdk._exceptions import MetadataError
from argilla_sdk._models import ResourceModel


class MetadataPropertyType(str, Enum):
Expand All @@ -36,9 +37,7 @@ class TermsMetadataPropertySettings(BaseMetadataPropertySettings):
type: Literal[MetadataPropertyType.terms]
values: Optional[List[str]] = None

@field_validator(
"values",
)
@field_validator("values")
@classmethod
def __validate_values(cls, values):
if values is None:
Expand Down Expand Up @@ -94,17 +93,18 @@ class FloatMetadataPropertySettings(NumericMetadataPropertySettings):
]


class MetadataFieldModel(BaseModel):
class MetadataFieldModel(ResourceModel):
"""The schema definition of a metadata field in an Argilla dataset."""

id: Optional[UUID] = None
name: str
settings: MetadataPropertySettings

type: Optional[MetadataPropertyType] = Field(None, validate_default=True)
title: Optional[str] = None
visible_for_annotators: Optional[bool] = True

dataset_id: Optional[UUID] = None

@field_validator("name")
@classmethod
def __name_lower(cls, name):
Expand All @@ -117,7 +117,7 @@ def __title_default(cls, title, values):
validated_title = title or values.data["name"]
return validated_title

@field_serializer("id", when_used="unless-none")
@field_serializer("id", "dataset_id", when_used="unless-none")
def serialize_id(self, value: UUID) -> str:
return str(value)

Expand Down
7 changes: 3 additions & 4 deletions src/argilla_sdk/_models/_settings/_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
from typing import Optional
from uuid import UUID

from pydantic import BaseModel, field_validator, field_serializer
from pydantic import field_validator, field_serializer
from pydantic_core.core_schema import ValidationInfo

from argilla_sdk._models import ResourceModel
from argilla_sdk._helpers import log_message


class VectorFieldModel(BaseModel):
class VectorFieldModel(ResourceModel):
name: str
title: Optional[str] = None
dimensions: int

id: Optional[UUID] = None
dataset_id: Optional[UUID] = None

@field_serializer("id", "dataset_id", when_used="unless-none")
Expand Down
10 changes: 9 additions & 1 deletion src/argilla_sdk/datasets/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ def create(self) -> "Dataset":
self.__rollback_dataset_creation()
raise SettingsError from e

def update(self) -> "Dataset":
"""Updates the dataset on the server with the current settings.
Returns:
Dataset: The updated dataset object.
"""
self.settings.update()
return self

@classmethod
def from_model(cls, model: DatasetModel, client: "Argilla") -> "Dataset":
return cls(client=client, _model=model)
Expand All @@ -173,7 +182,6 @@ def from_model(cls, model: DatasetModel, client: "Argilla") -> "Dataset":
#####################

def _publish(self) -> "Dataset":
self.settings.validate()
self._settings.create()
self._api.publish(dataset_id=self._model.id)

Expand Down
Loading

0 comments on commit 607cd27

Please sign in to comment.