From 96a0bdea406c8cb4dd0955834b2dba35d4b2199a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 13 Jun 2024 17:27:32 +0200 Subject: [PATCH] feat: use dataset_attrs as dictionary to avoid the use of DatasetCreate schema in datasets context --- .../api/handlers/v1/datasets/datasets.py | 2 +- .../src/argilla_server/contexts/datasets.py | 25 ++++++++----------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py index d1b1a4ce8f..7866e24d15 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py @@ -190,7 +190,7 @@ async def create_dataset( ): await authorize(current_user, DatasetPolicy.create(dataset_create.workspace_id)) - return await datasets.create_dataset(db, dataset_create) + return await datasets.create_dataset(db, dataset_create.dict()) @router.post("/datasets/{dataset_id}/fields", status_code=status.HTTP_201_CREATED, response_model=Field) diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index ee153aff59..ea2b399424 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -37,10 +37,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, selectinload -from argilla_server.api.schemas.v1.datasets import ( - DatasetCreate, - DatasetProgress, -) +from argilla_server.api.schemas.v1.datasets import DatasetProgress from argilla_server.api.schemas.v1.fields import FieldCreate from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyCreate, MetadataPropertyUpdate from argilla_server.api.schemas.v1.records import ( @@ -122,22 +119,22 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) -> return result.scalars().all() -async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate): - if await Workspace.get(db, dataset_create.workspace_id) is None: - raise UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found") +async def create_dataset(db: AsyncSession, dataset_attrs: dict): + if await Workspace.get(db, dataset_attrs["workspace_id"]) is None: + raise UnprocessableEntityError(f"Workspace with id `{dataset_attrs['workspace_id']}` not found") - if await Dataset.get_by(db, name=dataset_create.name, workspace_id=dataset_create.workspace_id): + if await Dataset.get_by(db, name=dataset_attrs["name"], workspace_id=dataset_attrs["workspace_id"]): raise NotUniqueError( - f"Dataset with name `{dataset_create.name}` already exists for workspace with id `{dataset_create.workspace_id}`" + f"Dataset with name `{dataset_attrs['name']}` already exists for workspace with id `{dataset_attrs['workspace_id']}`" ) return await Dataset.create( db, - name=dataset_create.name, - guidelines=dataset_create.guidelines, - allow_extra_metadata=dataset_create.allow_extra_metadata, - distribution=dataset_create.distribution.dict(), - workspace_id=dataset_create.workspace_id, + name=dataset_attrs["name"], + guidelines=dataset_attrs["guidelines"], + allow_extra_metadata=dataset_attrs["allow_extra_metadata"], + distribution=dataset_attrs["distribution"], + workspace_id=dataset_attrs["workspace_id"], )