From ec3f867fe33a2ac7d6e2f3c5daabf0384169499e Mon Sep 17 00:00:00 2001 From: Daniel Perrefort Date: Thu, 29 Aug 2024 20:32:08 -0400 Subject: [PATCH] Refactor slurm tasks for testability (#406) --- keystone_api/apps/allocations/managers.py | 125 ++++++++++++++++++ keystone_api/apps/allocations/models.py | 3 + keystone_api/apps/allocations/tasks/limits.py | 30 ++--- .../apps/allocations/tasks/notifications.py | 15 +-- .../tests/test_managers/__init__.py | 0 .../test_managers/test_AllocationManager.py | 117 ++++++++++++++++ keystone_api/apps/notifications/models.py | 20 ++- 7 files changed, 280 insertions(+), 30 deletions(-) create mode 100644 keystone_api/apps/allocations/managers.py create mode 100644 keystone_api/apps/allocations/tests/test_managers/__init__.py create mode 100644 keystone_api/apps/allocations/tests/test_managers/test_AllocationManager.py diff --git a/keystone_api/apps/allocations/managers.py b/keystone_api/apps/allocations/managers.py new file mode 100644 index 00000000..7b70c961 --- /dev/null +++ b/keystone_api/apps/allocations/managers.py @@ -0,0 +1,125 @@ +"""Custom database managers for encapsulating repeatable table queries. + +Manager classes encapsulate common database operations at the table level (as +opposed to the level of individual records). At least one Manager exists for +every database model. Managers are commonly exposed as an attribute of the +associated model class called `objects`. +""" + +from datetime import date +from typing import TYPE_CHECKING + +from django.db.models import Manager, QuerySet, Sum + +from apps.users.models import ResearchGroup + +if TYPE_CHECKING: # pragma: nocover + from apps.allocations.models import Cluster + +__all__ = ['AllocationManager'] + + +class AllocationManager(Manager): + """Custom manager for the `Allocation` model. + + Provides query methods for fetching approved, active, and expired allocations, + as well as calculating service units and historical usage. + """ + + def approved_allocations(self, account: ResearchGroup, cluster: 'Cluster') -> QuerySet: + """Retrieve all approved allocations for a specific account and cluster. + + Args: + account: object representing the account. + cluster: object representing the cluster. + + Returns: + A queryset of approved Allocation objects. + """ + + return self.filter(request__group=account, cluster=cluster, request__status='AP') + + def active_allocations(self, account: ResearchGroup, cluster: 'Cluster') -> QuerySet: + """Retrieve all active allocations for a specific account and cluster. + + Active allocations have been approved and are currently within their start/end date. + + Args: + account: object representing the account. + cluster: object representing the cluster. + + Returns: + A queryset of active Allocation objects. + """ + + return self.approved_allocations(account, cluster).filter( + request__active__lte=date.today(), request__expire__gt=date.today() + ) + + def expiring_allocations(self, account: ResearchGroup, cluster: 'Cluster') -> QuerySet: + """Retrieve all expiring allocations for a specific account and cluster. + + Expiring allocations have been approved and have passed their expiration date + but do not yet have a final usage value set. + + Args: + account: object representing the account. + cluster: object representing the cluster. + + Returns: + A queryset of expired Allocation objects ordered by expiration date. + """ + + return self.approved_allocations(account, cluster).filter( + final=None, request__expire__lte=date.today() + ).order_by("request__expire") + + def active_service_units(self, account: ResearchGroup, cluster: 'Cluster') -> int: + """Calculate the total service units across all active allocations for an account and cluster. + + Active allocations have been approved and are currently within their start/end date. + + Args: + account: object representing the account. + cluster: object representing the cluster. + + Returns: + Total service units from active allocations. + """ + + return self.active_allocations(account, cluster).aggregate( + Sum("awarded") + )['awarded__sum'] or 0 + + def expiring_service_units(self, account: ResearchGroup, cluster: 'Cluster') -> int: + """Calculate the total service units across all expiring allocations for an account and cluster. + + Expiring allocations have been approved and have passed their expiration date + but do not yet have a final usage value set. + + Args: + account: object representing the account. + cluster: object representing the cluster. + + Returns: + Total service units from expired allocations. + """ + + return self.expiring_allocations(account, cluster).aggregate( + Sum("awarded") + )['awarded__sum'] or 0 + + def historical_usage(self, account: ResearchGroup, cluster: 'Cluster') -> int: + """Calculate the total final usage for expired allocations of a specific account and cluster. + + Args: + account: object representing the account. + cluster: object representing the cluster. + + Returns: + Total historical usage calculated from expired allocations. + """ + + return self.approved_allocations(account, cluster).filter( + request__expire__lte=date.today() + ).aggregate(Sum("final"))['final__sum'] or 0 diff --git a/keystone_api/apps/allocations/models.py b/keystone_api/apps/allocations/models.py index 2044dfe6..affb7712 100644 --- a/keystone_api/apps/allocations/models.py +++ b/keystone_api/apps/allocations/models.py @@ -15,6 +15,7 @@ from django.db import models from django.template.defaultfilters import truncatechars +from apps.allocations.managers import AllocationManager from apps.users.models import ResearchGroup, User __all__ = [ @@ -45,6 +46,8 @@ class Allocation(RGModelInterface, models.Model): cluster: Cluster = models.ForeignKey('Cluster', on_delete=models.CASCADE) request: AllocationRequest = models.ForeignKey('AllocationRequest', on_delete=models.CASCADE) + objects = AllocationManager() + def get_research_group(self) -> ResearchGroup: """Return the research group tied to the current record.""" diff --git a/keystone_api/apps/allocations/tasks/limits.py b/keystone_api/apps/allocations/tasks/limits.py index 05feabaf..98faaa85 100644 --- a/keystone_api/apps/allocations/tasks/limits.py +++ b/keystone_api/apps/allocations/tasks/limits.py @@ -1,10 +1,8 @@ """Background tasks for updating/enforcing slurm usage limits.""" import logging -from datetime import date from celery import shared_task -from django.db.models import Sum from apps.allocations.models import * from apps.users.models import * @@ -49,28 +47,21 @@ def update_limits_for_cluster(cluster: Cluster) -> None: @shared_task() def update_limit_for_account(account: ResearchGroup, cluster: Cluster) -> None: - """Update the TRES billing usage limits for an individual Slurm account, closing out any expired allocations. + """Update the allocation limits for an individual Slurm account and close out any expired allocations. Args: account: ResearchGroup object for the account. cluster: Cluster object corresponding to the Slurm cluster. """ - # Base query for approved Allocations under the given account on the given cluster - approved_query = Allocation.objects.filter(request__group=account, cluster=cluster, request__status='AP') - - # Query for allocations that have expired but do not have a final usage value, determine their SU contribution - closing_query = approved_query.filter(final=None, request__expire__lte=date.today()).order_by("request__expire") - closing_sus = closing_query.aggregate(Sum("awarded"))['awarded__sum'] or 0 - - # Query for allocations that are active, and determine their total service unit contribution - active_query = approved_query.filter(request__active__lte=date.today(), request__expire__gt=date.today()) - active_sus = active_query.aggregate(Sum("awarded"))['awarded__sum'] or 0 + # Calculate service units for expired and active allocations + closing_sus = Allocation.objects.expiring_service_units(account, cluster) + active_sus = Allocation.objects.active_service_units(account, cluster) # Determine the historical contribution to the current limit current_limit = slurm.get_cluster_limit(account.name, cluster.name) - historical_usage = current_limit - active_sus - closing_sus + if historical_usage < 0: log.warning(f"Negative Historical usage found for {account.name} on {cluster.name}:\n" f"historical: {historical_usage}, current: {current_limit}, active: {active_sus}, closing: {closing_sus}\n" @@ -88,7 +79,7 @@ def update_limit_for_account(account: ResearchGroup, cluster: Cluster) -> None: closing_summary = (f"Summary of closing allocations:\n" f"> Current Usage before closing: {current_usage}\n") - for allocation in closing_query.all(): + for allocation in Allocation.objects.expiring_allocations(account, cluster): allocation.final = min(current_usage, allocation.awarded) closing_summary += f"> Allocation {allocation.id}: {current_usage} - {allocation.final} -> {current_usage - allocation.final}\n" current_usage -= allocation.final @@ -100,17 +91,14 @@ def update_limit_for_account(account: ResearchGroup, cluster: Cluster) -> None: log.warning(f"The current usage is somehow higher than the limit for {account.name}!") # Set the new account usage limit using the updated historical usage after closing any expired allocations - expired_requests = approved_query.filter(request__expire__lte=date.today()) - updated_historical_usage = expired_requests.aggregate(Sum("final"))['final__sum'] or 0 - + updated_historical_usage = Allocation.objects.historical_usage(account, cluster) updated_limit = updated_historical_usage + active_sus slurm.set_cluster_limit(account.name, cluster.name, updated_limit) # Log summary of changes during limits update for this Slurm account on this cluster log.debug(f"Summary of limits update for {account.name} on {cluster.name}:\n" - f"> Approved allocations found: {len(approved_query)}\n" - f"> Service units from {len(active_query)} active allocations: {active_sus}\n" - f"> Service units from {len(closing_query)} closing allocations: {closing_sus}\n" + f"> Service units from active allocations: {active_sus}\n" + f"> Service units from closing allocations: {closing_sus}\n" f"> {closing_summary}" f"> historical usage change: {historical_usage} -> {updated_historical_usage}\n" f"> limit change: {current_limit} -> {updated_limit}") diff --git a/keystone_api/apps/allocations/tasks/notifications.py b/keystone_api/apps/allocations/tasks/notifications.py index d0d6b29f..36fa3e64 100644 --- a/keystone_api/apps/allocations/tasks/notifications.py +++ b/keystone_api/apps/allocations/tasks/notifications.py @@ -1,7 +1,6 @@ """Background tasks for issuing user notifications.""" import logging -from datetime import timedelta from celery import shared_task from django.utils import timezone @@ -13,10 +12,10 @@ log = logging.getLogger(__name__) -__all__ = ['send_expiry_notifications', 'send_expiry_notification_for_request'] +__all__ = ['send_notifications', 'send_expiry_notification'] -def send_expiry_notification_for_request(user: User, request: AllocationRequest) -> None: +def send_expiry_notification(user: User, request: AllocationRequest) -> None: """Send any pending expiration notices to the given user. A notification is only generated if warranted by the user's notification preferences. @@ -72,20 +71,20 @@ def send_expiry_notification_for_request(user: User, request: AllocationRequest) @shared_task() -def send_expiry_notifications() -> None: +def send_notifications() -> None: """Send any pending expiration notices to all users.""" - expiring_requests = AllocationRequest.objects.filter( + active_requests = AllocationRequest.objects.filter( status=AllocationRequest.StatusChoices.APPROVED, - expire__gte=timezone.now() - timedelta(days=7) + expire__gte=timezone.now() ).all() failed = False - for request in expiring_requests: + for request in active_requests: for user in request.group.get_all_members(): try: - send_expiry_notification_for_request(user, request) + send_expiry_notification(user, request) except Exception as error: log.exception(f'Error notifying user {user.username} for request {request.id}: {error}') diff --git a/keystone_api/apps/allocations/tests/test_managers/__init__.py b/keystone_api/apps/allocations/tests/test_managers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/keystone_api/apps/allocations/tests/test_managers/test_AllocationManager.py b/keystone_api/apps/allocations/tests/test_managers/test_AllocationManager.py new file mode 100644 index 00000000..ef919c33 --- /dev/null +++ b/keystone_api/apps/allocations/tests/test_managers/test_AllocationManager.py @@ -0,0 +1,117 @@ +"""Unit tests for the `AllocationManager` class.""" + +from django.test import TestCase +from django.utils import timezone + +from apps.allocations.models import * +from apps.users.models import * + + +class GetAllocationData(TestCase): + """Test get methods used to retrieve allocation metadata/status.""" + + def setUp(self) -> None: + """Create test data.""" + + self.user = User.objects.create(username="user", password='foobar123!') + self.group = ResearchGroup.objects.create(name="Research Group 1", pi=self.user) + self.cluster = Cluster.objects.create(name="Test Cluster") + + # An allocation request pending review + self.request1 = AllocationRequest.objects.create( + group=self.group, + status='PD', + active=timezone.now().date(), + expire=timezone.now().date() + timezone.timedelta(days=30) + ) + self.allocation1 = Allocation.objects.create( + requested=100, + awarded=80, + final=None, + cluster=self.cluster, + request=self.request1 + ) + + # An approved allocation request that is active + self.request2 = AllocationRequest.objects.create( + group=self.group, + status='AP', + active=timezone.now().date(), + expire=timezone.now().date() + timezone.timedelta(days=30) + ) + self.allocation2 = Allocation.objects.create( + requested=100, + awarded=80, + final=None, + cluster=self.cluster, + request=self.request2 + ) + + # An approved allocation request that is expired without final usage + self.request3 = AllocationRequest.objects.create( + group=self.group, + status='AP', + active=timezone.now().date() - timezone.timedelta(days=60), + expire=timezone.now().date() - timezone.timedelta(days=30) + ) + self.allocation3 = Allocation.objects.create( + requested=100, + awarded=70, + final=None, + cluster=self.cluster, + request=self.request3 + ) + + # An approved allocation request that is expired with final usage + self.request4 = AllocationRequest.objects.create( + group=self.group, + status='AP', + active=timezone.now().date() - timezone.timedelta(days=30), + expire=timezone.now().date() + ) + self.allocation4 = Allocation.objects.create( + requested=100, + awarded=60, + final=60, + cluster=self.cluster, + request=self.request4 + ) + + def test_approved_allocations(self) -> None: + """Test the `approved_allocations` method returns only approved allocations.""" + + approved_allocations = Allocation.objects.approved_allocations(self.group, self.cluster) + expected_allocations = [self.allocation2, self.allocation3, self.allocation4] + self.assertQuerySetEqual(expected_allocations, approved_allocations, ordered=False) + + def test_active_allocations(self) -> None: + """Test the `active_allocations` method returns only active allocations.""" + + active_allocations = Allocation.objects.active_allocations(self.group, self.cluster) + expected_allocations = [self.allocation2] + self.assertQuerySetEqual(expected_allocations, active_allocations, ordered=False) + + def test_expired_allocations(self) -> None: + """Test the `expired_allocations` method returns only expired allocations.""" + + expiring_allocations = Allocation.objects.expiring_allocations(self.group, self.cluster) + expected_allocations = [self.allocation3] + self.assertQuerySetEqual(expected_allocations, expiring_allocations, ordered=False) + + def test_active_service_units(self) -> None: + """Test the `active_service_units` method returns the total awarded service units for active allocations.""" + + active_su = Allocation.objects.active_service_units(self.group, self.cluster) + self.assertEqual(80, active_su) + + def test_expired_service_units(self) -> None: + """Test the `expired_service_units` method returns the total awarded service units for expired allocations.""" + + expiring_su = Allocation.objects.expiring_service_units(self.group, self.cluster) + self.assertEqual(70, expiring_su) + + def test_historical_usage(self) -> None: + """Test the `historical_usage` method returns the total final usage for expired allocations.""" + + historical_usage = Allocation.objects.historical_usage(self.group, self.cluster) + self.assertEqual(60, historical_usage) diff --git a/keystone_api/apps/notifications/models.py b/keystone_api/apps/notifications/models.py index 79fa63b9..9a6d92a6 100644 --- a/keystone_api/apps/notifications/models.py +++ b/keystone_api/apps/notifications/models.py @@ -77,7 +77,7 @@ def set_user_preference(cls, user: settings.AUTH_USER_MODEL, **kwargs) -> None: cls.objects.update_or_create(user=user, defaults=kwargs) def get_next_expiration_threshold(self, days_until_expire: int) -> int | None: - """Return the next threshold att which an expiration notification should be sent + """Return the next threshold at which an expiration notification should be sent The next notification occurs at the smallest threshold that is greater than or equal the days until expiration @@ -93,3 +93,21 @@ def get_next_expiration_threshold(self, days_until_expire: int) -> int | None: filter(lambda x: x >= days_until_expire, self.request_expiry_thresholds), default=None ) + + def get_next_usage_threshold(self, usage_percentage: int) -> int | None: + """Return the next threshold at which a usage notification should be sent + + The next notification occurs at the largest threshold that is + less than or equal the days until expiration + + Args: + usage_percentage: An allocation's percent utilization + + Return: + The next notification threshold in percent + """ + + return max( + filter(lambda x: x <= usage_percentage, self.request_expiry_thresholds), + default=None + )