Skip to content

Commit

Permalink
Add unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
WaVEV committed Nov 23, 2024
1 parent 2a59bdc commit 6740a0a
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/model_fields_/test_objectidfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,10 @@ def test_to_python_invalid_value(self):
msg = f"['“{invalid_value}” value must be an Object Id.']"
with self.assertRaisesMessage(exceptions.ValidationError, msg):
ObjectIdField().to_python(invalid_value)

def test_get_prep_value_invalud_values(self):
for invalid_value in [3, "None", {}, []]:
with self.subTest(invalid_value=invalid_value):
msg = f"Field '{None}' expected an ObjectId but got {invalid_value!r}."
with self.assertRaisesMessage(ValueError, msg):
ObjectIdField().get_prep_value(invalid_value)
17 changes: 17 additions & 0 deletions tests/queries_/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from django.db import models

from django_mongodb.fields import ObjectIdField


class Author(models.Model):
name = models.CharField(max_length=10)
Expand All @@ -14,3 +16,18 @@ class Book(models.Model):

def __str__(self):
return self.title


class Tag(models.Model):
name = models.CharField(max_length=10)
parent = models.ForeignKey(
"self",
models.SET_NULL,
blank=True,
null=True,
related_name="children",
)
group_id = ObjectIdField(null=True)

def __str__(self):
return self.name
95 changes: 95 additions & 0 deletions tests/queries_/test_objectid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from bson import ObjectId
from django.test import TestCase

from .models import Tag


class QueriesFilterByObjectId(TestCase):
@classmethod
def setUpTestData(cls):
cls.group_id_str_1 = "1" * 24
cls.group_id_obj_1 = ObjectId(cls.group_id_str_1)
cls.group_id_str_2 = "2" * 24
cls.group_id_obj_2 = ObjectId(cls.group_id_str_2)

cls.t1 = Tag.objects.create(name="t1")
cls.t2 = Tag.objects.create(name="t2", parent=cls.t1)
cls.t3 = Tag.objects.create(name="t3", parent=cls.t1, group_id=cls.group_id_str_1)
cls.t4 = Tag.objects.create(name="t4", parent=cls.t3, group_id=cls.group_id_obj_2)
cls.t5 = Tag.objects.create(name="t5", parent=cls.t3)

def test_filter_group_id_is_null(self):
"""Filter objects where group_id is not null"""
for value, expected in [(False, [self.t3, self.t4]), (True, [self.t1, self.t2, self.t5])]:
with self.subTest(object_id=value):
qs = Tag.objects.filter(group_id__isnull=value).order_by("name")
self.assertSequenceEqual(qs, expected)

def test_filter_group_id_equal_value(self):
"""Filter by group_id with a specific value"""
for value in [self.group_id_str_1, self.group_id_obj_1]:
with self.subTest(object_id=value):
qs = Tag.objects.filter(group_id=value).order_by("name")
self.assertSequenceEqual(qs, [self.t3])

def test_filter_group_id_in_value(self):
"""Filter by group_id where value is in a list"""
test_cases = [
[self.group_id_str_1, self.group_id_str_2],
[self.group_id_obj_1, self.group_id_obj_2],
]
for values in test_cases:
with self.subTest(values=values):
qs = Tag.objects.filter(group_id__in=values).order_by("name")
self.assertSequenceEqual(qs, [self.t3, self.t4])

def test_filter_group_id_equal_subquery(self):
"""Filter by group_id using a subquery"""
subquery = Tag.objects.filter(name="t3").values("group_id")
for value in [self.group_id_str_1, self.group_id_obj_1]:
with self.subTest(object_id=value):
qs = Tag.objects.filter(group_id__in=subquery).order_by("name")
self.assertSequenceEqual(qs, [self.t3])

def test_filter_group_id_in_subquery(self):
"""Filter by group_id using a subquery with multiple values"""
subquery = Tag.objects.filter(name__in=["t3", "t4"]).values("group_id")
test_cases = [
[self.group_id_str_1, self.group_id_str_2],
[self.group_id_obj_1, self.group_id_obj_2],
]
for values in test_cases:
with self.subTest(values=values):
qs = Tag.objects.filter(group_id__in=subquery).order_by("name")
self.assertSequenceEqual(qs, [self.t3, self.t4])

def test_union_children_to_select_parents(self):
"""Union query to select parents of children based on group_id"""
child_group_ids = [self.group_id_str_1, self.group_id_obj_1]
for group_id in child_group_ids:
with self.subTest(group_id=group_id):
child_ids = Tag.objects.filter(group_id=group_id).values_list("id", flat=True)
parent_qs = (
Tag.objects.filter(children__id__in=child_ids).distinct().order_by("name")
)
self.assertSequenceEqual(parent_qs, [self.t1])

def test_filter_group_id_union_with(self):
"""Combine queries using union"""
test_cases = [
(self.group_id_str_1, self.group_id_str_2),
(self.group_id_obj_1, self.group_id_obj_2),
]
for value1, value2 in test_cases:
with self.subTest(value1=value1, value2=value2):
qs_a = Tag.objects.filter(group_id=value1)
qs_b = Tag.objects.filter(group_id=value2)
union_qs = qs_a.union(qs_b).order_by("name")
self.assertSequenceEqual(union_qs, [self.t3, self.t4])

def test_invalid_object_id(self):
"""Combine queries using union"""
value = "value1"
msg = f"Field 'group_id' expected an ObjectId but got '{value}'."
with self.assertRaisesMessage(ValueError, msg):
Tag.objects.filter(group_id=value)

0 comments on commit 6740a0a

Please sign in to comment.