diff --git a/src/hope_dedup_engine/apps/api/views.py b/src/hope_dedup_engine/apps/api/views.py index 0d3bf7a..bc8bf87 100644 --- a/src/hope_dedup_engine/apps/api/views.py +++ b/src/hope_dedup_engine/apps/api/views.py @@ -3,9 +3,10 @@ from typing import Any from uuid import UUID -from django.db.models import QuerySet +from django.db.models import Q, QuerySet -from drf_spectacular.utils import extend_schema +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter, extend_schema from rest_framework import mixins, status, viewsets from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated @@ -228,6 +229,9 @@ def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: return super().create(request, *args, **kwargs) +REFERENCE_PK = "reference_pk" + + class DuplicateViewSet( nested_viewsets.NestedViewSetMixin[Duplicate], mixins.ListModelMixin, @@ -245,7 +249,25 @@ class DuplicateViewSet( DEDUPLICATION_SET_PARAM: DEDUPLICATION_SET_FILTER, } - @extend_schema(description="List all duplicates found in the deduplication set") + def get_queryset(self) -> QuerySet[Duplicate]: + queryset = super().get_queryset() + if reference_pk := self.request.query_params.get(REFERENCE_PK): + return queryset.filter( + Q(first_reference_pk=reference_pk) | Q(second_reference_pk=reference_pk) + ) + return queryset + + @extend_schema( + description="List all duplicates found in the deduplication set", + parameters=[ + OpenApiParameter( + REFERENCE_PK, + OpenApiTypes.STR, + OpenApiParameter.QUERY, + description="Filters results by reference pk", + ) + ], + ) def list(self, request: Request, *args: Any, **kwargs: Any) -> Response: return super().list(request, *args, **kwargs) diff --git a/tests/api/test_duplicate_list.py b/tests/api/test_duplicate_list.py index 3aff045..a1ed169 100644 --- a/tests/api/test_duplicate_list.py +++ b/tests/api/test_duplicate_list.py @@ -1,10 +1,17 @@ +from collections.abc import Callable +from operator import attrgetter +from urllib.parse import urlencode + from api_const import DUPLICATE_LIST_VIEW +from factory.fuzzy import FuzzyText +from pytest import mark from rest_framework import status from rest_framework.reverse import reverse from rest_framework.test import APIClient from hope_dedup_engine.apps.api.models import DeduplicationSet from hope_dedup_engine.apps.api.models.deduplication import Duplicate +from hope_dedup_engine.apps.api.views import REFERENCE_PK def test_can_list_duplicates( @@ -26,3 +33,30 @@ def test_cannot_list_duplicates_between_systems( reverse(DUPLICATE_LIST_VIEW, (deduplication_set.pk,)) ) assert response.status_code == status.HTTP_403_FORBIDDEN + + +@mark.parametrize( + ("filter_value_getter", "expected_amount"), + ( + # filter by first_reference_pk + (attrgetter("first_reference_pk"), 1), + # filter by second_reference_pk + (attrgetter("second_reference_pk"), 1), + # filter by random string + (lambda _: FuzzyText().fuzz(), 0), + ), +) +def test_can_filter_by_reference_pk( + api_client: APIClient, + deduplication_set: DeduplicationSet, + duplicate: Duplicate, + filter_value_getter: Callable[[Duplicate], str], + expected_amount: int, +) -> None: + url = f"{reverse(DUPLICATE_LIST_VIEW, (deduplication_set.pk, ))}?" + urlencode( + {REFERENCE_PK: filter_value_getter(duplicate)} + ) + response = api_client.get(url) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data) == expected_amount