Skip to content

Commit

Permalink
[FEATURE] Add support to update record fields (#5685)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR adds backend support to update record fields.

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Improvement (change adding some improvement to an existing
functionality)


**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Dec 11, 2024
1 parent 21d07c9 commit 0487936
Show file tree
Hide file tree
Showing 14 changed files with 397 additions and 403 deletions.
5 changes: 5 additions & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ These are the section headers that we use:

## [Unreleased]()

### Added

- Added support to update record fields in `PATCH /api/v1/records/:record_id` endpoint. ([#5685](https://github.com/argilla-io/argilla/pull/5685))
- Added support to update record fields in `PUT /api/v1/datasets/:dataset_id/records/bulk` endpoint. ([#5685](https://github.com/argilla-io/argilla/pull/5685))

## [2.5.0](https://github.com/argilla-io/argilla/compare/v2.4.1...v2.5.0)

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def delete_dataset_records(
if num_records > DELETE_DATASET_RECORDS_LIMIT:
raise UnprocessableEntityError(f"Cannot delete more than {DELETE_DATASET_RECORDS_LIMIT} records at once")

await datasets.delete_records(db, search_engine, dataset, record_ids)
await records.delete_records(db, search_engine, dataset, record_ids)


@router.post(
Expand Down
15 changes: 10 additions & 5 deletions argilla-server/src/argilla_server/api/handlers/v1/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from argilla_server.api.schemas.v1.responses import Response, ResponseCreate
from argilla_server.api.schemas.v1.suggestions import Suggestion as SuggestionSchema
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate, Suggestions
from argilla_server.contexts import datasets
from argilla_server.contexts import datasets, records
from argilla_server.database import get_async_db
from argilla_server.errors.future.base_errors import NotFoundError, UnprocessableEntityError
from argilla_server.models import Dataset, Question, Record, Suggestion, User
Expand Down Expand Up @@ -74,16 +74,21 @@ async def update_record(
db,
record_id,
options=[
selectinload(Record.dataset).selectinload(Dataset.questions),
selectinload(Record.dataset).selectinload(Dataset.metadata_properties),
selectinload(Record.dataset).options(
selectinload(Dataset.questions),
selectinload(Dataset.metadata_properties),
selectinload(Dataset.vectors_settings),
selectinload(Dataset.fields),
),
selectinload(Record.suggestions),
selectinload(Record.responses),
selectinload(Record.vectors),
],
)

await authorize(current_user, RecordPolicy.update(record))

return await datasets.update_record(db, search_engine, record, record_update)
return await records.update_record(db, search_engine, record, record_update)


@router.post("/records/{record_id}/responses", status_code=status.HTTP_201_CREATED, response_model=Response)
Expand Down Expand Up @@ -233,4 +238,4 @@ async def delete_record(

await authorize(current_user, RecordPolicy.delete(record))

return await datasets.delete_record(db, search_engine, record)
return await records.delete_record(db, search_engine, record)
29 changes: 10 additions & 19 deletions argilla-server/src/argilla_server/api/schemas/v1/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
BaseModel,
Field,
StrictStr,
root_validator,
validator,
ValidationError,
ConfigDict,
model_validator,
Expand Down Expand Up @@ -183,17 +181,12 @@ def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict


class RecordUpdate(UpdateSchema):
metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata")
fields: Optional[Dict[str, FieldValueCreate]] = None
metadata: Optional[Dict[str, Any]] = None
suggestions: Optional[List[SuggestionCreate]] = None
vectors: Optional[Dict[str, List[float]]] = None

@property
def metadata(self) -> Optional[Dict[str, Any]]:
# Align with the RecordCreate model. Both should have the same name for the metadata field.
# TODO(@frascuchon): This will be properly adapted once the bulk records refactor is completed.
return self.metadata_

@field_validator("metadata_")
@field_validator("metadata")
@classmethod
def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if metadata is None:
Expand All @@ -205,15 +198,20 @@ def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict

return {k: v for k, v in metadata.items() if v == v} # By definition, NaN != NaN

def is_set(self, attribute: str) -> bool:
return attribute in self.model_fields_set

class RecordUpdateWithId(RecordUpdate):
id: UUID
def has_changes(self) -> bool:
return self.model_dump(exclude_unset=True) != {}


class RecordUpsert(RecordCreate):
id: Optional[UUID] = None
fields: Optional[Dict[str, FieldValueCreate]] = None

def is_set(self, attribute: str) -> bool:
return attribute in self.model_fields_set


class RecordIncludeParam(BaseModel):
relationships: Optional[List[RecordInclude]] = Field(None, alias="keys")
Expand Down Expand Up @@ -278,13 +276,6 @@ class RecordsCreate(BaseModel):
items: List[RecordCreate] = Field(..., min_length=RECORDS_CREATE_MIN_ITEMS, max_length=RECORDS_CREATE_MAX_ITEMS)


class RecordsUpdate(BaseModel):
# TODO: review this definition and align to create model
items: List[RecordUpdateWithId] = Field(
..., min_length=RECORDS_UPDATE_MIN_ITEMS, max_length=RECORDS_UPDATE_MAX_ITEMS
)


class MetadataParsedQueryParam:
def __init__(self, string: str):
k, *v = string.split(":", maxsplit=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class RecordsBulk(BaseModel):
items: List[Record]


class RecordsBulkWithUpdateInfo(RecordsBulk):
class RecordsBulkWithUpdatedItemIds(RecordsBulk):
updated_item_ids: List[UUID]


Expand Down
24 changes: 13 additions & 11 deletions argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, List, Sequence, Tuple, Union
from uuid import UUID

from datetime import UTC
from fastapi.encoders import jsonable_encoder
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -26,7 +27,7 @@
RecordsBulk,
RecordsBulkCreate,
RecordsBulkUpsert,
RecordsBulkWithUpdateInfo,
RecordsBulkWithUpdatedItemIds,
)
from argilla_server.api.schemas.v1.responses import UserResponseCreate
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
Expand All @@ -39,7 +40,7 @@
fetch_records_by_ids_as_dict,
)
from argilla_server.errors.future import UnprocessableEntityError
from argilla_server.models import Dataset, Record, Response, Suggestion, Vector, VectorSettings
from argilla_server.models import Dataset, Record, Response, Suggestion, Vector
from argilla_server.search_engine import SearchEngine
from argilla_server.validators.records import RecordsBulkCreateValidator, RecordUpsertValidator

Expand Down Expand Up @@ -154,15 +155,11 @@ async def _upsert_records_vectors(
autocommit=False,
)

@classmethod
def _metadata_is_set(cls, record_create: RecordCreate) -> bool:
return "metadata" in record_create.model_fields_set


class UpsertRecordsBulk(CreateRecordsBulk):
async def upsert_records_bulk(
self, dataset: Dataset, bulk_upsert: RecordsBulkUpsert, raise_on_error: bool = True
) -> RecordsBulkWithUpdateInfo:
) -> RecordsBulkWithUpdatedItemIds:
found_records = await self._fetch_existing_dataset_records(dataset, bulk_upsert.items)

records = []
Expand All @@ -185,9 +182,14 @@ async def upsert_records_bulk(
external_id=record_upsert.external_id,
dataset_id=dataset.id,
)
elif self._metadata_is_set(record_upsert):
record.metadata_ = record_upsert.metadata
record.updated_at = datetime.utcnow()
else:
if record_upsert.is_set("metadata"):
record.metadata_ = record_upsert.metadata
if record_upsert.is_set("fields"):
record.fields = jsonable_encoder(record_upsert.fields)

if self._db.is_modified(record):
record.updated_at = datetime.now(UTC)

records.append(record)

Expand All @@ -203,7 +205,7 @@ async def upsert_records_bulk(

await self._notify_upsert_record_events(records)

return RecordsBulkWithUpdateInfo(
return RecordsBulkWithUpdatedItemIds(
items=records,
updated_item_ids=[record.id for record in found_records.values()],
)
Expand Down
Loading

0 comments on commit 0487936

Please sign in to comment.