Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
feat: records bulk create and upsert endpoints (#106)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This is the main feature branch to crate new ingestion endpoints for
records. This is still a WIP and decisions made here will be aligned to
changes and workflows defined in the new [python
SDK](https://github.com/argilla-io/argilla-python)

Refs #79

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] New feature (non-breaking change which adds functionality)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I added relevant documentation
- [ ] I followed the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: José Francisco Calvo <[email protected]>
Co-authored-by: José Francisco Calvo <[email protected]>
  • Loading branch information
4 people authored Apr 30, 2024
1 parent 1619a8e commit 7f706fe
Show file tree
Hide file tree
Showing 25 changed files with 2,253 additions and 23 deletions.
1 change: 1 addition & 0 deletions .github/workflows/package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ on:
- "feature/**"
- "feat/**"
- "fix/**"
- "tests/**"
types:
- opened
- reopened
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ These are the section headers that we use:
- Added `GET /api/v1/settings` new endpoint exposing Argilla and Hugging Face settings when available. ([#127](https://github.com/argilla-io/argilla-server/pull/127))
- Added `ARGILLA_SHOW_HUGGINGFACE_SPACE_PERSISTANT_STORAGE_WARNING` new environment variable to disable warning message when Hugging Face Spaces persistent storage is disabled. ([#124](https://github.com/argilla-io/argilla-server/pull/124))
- Added `options_order` new settings attribute to support specify an order for options in multi label selection questions. ([#133](https://github.com/argilla-io/argilla-server/pull/133))
- Added `POST /api/v1/datasets/:dataset_id/records/bulk` endpoint. ([#106](https://github.com/argilla-io/argilla-server/pull/106))
- Added `PUT /api/v1/datasets/:dataset_id/records/bulk` endpoint. ([#106](https://github.com/argilla-io/argilla-server/pull/106))

### Deprecated

- Deprecated `POST /api/v1/datasets/:dataset_id/records` in favour of `POST /api/v1/datasets/:dataset_id/records/bulk`. ([#130](https://github.com/argilla-io/argilla-server/pull/130))
- Deprecated `PATCH /api/v1/dataset/:dataset_id/records` in favour of `PUT /api/v1/datasets/:dataset_id/records/bulk`. ([#130](https://github.com/argilla-io/argilla-server/pull/130))

### Removed

Expand Down
3 changes: 3 additions & 0 deletions src/argilla_server/apis/v1/handlers/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
from argilla_server.apis.v1.handlers.datasets.datasets import router as datasets_router
from argilla_server.apis.v1.handlers.datasets.questions import router as questions_router
from argilla_server.apis.v1.handlers.datasets.records import router as records_router
from argilla_server.apis.v1.handlers.datasets.records_bulk import router as records_bulk_router

router = APIRouter(tags=["datasets"])

router.include_router(datasets_router)
router.include_router(questions_router)
router.include_router(records_router)
router.include_router(records_bulk_router)
14 changes: 12 additions & 2 deletions src/argilla_server/apis/v1/handlers/datasets/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,12 @@ async def list_dataset_records(
return Records(items=records, total=total)


@router.post("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT)
@router.post(
"/datasets/{dataset_id}/records",
status_code=status.HTTP_204_NO_CONTENT,
deprecated=True,
description="Deprecated in favor of POST /datasets/{dataset_id}/records/bulk",
)
async def create_dataset_records(
*,
db: AsyncSession = Depends(get_async_db),
Expand All @@ -487,7 +492,12 @@ async def create_dataset_records(
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))


@router.patch("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT)
@router.patch(
"/datasets/{dataset_id}/records",
status_code=status.HTTP_204_NO_CONTENT,
deprecated=True,
description="Deprecated in favor of PUT /datasets/{dataset_id}/records/bulk",
)
async def update_dataset_records(
*,
db: AsyncSession = Depends(get_async_db),
Expand Down
100 changes: 100 additions & 0 deletions src/argilla_server/apis/v1/handlers/datasets/records_bulk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Security
from sqlalchemy.ext.asyncio import AsyncSession
from starlette import status

from argilla_server.apis.v1.handlers.datasets.datasets import _get_dataset_or_raise
from argilla_server.bulk.records_bulk import CreateRecordsBulk, UpsertRecordsBulk
from argilla_server.database import get_async_db
from argilla_server.models import User
from argilla_server.policies import DatasetPolicyV1, authorize
from argilla_server.schemas.v1.records_bulk import RecordsBulk, RecordsBulkCreate, RecordsBulkUpsert
from argilla_server.search_engine import SearchEngine, get_search_engine
from argilla_server.security import auth
from argilla_server.telemetry import TelemetryClient, get_telemetry_client

router = APIRouter()


@router.post(
"/datasets/{dataset_id}/records/bulk",
response_model=RecordsBulk,
status_code=status.HTTP_201_CREATED,
)
async def create_dataset_records_bulk(
*,
dataset_id: UUID,
records_bulk_create: RecordsBulkCreate,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
):
dataset = await _get_dataset_or_raise(
db,
dataset_id,
with_fields=True,
with_questions=True,
with_metadata_properties=True,
with_vectors_settings=True,
)

try:
await authorize(current_user, DatasetPolicyV1.create_records(dataset))

records_bulk = await CreateRecordsBulk(db, search_engine).create_records_bulk(dataset, records_bulk_create)
telemetry_client.track_data(action="DatasetRecordsCreated", data={"records": len(records_bulk.items)})

return records_bulk
except ValueError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))


@router.put("/datasets/{dataset_id}/records/bulk", response_model=RecordsBulk)
async def upsert_dataset_records_bulk(
*,
dataset_id: UUID,
records_bulk_create: RecordsBulkUpsert,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
current_user: User = Security(auth.get_current_user),
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
):
dataset = await _get_dataset_or_raise(
db,
dataset_id,
with_fields=True,
with_questions=True,
with_metadata_properties=True,
with_vectors_settings=True,
)

await authorize(current_user, DatasetPolicyV1.upsert_records(dataset))

try:
records_bulk = await UpsertRecordsBulk(db, search_engine).upsert_records_bulk(dataset, records_bulk_create)

updated = len(records_bulk.updated_item_ids)
created = len(records_bulk.items) - updated

telemetry_client.track_data(action="DatasetRecordsCreated", data={"records": created})
telemetry_client.track_data(action="DatasetRecordsUpdated", data={"records": updated})

return records_bulk
except ValueError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))
13 changes: 13 additions & 0 deletions src/argilla_server/bulk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 7f706fe

Please sign in to comment.