Skip to content

Commit

Permalink
78: Allow per deduplication set face distance threshold configuration (
Browse files Browse the repository at this point in the history
…#85)

* Add possibility to configure face distance threshold per deduplication set

* Check face_distance_threshold value passed to DuplicationDetector

* Resolve merge issues
  • Loading branch information
sergey-misuk-valor authored Sep 24, 2024
1 parent 5a16805 commit e15ffc5
Show file tree
Hide file tree
Showing 14 changed files with 175 additions and 28 deletions.
10 changes: 9 additions & 1 deletion src/hope_dedup_engine/apps/api/deduplication/adapters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections.abc import Generator

from constance import config

from hope_dedup_engine.apps.api.deduplication.registry import DuplicateKeyPair
from hope_dedup_engine.apps.api.models import DeduplicationSet
from hope_dedup_engine.apps.faces.services.duplication_detector import (
Expand All @@ -20,8 +22,14 @@ def run(self) -> Generator[DuplicateKeyPair, None, None]:
"reference_pk", "filename"
)
}
face_distance_threshold: float = (
self.deduplication_set.config
and self.deduplication_set.config.face_distance_threshold
) or config.FACE_DISTANCE_THRESHOLD
# ignored key pairs are not handled correctly in DuplicationDetector
detector = DuplicationDetector(tuple[str](filename_to_reference_pk.keys()), ())
detector = DuplicationDetector(
tuple[str](filename_to_reference_pk.keys()), face_distance_threshold
)
for first_filename, second_filename, distance in detector.find_duplicates():
yield filename_to_reference_pk[first_filename], filename_to_reference_pk[
second_filename
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Generated by Django 5.0.7 on 2024-09-24 09:05

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("api", "0004_remove_deduplicationset_error_and_more"),
]

operations = [
migrations.CreateModel(
name="Config",
fields=[
(
"id",
models.BigAutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("face_distance_threshold", models.FloatField(null=True)),
],
),
migrations.AddField(
model_name="deduplicationset",
name="config",
field=models.OneToOneField(
null=True, on_delete=django.db.models.deletion.SET_NULL, to="api.config"
),
),
]
5 changes: 5 additions & 0 deletions src/hope_dedup_engine/apps/api/models/deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
REFERENCE_PK_LENGTH: Final[int] = 100


class Config(models.Model):
face_distance_threshold = models.FloatField(null=True)


class DeduplicationSet(models.Model):
"""
Bucket for entries we want to deduplicate
Expand Down Expand Up @@ -52,6 +56,7 @@ class State(models.IntegerChoices):
)
updated_at = models.DateTimeField(auto_now=True)
notification_url = models.CharField(max_length=255, null=True, blank=True)
config = models.OneToOneField(Config, null=True, on_delete=models.SET_NULL)

def __str__(self) -> str:
return f"ID: {self.pk}" if not self.name else f"{self.name}"
Expand Down
23 changes: 22 additions & 1 deletion src/hope_dedup_engine/apps/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@

from hope_dedup_engine.apps.api.models import DeduplicationSet
from hope_dedup_engine.apps.api.models.deduplication import (
Config,
Duplicate,
IgnoredKeyPair,
Image,
)

CONFIG = "config"


class ConfigSerializer(serializers.ModelSerializer):
class Meta:
model = Config
exclude = ("id",)


class DeduplicationSetSerializer(serializers.ModelSerializer):
state = serializers.CharField(source="get_state_display", read_only=True)
config = ConfigSerializer(required=False)

class Meta:
model = DeduplicationSet
Expand All @@ -25,11 +35,22 @@ class Meta:
"updated_by",
)

def create(self, validated_data) -> DeduplicationSet:
config_data = validated_data.get(CONFIG) and validated_data.pop(CONFIG)
config = Config.objects.create(**config_data) if config_data else None
return DeduplicationSet.objects.create(config=config, **validated_data)


class CreateConfigSerializer(ConfigSerializer):
pass


class CreateDeduplicationSetSerializer(serializers.ModelSerializer):
config = CreateConfigSerializer(required=False)

class Meta:
model = DeduplicationSet
fields = ("reference_pk", "notification_url")
fields = ("config", "reference_pk", "notification_url")


class ImageSerializer(serializers.ModelSerializer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import face_recognition
import numpy as np
from constance import config

from hope_dedup_engine.apps.faces.managers import StorageManager
from hope_dedup_engine.apps.faces.services.image_processor import ImageProcessor
Expand All @@ -20,7 +19,10 @@ class DuplicationDetector:
logger: logging.Logger = logging.getLogger(__name__)

def __init__(
self, filenames: tuple[str], ignore_pairs: tuple[tuple[str, str], ...] = tuple()
self,
filenames: tuple[str],
face_distance_threshold: float,
ignore_pairs: tuple[tuple[str, str], ...] = (),
) -> None:
"""
Initialize the DuplicationDetector with the given filenames and ignore pairs.
Expand All @@ -31,9 +33,10 @@ def __init__(
The pairs of filenames to ignore. Defaults to an empty tuple.
"""
self.filenames = filenames
self.face_distance_threshold = face_distance_threshold
self.ignore_set = IgnorePairsValidator.validate(ignore_pairs)
self.storages = StorageManager()
self.image_processor = ImageProcessor()
self.image_processor = ImageProcessor(face_distance_threshold)

def _encodings_filename(self, filename: str) -> str:
"""
Expand Down Expand Up @@ -122,7 +125,7 @@ def find_duplicates(self) -> Generator[tuple[str, str, float], None, None]:
encodings_all = self._load_encodings_all()

for path1, path2 in combinations(existed_images_name, 2):
min_distance = config.FACE_DISTANCE_THRESHOLD
min_distance = self.face_distance_threshold
encodings1 = encodings_all.get(path1)
encodings2 = encodings_all.get(path2)
if encodings1 is None or encodings2 is None:
Expand All @@ -136,7 +139,7 @@ def find_duplicates(self) -> Generator[tuple[str, str, float], None, None]:
) < min_distance:
min_distance = current_min

if min_distance < config.FACE_DISTANCE_THRESHOLD:
if min_distance < self.face_distance_threshold:
yield (path1, path2, round(min_distance, 5))
except Exception as e:
self.logger.exception(
Expand Down
4 changes: 2 additions & 2 deletions src/hope_dedup_engine/apps/faces/services/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ImageProcessor:

logger: logging.Logger = logging.getLogger(__name__)

def __init__(self) -> None:
def __init__(self, face_distance_threshold: float) -> None:
"""
Initialize the ImageProcessor with the required configurations.
"""
Expand All @@ -75,7 +75,7 @@ def __init__(self) -> None:
model=config.FACE_ENCODINGS_MODEL,
)
self.face_detection_confidence: float = config.FACE_DETECTION_CONFIDENCE
self.distance_threshold: float = config.FACE_DISTANCE_THRESHOLD
self.distance_threshold: float = face_distance_threshold
self.nms_threshold: float = config.NMS_THRESHOLD

def _get_face_detections_dnn(
Expand Down
4 changes: 3 additions & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NoDuplicateFinder,
)
from testutils.factories.api import (
ConfigFactory,
DeduplicationSetFactory,
DuplicateFactory,
IgnoredKeyPairFactory,
Expand All @@ -26,14 +27,15 @@
register(ExternalSystemFactory)
register(UserFactory)
register(DeduplicationSetFactory, external_system=LazyFixture("external_system"))
register(ImageFactory, deduplication_Set=LazyFixture("deduplication_set"))
register(ImageFactory, deduplication_set=LazyFixture("deduplication_set"))
register(
ImageFactory,
_name="second_image",
deduplication_Set=LazyFixture("deduplication_set"),
)
register(DuplicateFactory, deduplication_set=LazyFixture("deduplication_set"))
register(IgnoredKeyPairFactory, deduplication_set=LazyFixture("deduplication_set"))
register(ConfigFactory)


@fixture
Expand Down
48 changes: 43 additions & 5 deletions tests/api/test_adapters.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
from random import random
from unittest.mock import MagicMock

from constance.test.unittest import override_config
from pytest import fixture
from pytest_mock import MockerFixture

from hope_dedup_engine.apps.api.deduplication.adapters import DuplicateFaceFinder
from hope_dedup_engine.apps.api.models import DeduplicationSet, Image


@fixture
def duplication_detector(mocker: MockerFixture) -> MagicMock:
yield mocker.patch(
"hope_dedup_engine.apps.api.deduplication.adapters.DuplicationDetector"
)


def test_duplicate_face_finder_uses_duplication_detector(
deduplication_set: DeduplicationSet,
image: Image,
second_image: Image,
mocker: MockerFixture,
duplication_detector: MagicMock,
) -> None:
duplication_detector = mocker.patch(
"hope_dedup_engine.apps.api.deduplication.adapters.DuplicationDetector"
)
duplication_detector.return_value.find_duplicates.return_value = iter(
(
(
Expand All @@ -27,7 +36,8 @@ def test_duplicate_face_finder_uses_duplication_detector(
found_pairs = tuple(finder.run())

duplication_detector.assert_called_once_with(
(image.filename, second_image.filename), ()
(image.filename, second_image.filename),
deduplication_set.config.face_distance_threshold,
)
duplication_detector.return_value.find_duplicates.assert_called_once()
assert len(found_pairs) == 1
Expand All @@ -36,3 +46,31 @@ def test_duplicate_face_finder_uses_duplication_detector(
second_image.reference_pk,
1 - distance,
)


def _run_duplicate_face_finder(deduplication_set: DeduplicationSet) -> None:
finder = DuplicateFaceFinder(deduplication_set)
tuple(finder.run()) # tuple is used to make generator finish execution


def test_duplication_detector_is_initiated_with_correct_face_distance_threshold_value(
deduplication_set: DeduplicationSet,
duplication_detector: MagicMock,
) -> None:
# deduplication set face_distance_threshold config value is used
_run_duplicate_face_finder(deduplication_set)
duplication_detector.assert_called_once_with(
(), deduplication_set.config.face_distance_threshold
)
face_distance_threshold = random()
with override_config(FACE_DISTANCE_THRESHOLD=face_distance_threshold):
# value from global config is used when face_distance_threshold is not set in deduplication set config
duplication_detector.reset_mock()
deduplication_set.config.face_distance_threshold = None
_run_duplicate_face_finder(deduplication_set)
duplication_detector.assert_called_once_with((), face_distance_threshold)
# value from global config is used when deduplication set has no config
duplication_detector.reset_mock()
deduplication_set.config = None
_run_duplicate_face_finder(deduplication_set)
duplication_detector.assert_called_once_with((), face_distance_threshold)
31 changes: 24 additions & 7 deletions tests/api/test_deduplication_set_create.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,63 @@
from api_const import DEDUPLICATION_SET_LIST_VIEW, JSON
from pytest import mark
from rest_framework import status
from rest_framework.reverse import reverse
from rest_framework.test import APIClient
from testutils.factories.api import DeduplicationSetFactory

from hope_dedup_engine.apps.api.models import DeduplicationSet
from hope_dedup_engine.apps.api.serializers import DeduplicationSetSerializer
from hope_dedup_engine.apps.api.serializers import CreateDeduplicationSetSerializer


def test_can_create_deduplication_set(api_client: APIClient) -> None:
previous_amount = DeduplicationSet.objects.count()
data = DeduplicationSetSerializer(DeduplicationSetFactory.build()).data
data = CreateDeduplicationSetSerializer(DeduplicationSetFactory.build()).data

response = api_client.post(
reverse(DEDUPLICATION_SET_LIST_VIEW), data=data, format=JSON
)

assert response.status_code == status.HTTP_201_CREATED
assert DeduplicationSet.objects.count() == previous_amount + 1
data = response.json()
assert data["state"] == DeduplicationSet.State.CLEAN.label


def test_missing_fields_handling(api_client: APIClient) -> None:
data = DeduplicationSetSerializer(DeduplicationSetFactory.build()).data
data = CreateDeduplicationSetSerializer(DeduplicationSetFactory.build()).data
del data["reference_pk"]

response = api_client.post(
reverse(DEDUPLICATION_SET_LIST_VIEW), data=data, format=JSON
)

assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()
assert len(errors) == 1
assert "reference_pk" in errors


def test_invalid_values_handling(api_client: APIClient) -> None:
data = DeduplicationSetSerializer(DeduplicationSetFactory.build()).data
data["reference_pk"] = None
@mark.parametrize("field", ("reference_pk", "config"))
def test_invalid_values_handling(field: str, api_client: APIClient) -> None:
data = CreateDeduplicationSetSerializer(DeduplicationSetFactory.build()).data
data[field] = None

response = api_client.post(
reverse(DEDUPLICATION_SET_LIST_VIEW), data=data, format=JSON
)

assert response.status_code == status.HTTP_400_BAD_REQUEST
errors = response.json()
assert len(errors) == 1
assert "reference_pk" in errors
assert field in errors


def test_can_set_deduplication_set_without_config(api_client: APIClient) -> None:
data = CreateDeduplicationSetSerializer(DeduplicationSetFactory.build()).data
del data["config"]

response = api_client.post(
reverse(DEDUPLICATION_SET_LIST_VIEW), data=data, format=JSON
)

assert response.status_code == status.HTTP_201_CREATED
9 changes: 9 additions & 0 deletions tests/extras/testutils/factories/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from hope_dedup_engine.apps.api.models import DeduplicationSet, HDEToken
from hope_dedup_engine.apps.api.models.deduplication import (
Config,
Duplicate,
IgnoredKeyPair,
Image,
Expand All @@ -17,11 +18,19 @@ class Meta:
model = HDEToken


class ConfigFactory(DjangoModelFactory):
face_distance_threshold = fuzzy.FuzzyFloat(low=0.1, high=1.0)

class Meta:
model = Config


class DeduplicationSetFactory(DjangoModelFactory):
reference_pk = fuzzy.FuzzyText()
external_system = SubFactory(ExternalSystemFactory)
state = DeduplicationSet.State.CLEAN
notification_url = fuzzy.FuzzyText(prefix="https://")
config = SubFactory(ConfigFactory)

class Meta:
model = DeduplicationSet
Expand Down
Loading

0 comments on commit e15ffc5

Please sign in to comment.