Skip to content

Commit

Permalink
feat: use dataset_attrs as dictionary to avoid the use of DatasetCrea…
Browse files Browse the repository at this point in the history
…te schema in datasets context
  • Loading branch information
jfcalvo committed Jun 13, 2024
1 parent d5b762c commit 96a0bde
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ async def create_dataset(
):
await authorize(current_user, DatasetPolicy.create(dataset_create.workspace_id))

return await datasets.create_dataset(db, dataset_create)
return await datasets.create_dataset(db, dataset_create.dict())


@router.post("/datasets/{dataset_id}/fields", status_code=status.HTTP_201_CREATED, response_model=Field)
Expand Down
25 changes: 11 additions & 14 deletions argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, joinedload, selectinload

from argilla_server.api.schemas.v1.datasets import (
DatasetCreate,
DatasetProgress,
)
from argilla_server.api.schemas.v1.datasets import DatasetProgress
from argilla_server.api.schemas.v1.fields import FieldCreate
from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyCreate, MetadataPropertyUpdate
from argilla_server.api.schemas.v1.records import (
Expand Down Expand Up @@ -122,22 +119,22 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) ->
return result.scalars().all()


async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate):
if await Workspace.get(db, dataset_create.workspace_id) is None:
raise UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found")
async def create_dataset(db: AsyncSession, dataset_attrs: dict):
if await Workspace.get(db, dataset_attrs["workspace_id"]) is None:
raise UnprocessableEntityError(f"Workspace with id `{dataset_attrs['workspace_id']}` not found")

if await Dataset.get_by(db, name=dataset_create.name, workspace_id=dataset_create.workspace_id):
if await Dataset.get_by(db, name=dataset_attrs["name"], workspace_id=dataset_attrs["workspace_id"]):
raise NotUniqueError(
f"Dataset with name `{dataset_create.name}` already exists for workspace with id `{dataset_create.workspace_id}`"
f"Dataset with name `{dataset_attrs['name']}` already exists for workspace with id `{dataset_attrs['workspace_id']}`"
)

return await Dataset.create(
db,
name=dataset_create.name,
guidelines=dataset_create.guidelines,
allow_extra_metadata=dataset_create.allow_extra_metadata,
distribution=dataset_create.distribution.dict(),
workspace_id=dataset_create.workspace_id,
name=dataset_attrs["name"],
guidelines=dataset_attrs["guidelines"],
allow_extra_metadata=dataset_attrs["allow_extra_metadata"],
distribution=dataset_attrs["distribution"],
workspace_id=dataset_attrs["workspace_id"],
)


Expand Down

0 comments on commit 96a0bde

Please sign in to comment.