From 6740a0ab414240801686fd992d04f73170a7518b Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 23 Nov 2024 17:09:42 -0300 Subject: [PATCH] Add unit tests. --- tests/model_fields_/test_objectidfield.py | 7 ++ tests/queries_/models.py | 17 ++++ tests/queries_/test_objectid.py | 95 +++++++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 tests/queries_/test_objectid.py diff --git a/tests/model_fields_/test_objectidfield.py b/tests/model_fields_/test_objectidfield.py index 7889ea47..e11406b5 100644 --- a/tests/model_fields_/test_objectidfield.py +++ b/tests/model_fields_/test_objectidfield.py @@ -32,3 +32,10 @@ def test_to_python_invalid_value(self): msg = f"['“{invalid_value}” value must be an Object Id.']" with self.assertRaisesMessage(exceptions.ValidationError, msg): ObjectIdField().to_python(invalid_value) + + def test_get_prep_value_invalud_values(self): + for invalid_value in [3, "None", {}, []]: + with self.subTest(invalid_value=invalid_value): + msg = f"Field '{None}' expected an ObjectId but got {invalid_value!r}." + with self.assertRaisesMessage(ValueError, msg): + ObjectIdField().get_prep_value(invalid_value) diff --git a/tests/queries_/models.py b/tests/queries_/models.py index 61b93890..7d375ad7 100644 --- a/tests/queries_/models.py +++ b/tests/queries_/models.py @@ -1,5 +1,7 @@ from django.db import models +from django_mongodb.fields import ObjectIdField + class Author(models.Model): name = models.CharField(max_length=10) @@ -14,3 +16,18 @@ class Book(models.Model): def __str__(self): return self.title + + +class Tag(models.Model): + name = models.CharField(max_length=10) + parent = models.ForeignKey( + "self", + models.SET_NULL, + blank=True, + null=True, + related_name="children", + ) + group_id = ObjectIdField(null=True) + + def __str__(self): + return self.name diff --git a/tests/queries_/test_objectid.py b/tests/queries_/test_objectid.py new file mode 100644 index 00000000..c4a641f2 --- /dev/null +++ b/tests/queries_/test_objectid.py @@ -0,0 +1,95 @@ +from bson import ObjectId +from django.test import TestCase + +from .models import Tag + + +class QueriesFilterByObjectId(TestCase): + @classmethod + def setUpTestData(cls): + cls.group_id_str_1 = "1" * 24 + cls.group_id_obj_1 = ObjectId(cls.group_id_str_1) + cls.group_id_str_2 = "2" * 24 + cls.group_id_obj_2 = ObjectId(cls.group_id_str_2) + + cls.t1 = Tag.objects.create(name="t1") + cls.t2 = Tag.objects.create(name="t2", parent=cls.t1) + cls.t3 = Tag.objects.create(name="t3", parent=cls.t1, group_id=cls.group_id_str_1) + cls.t4 = Tag.objects.create(name="t4", parent=cls.t3, group_id=cls.group_id_obj_2) + cls.t5 = Tag.objects.create(name="t5", parent=cls.t3) + + def test_filter_group_id_is_null(self): + """Filter objects where group_id is not null""" + for value, expected in [(False, [self.t3, self.t4]), (True, [self.t1, self.t2, self.t5])]: + with self.subTest(object_id=value): + qs = Tag.objects.filter(group_id__isnull=value).order_by("name") + self.assertSequenceEqual(qs, expected) + + def test_filter_group_id_equal_value(self): + """Filter by group_id with a specific value""" + for value in [self.group_id_str_1, self.group_id_obj_1]: + with self.subTest(object_id=value): + qs = Tag.objects.filter(group_id=value).order_by("name") + self.assertSequenceEqual(qs, [self.t3]) + + def test_filter_group_id_in_value(self): + """Filter by group_id where value is in a list""" + test_cases = [ + [self.group_id_str_1, self.group_id_str_2], + [self.group_id_obj_1, self.group_id_obj_2], + ] + for values in test_cases: + with self.subTest(values=values): + qs = Tag.objects.filter(group_id__in=values).order_by("name") + self.assertSequenceEqual(qs, [self.t3, self.t4]) + + def test_filter_group_id_equal_subquery(self): + """Filter by group_id using a subquery""" + subquery = Tag.objects.filter(name="t3").values("group_id") + for value in [self.group_id_str_1, self.group_id_obj_1]: + with self.subTest(object_id=value): + qs = Tag.objects.filter(group_id__in=subquery).order_by("name") + self.assertSequenceEqual(qs, [self.t3]) + + def test_filter_group_id_in_subquery(self): + """Filter by group_id using a subquery with multiple values""" + subquery = Tag.objects.filter(name__in=["t3", "t4"]).values("group_id") + test_cases = [ + [self.group_id_str_1, self.group_id_str_2], + [self.group_id_obj_1, self.group_id_obj_2], + ] + for values in test_cases: + with self.subTest(values=values): + qs = Tag.objects.filter(group_id__in=subquery).order_by("name") + self.assertSequenceEqual(qs, [self.t3, self.t4]) + + def test_union_children_to_select_parents(self): + """Union query to select parents of children based on group_id""" + child_group_ids = [self.group_id_str_1, self.group_id_obj_1] + for group_id in child_group_ids: + with self.subTest(group_id=group_id): + child_ids = Tag.objects.filter(group_id=group_id).values_list("id", flat=True) + parent_qs = ( + Tag.objects.filter(children__id__in=child_ids).distinct().order_by("name") + ) + self.assertSequenceEqual(parent_qs, [self.t1]) + + def test_filter_group_id_union_with(self): + """Combine queries using union""" + test_cases = [ + (self.group_id_str_1, self.group_id_str_2), + (self.group_id_obj_1, self.group_id_obj_2), + ] + for value1, value2 in test_cases: + with self.subTest(value1=value1, value2=value2): + qs_a = Tag.objects.filter(group_id=value1) + qs_b = Tag.objects.filter(group_id=value2) + union_qs = qs_a.union(qs_b).order_by("name") + self.assertSequenceEqual(union_qs, [self.t3, self.t4]) + + def test_invalid_object_id(self): + """Combine queries using union""" + value = "value1" + msg = f"Field 'group_id' expected an ObjectId but got '{value}'." + with self.assertRaisesMessage(ValueError, msg): + Tag.objects.filter(group_id=value)