diff --git a/keystone_api/apps/allocations/tasks/notifications.py b/keystone_api/apps/allocations/tasks/notifications.py index 36fa3e64..8d552a38 100644 --- a/keystone_api/apps/allocations/tasks/notifications.py +++ b/keystone_api/apps/allocations/tasks/notifications.py @@ -1,9 +1,9 @@ """Background tasks for issuing user notifications.""" import logging +from datetime import date, timedelta from celery import shared_task -from django.utils import timezone from apps.allocations.models import AllocationRequest from apps.notifications.models import Notification, Preference @@ -30,17 +30,23 @@ def send_expiry_notification(user: User, request: AllocationRequest) -> None: # There are no notifications if the allocation does not expire days_until_expire = request.get_days_until_expire() if days_until_expire is None: - log.debug(f'Request {request.id} does not expire') + log.debug(f'Skipping expiry notification for user {user.username}: Request {request.id} does not expire') return - elif days_until_expire <= 0: - log.debug(f'Request {request.id} has already expired') + # Skip proposals that are already expired + if days_until_expire <= 0: + log.debug(f'Skipping expiry notification for user {user.username}: Request {request.id} has already expired') return # Exit early if we have not hit a notification threshold yet next_threshold = Preference.get_user_preference(user).get_next_expiration_threshold(days_until_expire) - log.debug(f'Request {request.id} expires in {days_until_expire} days. Next threshold at {next_threshold} days.') if next_threshold is None: + log.debug(f'Skipping expiry notification for user {user.username}: No notification threshold has been hit yet.') + return + + # Don't bombard new user's with outdated notifications + if user.date_joined >= date.today() - timedelta(days=next_threshold): + log.debug(f'Skipping expiry notification for user {user.username}: User account created after notification threshold.') return # Check if a notification has already been sent for the next threshold or any smaller threshold @@ -50,9 +56,10 @@ def send_expiry_notification(user: User, request: AllocationRequest) -> None: metadata__request_id=request.id, metadata__days_to_expire__lte=next_threshold ).exists(): + log.debug(f'Skipping expiry notification for user {user.username}: Notification already sent for threshold.') return - log.debug(f'Sending new notification for request {request.id} to user {user.username}.') + log.debug(f'Sending expiry notification for request {request.id} to user {user.username}.') send_notification_template( user=user, subject=f'Allocation Expires on {request.expire}', @@ -76,13 +83,12 @@ def send_notifications() -> None: active_requests = AllocationRequest.objects.filter( status=AllocationRequest.StatusChoices.APPROVED, - expire__gte=timezone.now() + expire__gt=date.today() ).all() failed = False for request in active_requests: - for user in request.group.get_all_members(): - + for user in request.group.get_all_members().filter(is_active=True): try: send_expiry_notification(user, request) diff --git a/keystone_api/apps/allocations/tests/test_tasks/__init__.py b/keystone_api/apps/allocations/tests/test_tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/keystone_api/apps/allocations/tests/test_tasks/test_notifications/__init__.py b/keystone_api/apps/allocations/tests/test_tasks/test_notifications/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/keystone_api/apps/allocations/tests/test_tasks/test_notifications/test_send_expiry_notification.py b/keystone_api/apps/allocations/tests/test_tasks/test_notifications/test_send_expiry_notification.py new file mode 100644 index 00000000..61aab67d --- /dev/null +++ b/keystone_api/apps/allocations/tests/test_tasks/test_notifications/test_send_expiry_notification.py @@ -0,0 +1,118 @@ +"""Unit tests for the `send_expiry_notification` function.""" + +from datetime import date, timedelta +from unittest.mock import MagicMock, Mock, patch + +from django.test import TestCase + +from apps.allocations.models import AllocationRequest +from apps.allocations.tasks import send_expiry_notification +from apps.users.models import User + + +class NotificationSending(TestCase): + """Test the sending/skipping of notifications.""" + + def setUp(self): + """Set up test data.""" + + self.user = MagicMock(spec=User) + self.user.username = 'testuser' + self.user.date_joined = date.today() - timedelta(days=10) + self.user.is_active = True + + self.request = MagicMock(spec=AllocationRequest) + + @patch('apps.notifications.shortcuts.send_notification_template') + def test_no_notification_if_request_does_not_expire(self, mock_send_template: Mock) -> None: + """Test no notification is sent if the request does not expire.""" + + self.request.get_days_until_expire.return_value = None + with self.assertLogs('apps.allocations.tasks', level='DEBUG') as log: + send_expiry_notification(self.user, self.request) + self.assertRegex(log.output[-1], '.*Skipping expiry notification .* does not expire.*') + + mock_send_template.assert_not_called() + + @patch('apps.notifications.shortcuts.send_notification_template') + def test_no_notification_if_request_already_expired(self, mock_send_template: Mock) -> None: + """Test no notification is sent if the request has already expired.""" + + self.request.get_days_until_expire.return_value = 0 + with self.assertLogs('apps.allocations.tasks', level='DEBUG') as log: + send_expiry_notification(self.user, self.request) + self.assertRegex(log.output[-1], '.*Skipping expiry notification .* has already expired.*') + + mock_send_template.assert_not_called() + + @patch('apps.notifications.shortcuts.send_notification_template') + @patch('apps.notifications.models.Preference.get_user_preference') + def test_no_notification_if_no_threshold_reached( + self, mock_get_user_preference: Mock, mock_send_template: Mock + ) -> None: + """Test no notification is sent if no threshold is reached.""" + + mock_preference = MagicMock() + mock_preference.get_next_expiration_threshold.return_value = None + mock_get_user_preference.return_value = mock_preference + + self.request.get_days_until_expire.return_value = 15 + with self.assertLogs('apps.allocations.tasks', level='DEBUG') as log: + send_expiry_notification(self.user, self.request) + self.assertRegex( + log.output[-1], + '.*Skipping expiry notification .* No notification threshold has been hit yet.*' + ) + + mock_send_template.assert_not_called() + + @patch('apps.notifications.shortcuts.send_notification_template') + @patch('apps.notifications.models.Preference.get_user_preference') + def test_no_notification_if_user_recently_joined( + self, mock_get_user_preference: Mock, mock_send_template: Mock + ) -> None: + """Test no notification is sent if the user recently joined.""" + + mock_preference = MagicMock() + mock_preference.get_next_expiration_threshold.return_value = 10 + mock_get_user_preference.return_value = mock_preference + + self.user.date_joined = date.today() - timedelta(days=5) + self.request.get_days_until_expire.return_value = 15 + + with self.assertLogs('apps.allocations.tasks', level='DEBUG') as log: + send_expiry_notification(self.user, self.request) + self.assertRegex( + log.output[-1], + '.*Skipping expiry notification .* User account created after notification threshold.*' + ) + + mock_send_template.assert_not_called() + + @patch('apps.notifications.shortcuts.send_notification_template') + @patch('apps.notifications.models.Notification.objects.filter') + @patch('apps.notifications.models.Preference.get_user_preference') + def test_no_duplicate_notification( + self, mock_get_user_preference: Mock, + mock_notification_filter: Mock, + mock_send_template: Mock + ) -> None: + """Test no duplicate notification is sent.""" + + mock_preference = MagicMock() + mock_preference.get_next_expiration_threshold.return_value = 10 + mock_get_user_preference.return_value = mock_preference + + mock_notification_filter.return_value.exists.return_value = True + + self.user.date_joined = date.today() - timedelta(days=20) + self.request.get_days_until_expire.return_value = 15 + + with self.assertLogs('apps.allocations.tasks', level='DEBUG') as log: + send_expiry_notification(self.user, self.request) + self.assertRegex( + log.output[-1], + '.*Skipping expiry notification .* Notification already sent for threshold.*' + ) + + mock_send_template.assert_not_called() diff --git a/keystone_api/apps/users/models.py b/keystone_api/apps/users/models.py index 8e686762..ba4bdbcf 100644 --- a/keystone_api/apps/users/models.py +++ b/keystone_api/apps/users/models.py @@ -54,15 +54,22 @@ class ResearchGroup(models.Model): objects = ResearchGroupManager() - def get_all_members(self) -> tuple[User, ...]: - """Return all research group members.""" - - return (self.pi,) + tuple(self.admins.all()) + tuple(self.members.all()) - - def get_privileged_members(self) -> tuple[User, ...]: - """Return all research group members with admin privileges.""" - - return (self.pi,) + tuple(self.admins.all()) + def get_all_members(self) -> models.QuerySet: + """Return a queryset of all research group members.""" + + return User.objects.filter( + models.Q(pk=self.pi.pk) | + models.Q(pk__in=self.admins.values_list('pk', flat=True)) | + models.Q(pk__in=self.members.values_list('pk', flat=True)) + ) + + def get_privileged_members(self) -> models.QuerySet: + """Return a queryset of all research group members with admin privileges.""" + + return User.objects.filter( + models.Q(pk=self.pi.pk) | + models.Q(pk__in=self.admins.values_list('pk', flat=True)) + ) def __str__(self) -> str: # pragma: nocover # pragma: nocover """Return the research group's account name.""" diff --git a/keystone_api/apps/users/tests/test_models/test_ResearchGroup.py b/keystone_api/apps/users/tests/test_models/test_ResearchGroup.py index f9dd1c16..5767b9bb 100644 --- a/keystone_api/apps/users/tests/test_models/test_ResearchGroup.py +++ b/keystone_api/apps/users/tests/test_models/test_ResearchGroup.py @@ -19,7 +19,7 @@ def setUp(self) -> None: self.member2 = create_test_user(username='unprivileged2') def test_all_accounts_returned(self) -> None: - """Test all group members are included in the returned list.""" + """Test all group members are included in the returned queryset.""" group = ResearchGroup.objects.create(pi=self.pi) group.admins.add(self.admin1) @@ -27,8 +27,13 @@ def test_all_accounts_returned(self) -> None: group.members.add(self.member1) group.members.add(self.member2) - expected_members = (self.pi, self.admin1, self.admin2, self.member1, self.member2) - self.assertEqual(expected_members, group.get_all_members()) + expected_members = [self.pi, self.admin1, self.admin2, self.member1, self.member2] + + self.assertQuerySetEqual( + expected_members, + group.get_all_members(), + ordered=False + ) class GetPrivilegedMembers(TestCase): @@ -48,7 +53,7 @@ def test_pi_only(self) -> None: group = ResearchGroup.objects.create(pi=self.pi) expected_members = (self.pi,) - self.assertEqual(expected_members, group.get_privileged_members()) + self.assertQuerySetEqual(expected_members, group.get_privileged_members(), ordered=False) def test_pi_with_admins(self) -> None: """Test returned group members for a group with a PI and admins.""" @@ -58,7 +63,7 @@ def test_pi_with_admins(self) -> None: group.admins.add(self.admin2) expected_members = (self.pi, self.admin1, self.admin2) - self.assertEqual(expected_members, group.get_privileged_members()) + self.assertQuerySetEqual(expected_members, group.get_privileged_members(), ordered=False) def test_pi_with_members(self) -> None: """Test returned group members for a group with a PI and unprivileged members.""" @@ -68,7 +73,7 @@ def test_pi_with_members(self) -> None: group.members.add(self.member2) expected_members = (self.pi,) - self.assertEqual(expected_members, group.get_privileged_members()) + self.assertQuerySetEqual(expected_members, group.get_privileged_members(), ordered=False) def test_pi_with_admin_and_members(self) -> None: """Test returned group members for a group with a PI, admins, and unprivileged members.""" @@ -80,4 +85,4 @@ def test_pi_with_admin_and_members(self) -> None: group.members.add(self.member2) expected_members = (self.pi, self.admin1, self.admin2) - self.assertEqual(expected_members, group.get_privileged_members()) + self.assertQuerySetEqual(expected_members, group.get_privileged_members(), ordered=False)