Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: improve upsert responses bulk #4451

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 14 additions & 29 deletions src/argilla/server/apis/v1/handlers/responses.py
Original file line number Diff line number Diff line change
@@ -17,23 +17,25 @@
from fastapi import APIRouter, Depends, HTTPException, Security, status
from sqlalchemy.ext.asyncio import AsyncSession

import argilla.server.errors.future as errors
from argilla.server.contexts import datasets
from argilla.server.database import get_async_db
from argilla.server.models import Record, Response, User
from argilla.server.policies import RecordPolicyV1, ResponsePolicyV1, authorize
from argilla.server.errors.future import NotFoundError
from argilla.server.models import Response, User
from argilla.server.policies import ResponsePolicyV1, authorize
from argilla.server.schemas.v1.responses import (
Response as ResponseSchema,
)
from argilla.server.schemas.v1.responses import (
ResponseBulk,
ResponseBulkError,
ResponsesBulk,
ResponsesBulkCreate,
ResponseUpdate,
)
from argilla.server.search_engine import SearchEngine, get_search_engine
from argilla.server.security import auth
from argilla.server.use_cases.responses.upsert_responses_in_bulk import (
UpsertResponsesInBulkUseCase,
UpsertResponsesInBulkUseCaseFactory,
)

router = APIRouter(tags=["responses"])

@@ -49,36 +51,19 @@ async def _get_response(db: AsyncSession, response_id: UUID) -> Response:
return response


async def _get_record(db: AsyncSession, record_id: UUID) -> Record:
record = await datasets.get_record_by_id(db, record_id, with_dataset=True)
if record is None:
raise errors.NotFoundError(f"Record with id `{record_id}` not found")

return record


@router.post("/me/responses/bulk", response_model=ResponsesBulk)
async def create_current_user_responses_bulk(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
body: ResponsesBulkCreate,
current_user: User = Security(auth.get_current_user),
use_case: UpsertResponsesInBulkUseCase = Depends(UpsertResponsesInBulkUseCaseFactory()),
):
responses_bulk_items = []
for item in body.items:
try:
record = await _get_record(db, item.record_id)

await authorize(current_user, RecordPolicyV1.create_response(record))

response = await datasets.upsert_response(db, search_engine, record, current_user, item)
except Exception as err:
responses_bulk_items.append(ResponseBulk(item=None, error=ResponseBulkError(detail=str(err))))
else:
responses_bulk_items.append(ResponseBulk(item=ResponseSchema.from_orm(response), error=None))

return ResponsesBulk(items=responses_bulk_items)
try:
responses_bulk_items = await use_case.execute(body.items, user=current_user)
except NotFoundError as err:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(err))
else:
return ResponsesBulk(items=responses_bulk_items)


@router.put("/responses/{response_id}", response_model=ResponseSchema)
23 changes: 19 additions & 4 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
@@ -421,12 +421,17 @@ async def get_record_by_id(

async def get_records_by_ids(
db: "AsyncSession",
dataset_id: UUID,
records_ids: Iterable[UUID],
dataset_id: Optional[UUID] = None,
include: Optional["RecordIncludeParam"] = None,
user_id: Optional[UUID] = None,
) -> List[Record]:
query = select(Record).filter(Record.dataset_id == dataset_id, Record.id.in_(records_ids))
) -> List[Union[Record, None]]:
query = select(Record)

if dataset_id:
query.filter(Record.dataset_id == dataset_id)

query = query.filter(Record.id.in_(records_ids))

if include and include.with_responses:
if not user_id:
@@ -443,7 +448,7 @@ async def get_records_by_ids(

# Preserve the order of the `record_ids` list
record_order_map = {record.id: record for record in records}
ordered_records = [record_order_map[record_id] for record_id in records_ids]
ordered_records = [record_order_map.get(record_id, None) for record_id in records_ids]

return ordered_records

@@ -821,6 +826,16 @@ async def _preload_record_relationships_before_index(db: "AsyncSession", record:
)


async def preload_records_relationships_before_validate(db: "AsyncSession", records: List[Record]) -> None:
await db.execute(
select(Record)
.filter(Record.id.in_([record.id for record in records]))
.options(
selectinload(Record.dataset).selectinload(Dataset.questions),
)
)


async def update_records(
db: "AsyncSession", search_engine: "SearchEngine", dataset: Dataset, records_update: "RecordsUpdate"
) -> None:
28 changes: 13 additions & 15 deletions src/argilla/server/search_engine/commons.py
Original file line number Diff line number Diff line change
@@ -270,7 +270,7 @@ async def create_index(self, dataset: Dataset):

async def configure_metadata_property(self, dataset: Dataset, metadata_property: MetadataProperty):
mapping = es_mapping_for_metadata_property(metadata_property)
index_name = await self._get_index_or_raise(dataset)
index_name = await self._get_dataset_index(dataset)

await self.put_index_mapping_request(index_name, mapping)

@@ -280,7 +280,7 @@ async def delete_index(self, dataset: Dataset):
await self._delete_index_request(index_name)

async def index_records(self, dataset: Dataset, records: Iterable[Record]):
index_name = await self._get_index_or_raise(dataset)
index_name = await self._get_dataset_index(dataset)

bulk_actions = [
{
@@ -297,30 +297,30 @@ async def index_records(self, dataset: Dataset, records: Iterable[Record]):
await self._refresh_index_request(index_name)

async def delete_records(self, dataset: Dataset, records: Iterable[Record]):
index_name = await self._get_index_or_raise(dataset)
index_name = await self._get_dataset_index(dataset)

bulk_actions = [{"_op_type": "delete", "_id": record.id, "_index": index_name} for record in records]

await self._bulk_op_request(bulk_actions)

async def update_record_response(self, response: Response):
record = response.record
index_name = await self._get_index_or_raise(record.dataset)
index_name = await self._get_dataset_index(record.dataset)

es_responses = self._map_record_responses_to_es([response])

await self._update_document_request(index_name, id=record.id, body={"doc": {"responses": es_responses}})

async def delete_record_response(self, response: Response):
record = response.record
index_name = await self._get_index_or_raise(record.dataset)
index_name = await self._get_dataset_index(record.dataset)

await self._update_document_request(
index_name, id=record.id, body={"script": f'ctx._source["responses"].remove("{response.user.username}")'}
)

async def update_record_suggestion(self, suggestion: Suggestion):
index_name = await self._get_index_or_raise(suggestion.record.dataset)
index_name = await self._get_dataset_index(suggestion.record.dataset)

es_suggestions = self._map_record_suggestions_to_es([suggestion])

@@ -331,7 +331,7 @@ async def update_record_suggestion(self, suggestion: Suggestion):
)

async def delete_record_suggestion(self, suggestion: Suggestion):
index_name = await self._get_index_or_raise(suggestion.record.dataset)
index_name = await self._get_dataset_index(suggestion.record.dataset)

await self._update_document_request(
index_name,
@@ -340,7 +340,7 @@ async def delete_record_suggestion(self, suggestion: Suggestion):
)

async def set_records_vectors(self, dataset: Dataset, vectors: Iterable[Vector]):
index_name = await self._get_index_or_raise(dataset)
index_name = await self._get_dataset_index(dataset)

bulk_actions = [
{
@@ -399,7 +399,7 @@ async def similarity_search(
# Wrapping filter in a list to use easily on each engine implementation
query_filters = [self.build_elasticsearch_filter(filter)]

index = await self._get_index_or_raise(dataset)
index = await self._get_dataset_index(dataset)
response = await self._request_similarity_search(
index=index,
vector_settings=vector_settings,
@@ -548,7 +548,7 @@ def _map_record_metadata_to_es(
return search_engine_metadata

async def configure_index_vectors(self, vector_settings: VectorSettings) -> None:
index = await self._get_index_or_raise(vector_settings.dataset)
index = await self._get_dataset_index(vector_settings.dataset)

mappings = self._mapping_for_vector_settings(vector_settings)
await self.put_index_mapping_request(index, mappings)
@@ -586,15 +586,15 @@ async def search(
bool_query["filter"] = self.build_elasticsearch_filter(filter)

es_query = {"bool": bool_query}
index = await self._get_index_or_raise(dataset)
index = await self._get_dataset_index(dataset)

es_sort = self.build_elasticsearch_sort(sort) if sort else None
response = await self._index_search_request(index, query=es_query, size=limit, from_=offset, sort=es_sort)

return await self._process_search_response(response)

async def compute_metrics_for(self, metadata_property: MetadataProperty) -> MetadataMetrics:
index_name = await self._get_index_or_raise(metadata_property.dataset)
index_name = await self._get_dataset_index(metadata_property.dataset)

if metadata_property.type == MetadataPropertyType.terms:
return await self._metrics_for_terms_property(index_name, metadata_property)
@@ -730,10 +730,8 @@ def _dynamic_templates_for_question_responses(self, questions: List[Question]) -
],
]

async def _get_index_or_raise(self, dataset: Dataset):
async def _get_dataset_index(self, dataset: Dataset):
index_name = es_index_name_for_dataset(dataset)
if not await self._index_exists_request(index_name):
raise ValueError(f"Cannot access to index for dataset {dataset.id}: the specified index does not exist")

return index_name

13 changes: 13 additions & 0 deletions src/argilla/server/use_cases/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions src/argilla/server/use_cases/responses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
60 changes: 60 additions & 0 deletions src/argilla/server/use_cases/responses/upsert_responses_in_bulk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession

from argilla.server.contexts import datasets
from argilla.server.database import get_async_db
from argilla.server.errors import future as errors
from argilla.server.models import User
from argilla.server.policies import RecordPolicyV1, authorize
from argilla.server.schemas.v1.responses import Response, ResponseBulk, ResponseBulkError, ResponseUpsert
from argilla.server.search_engine import SearchEngine, get_search_engine


class UpsertResponsesInBulkUseCase:
def __init__(self, db: AsyncSession, search_engine: SearchEngine):
self.db = db
self.search_engine = search_engine

async def execute(self, responses: List[ResponseUpsert], user: User) -> List[ResponseBulk]:
responses_bulk_items = []

all_records = await datasets.get_records_by_ids(self.db, [item.record_id for item in responses])
non_empty_records = [r for r in all_records if r is not None]

await datasets.preload_records_relationships_before_validate(self.db, non_empty_records)
for item, record in zip(responses, all_records):
try:
if record is None:
raise errors.NotFoundError(f"Record with id `{item.record_id}` not found")

await authorize(user, RecordPolicyV1.create_response(record))
response = await datasets.upsert_response(self.db, self.search_engine, record, user, item)
except Exception as err:
responses_bulk_items.append(ResponseBulk(item=None, error=ResponseBulkError(detail=str(err))))
else:
responses_bulk_items.append(ResponseBulk(item=Response.from_orm(response), error=None))

return responses_bulk_items


class UpsertResponsesInBulkUseCaseFactory:
def __call__(
self, db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine)
):
return UpsertResponsesInBulkUseCase(db, search_engine)
Loading