Skip to content

Commit

Permalink
feat: delete suggestion from record on search engine (#4336)
Browse files Browse the repository at this point in the history
# Description

Delete a suggestion from the record document on search engine when the
suggestion is deleted using it's deletion endpoint.

Ref #4230 

**Type of change**

- [x] New feature (non-breaking change which adds functionality)

**How Has This Been Tested**

- [x] Running tests locally.
- [x] Checking manually that deleting a suggestion is affecting filters
later on the UI.

**Checklist**

- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Francisco Aranda <[email protected]>
  • Loading branch information
jfcalvo and frascuchon authored Nov 28, 2023
1 parent 968131c commit d42e888
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 12 deletions.
3 changes: 2 additions & 1 deletion src/argilla/server/apis/v1/handlers/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ async def upsert_suggestion(
async def delete_record_suggestions(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
record_id: UUID,
current_user: User = Security(auth.get_current_user),
ids: str = Query(..., description="A comma separated list with the IDs of the suggestions to be removed"),
Expand All @@ -195,7 +196,7 @@ async def delete_record_suggestions(
detail=f"Cannot delete more than {DELETE_RECORD_SUGGESTIONS_LIMIT} suggestions at once",
)

await datasets.delete_suggestions(db, record, suggestion_ids)
await datasets.delete_suggestions(db, search_engine, record, suggestion_ids)


@router.delete("/records/{record_id}", response_model=RecordSchema, response_model_exclude_unset=True)
Expand Down
4 changes: 3 additions & 1 deletion src/argilla/server/apis/v1/handlers/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from argilla.server.models import Suggestion, User
from argilla.server.policies import SuggestionPolicyV1, authorize
from argilla.server.schemas.v1.suggestions import Suggestion as SuggestionSchema
from argilla.server.search_engine import SearchEngine, get_search_engine
from argilla.server.security import auth

router = APIRouter(tags=["suggestions"])
Expand All @@ -41,6 +42,7 @@ async def _get_suggestion(db: "AsyncSession", suggestion_id: UUID) -> Suggestion
async def delete_suggestion(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
suggestion_id: UUID,
current_user: User = Security(auth.get_current_user),
):
Expand All @@ -49,6 +51,6 @@ async def delete_suggestion(
await authorize(current_user, SuggestionPolicyV1.delete(suggestion))

try:
return await datasets.delete_suggestion(db, suggestion)
return await datasets.delete_suggestion(db, search_engine, suggestion)
except ValueError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))
60 changes: 54 additions & 6 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,21 @@
# limitations under the License.
import copy
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
from uuid import UUID

import sqlalchemy
Expand Down Expand Up @@ -1113,22 +1127,56 @@ async def upsert_suggestion(
return suggestion


async def delete_suggestions(db: "AsyncSession", record: Record, suggestions_ids: List[UUID]) -> None:
async def delete_suggestions(
db: "AsyncSession", search_engine: SearchEngine, record: Record, suggestions_ids: List[UUID]
) -> None:
params = [Suggestion.id.in_(suggestions_ids), Suggestion.record_id == record.id]
await Suggestion.delete_many(db=db, params=params)
suggestions = await list_suggestions_by_id_and_record_id(db, suggestions_ids, record.id)

async with db.begin_nested():
await Suggestion.delete_many(db=db, params=params, autocommit=False)
for suggestion in suggestions:
await search_engine.delete_record_suggestion(suggestion)

await db.commit()


async def get_suggestion_by_id(db: "AsyncSession", suggestion_id: "UUID") -> Union[Suggestion, None]:
result = await db.execute(
select(Suggestion)
.filter_by(id=suggestion_id)
.options(selectinload(Suggestion.record).selectinload(Record.dataset))
.options(
selectinload(Suggestion.record).selectinload(Record.dataset),
selectinload(Suggestion.question),
)
)

return result.scalar_one_or_none()


async def delete_suggestion(db: "AsyncSession", suggestion: Suggestion) -> Suggestion:
return await suggestion.delete(db)
async def list_suggestions_by_id_and_record_id(
db: "AsyncSession", suggestion_ids: List[UUID], record_id: UUID
) -> Sequence[Suggestion]:
result = await db.execute(
select(Suggestion)
.filter(Suggestion.record_id == record_id, Suggestion.id.in_(suggestion_ids))
.options(
selectinload(Suggestion.record).selectinload(Record.dataset),
selectinload(Suggestion.question),
)
)

return result.scalars().all()


async def delete_suggestion(db: "AsyncSession", search_engine: SearchEngine, suggestion: Suggestion) -> Suggestion:
async with db.begin_nested():
suggestion = await suggestion.delete(db, autocommit=False)
await search_engine.delete_record_suggestion(suggestion)

await db.commit()

return suggestion


async def get_metadata_property_by_id(db: "AsyncSession", metadata_property_id: UUID) -> Optional[MetadataProperty]:
Expand Down
6 changes: 4 additions & 2 deletions src/argilla/server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Generic,
Iterable,
List,
Literal,
Optional,
Type,
TypeVar,
Expand All @@ -36,7 +35,6 @@
from argilla.server.enums import (
MetadataPropertyType,
RecordSortField,
ResponseStatus,
ResponseStatusFilter,
SimilarityOrder,
SortOrder,
Expand Down Expand Up @@ -315,6 +313,10 @@ async def delete_record_response(self, response: Response):
async def update_record_suggestion(self, suggestion: Suggestion):
pass

@abstractmethod
async def delete_record_suggestion(self, suggestion: Suggestion):
pass

@abstractmethod
async def search(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/argilla/server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ async def update_record_suggestion(self, suggestion: Suggestion):
body={"doc": {"suggestions": es_suggestions}},
)

async def delete_record_suggestion(self, suggestion: Suggestion):
index_name = await self._get_index_or_raise(suggestion.record.dataset)

await self._update_document_request(
index_name,
id=suggestion.record_id,
body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'},
)

async def set_records_vectors(self, dataset: Dataset, vectors: Iterable[Vector]):
index_name = await self._get_index_or_raise(dataset)

Expand Down
6 changes: 5 additions & 1 deletion tests/unit/server/api/v1/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Type
from unittest.mock import call
from uuid import UUID, uuid4

import pytest
Expand Down Expand Up @@ -1342,7 +1343,7 @@ async def test_delete_record_non_existent(self, async_client: "AsyncClient", own

@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner])
async def test_delete_record_suggestions(
self, async_client: "AsyncClient", db: "AsyncSession", role: UserRole
self, async_client: "AsyncClient", db: "AsyncSession", mock_search_engine: SearchEngine, role: UserRole
) -> None:
dataset = await DatasetFactory.create()
user = await UserFactory.create(workspaces=[dataset.workspace], role=role)
Expand All @@ -1363,6 +1364,9 @@ async def test_delete_record_suggestions(
assert response.status_code == 204
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 0

expected_calls = [call(suggestion) for suggestion in suggestions]
mock_search_engine.delete_record_suggestion.assert_has_calls(expected_calls)

async def test_delete_record_suggestions_with_no_ids(
self, async_client: "AsyncClient", owner_auth_header: dict
) -> None:
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/server/api/v1/test_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
from argilla._constants import API_KEY_HEADER_NAME
from argilla.server.models import Suggestion, UserRole
from argilla.server.search_engine import SearchEngine
from sqlalchemy import func, select

from tests.factories import SuggestionFactory, UserFactory
Expand All @@ -30,7 +31,9 @@
@pytest.mark.asyncio
class TestSuiteSuggestions:
@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner])
async def test_delete_suggestion(self, async_client: "AsyncClient", db: "AsyncSession", role: UserRole) -> None:
async def test_delete_suggestion(
self, async_client: "AsyncClient", mock_search_engine: SearchEngine, db: "AsyncSession", role: UserRole
) -> None:
suggestion = await SuggestionFactory.create()
user = await UserFactory.create(role=role, workspaces=[suggestion.record.dataset.workspace])

Expand All @@ -50,6 +53,8 @@ async def test_delete_suggestion(self, async_client: "AsyncClient", db: "AsyncSe
}
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 0

mock_search_engine.delete_record_suggestion.assert_called_once_with(suggestion)

async def test_delete_suggestion_non_existent(self, async_client: "AsyncClient", owner_auth_header: dict) -> None:
response = await async_client.delete(f"/api/v1/suggestions/{uuid4()}", headers=owner_auth_header)

Expand Down

0 comments on commit d42e888

Please sign in to comment.