Skip to content

Commit

Permalink
Refactor slurm tasks for testability (#406)
Browse files Browse the repository at this point in the history
  • Loading branch information
djperrefort committed Aug 30, 2024
1 parent 1dda7d0 commit ec3f867
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 30 deletions.
125 changes: 125 additions & 0 deletions keystone_api/apps/allocations/managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Custom database managers for encapsulating repeatable table queries.
Manager classes encapsulate common database operations at the table level (as
opposed to the level of individual records). At least one Manager exists for
every database model. Managers are commonly exposed as an attribute of the
associated model class called `objects`.
"""

from datetime import date
from typing import TYPE_CHECKING

from django.db.models import Manager, QuerySet, Sum

from apps.users.models import ResearchGroup

if TYPE_CHECKING: # pragma: nocover
from apps.allocations.models import Cluster

__all__ = ['AllocationManager']


class AllocationManager(Manager):
"""Custom manager for the `Allocation` model.
Provides query methods for fetching approved, active, and expired allocations,
as well as calculating service units and historical usage.
"""

def approved_allocations(self, account: ResearchGroup, cluster: 'Cluster') -> QuerySet:
"""Retrieve all approved allocations for a specific account and cluster.
Args:
account: object representing the account.
cluster: object representing the cluster.
Returns:
A queryset of approved Allocation objects.
"""

return self.filter(request__group=account, cluster=cluster, request__status='AP')

def active_allocations(self, account: ResearchGroup, cluster: 'Cluster') -> QuerySet:
"""Retrieve all active allocations for a specific account and cluster.
Active allocations have been approved and are currently within their start/end date.
Args:
account: object representing the account.
cluster: object representing the cluster.
Returns:
A queryset of active Allocation objects.
"""

return self.approved_allocations(account, cluster).filter(
request__active__lte=date.today(), request__expire__gt=date.today()
)

def expiring_allocations(self, account: ResearchGroup, cluster: 'Cluster') -> QuerySet:
"""Retrieve all expiring allocations for a specific account and cluster.
Expiring allocations have been approved and have passed their expiration date
but do not yet have a final usage value set.
Args:
account: object representing the account.
cluster: object representing the cluster.
Returns:
A queryset of expired Allocation objects ordered by expiration date.
"""

return self.approved_allocations(account, cluster).filter(
final=None, request__expire__lte=date.today()
).order_by("request__expire")

def active_service_units(self, account: ResearchGroup, cluster: 'Cluster') -> int:
"""Calculate the total service units across all active allocations for an account and cluster.
Active allocations have been approved and are currently within their start/end date.
Args:
account: object representing the account.
cluster: object representing the cluster.
Returns:
Total service units from active allocations.
"""

return self.active_allocations(account, cluster).aggregate(
Sum("awarded")
)['awarded__sum'] or 0

def expiring_service_units(self, account: ResearchGroup, cluster: 'Cluster') -> int:
"""Calculate the total service units across all expiring allocations for an account and cluster.
Expiring allocations have been approved and have passed their expiration date
but do not yet have a final usage value set.
Args:
account: object representing the account.
cluster: object representing the cluster.
Returns:
Total service units from expired allocations.
"""

return self.expiring_allocations(account, cluster).aggregate(
Sum("awarded")
)['awarded__sum'] or 0

def historical_usage(self, account: ResearchGroup, cluster: 'Cluster') -> int:
"""Calculate the total final usage for expired allocations of a specific account and cluster.
Args:
account: object representing the account.
cluster: object representing the cluster.
Returns:
Total historical usage calculated from expired allocations.
"""

return self.approved_allocations(account, cluster).filter(
request__expire__lte=date.today()
).aggregate(Sum("final"))['final__sum'] or 0
3 changes: 3 additions & 0 deletions keystone_api/apps/allocations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from django.db import models
from django.template.defaultfilters import truncatechars

from apps.allocations.managers import AllocationManager
from apps.users.models import ResearchGroup, User

__all__ = [
Expand Down Expand Up @@ -45,6 +46,8 @@ class Allocation(RGModelInterface, models.Model):
cluster: Cluster = models.ForeignKey('Cluster', on_delete=models.CASCADE)
request: AllocationRequest = models.ForeignKey('AllocationRequest', on_delete=models.CASCADE)

objects = AllocationManager()

def get_research_group(self) -> ResearchGroup:
"""Return the research group tied to the current record."""

Expand Down
30 changes: 9 additions & 21 deletions keystone_api/apps/allocations/tasks/limits.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Background tasks for updating/enforcing slurm usage limits."""

import logging
from datetime import date

from celery import shared_task
from django.db.models import Sum

from apps.allocations.models import *
from apps.users.models import *
Expand Down Expand Up @@ -49,28 +47,21 @@ def update_limits_for_cluster(cluster: Cluster) -> None:

@shared_task()
def update_limit_for_account(account: ResearchGroup, cluster: Cluster) -> None:
"""Update the TRES billing usage limits for an individual Slurm account, closing out any expired allocations.
"""Update the allocation limits for an individual Slurm account and close out any expired allocations.
Args:
account: ResearchGroup object for the account.
cluster: Cluster object corresponding to the Slurm cluster.
"""

# Base query for approved Allocations under the given account on the given cluster
approved_query = Allocation.objects.filter(request__group=account, cluster=cluster, request__status='AP')

# Query for allocations that have expired but do not have a final usage value, determine their SU contribution
closing_query = approved_query.filter(final=None, request__expire__lte=date.today()).order_by("request__expire")
closing_sus = closing_query.aggregate(Sum("awarded"))['awarded__sum'] or 0

# Query for allocations that are active, and determine their total service unit contribution
active_query = approved_query.filter(request__active__lte=date.today(), request__expire__gt=date.today())
active_sus = active_query.aggregate(Sum("awarded"))['awarded__sum'] or 0
# Calculate service units for expired and active allocations
closing_sus = Allocation.objects.expiring_service_units(account, cluster)
active_sus = Allocation.objects.active_service_units(account, cluster)

# Determine the historical contribution to the current limit
current_limit = slurm.get_cluster_limit(account.name, cluster.name)

historical_usage = current_limit - active_sus - closing_sus

if historical_usage < 0:
log.warning(f"Negative Historical usage found for {account.name} on {cluster.name}:\n"
f"historical: {historical_usage}, current: {current_limit}, active: {active_sus}, closing: {closing_sus}\n"
Expand All @@ -88,7 +79,7 @@ def update_limit_for_account(account: ResearchGroup, cluster: Cluster) -> None:

closing_summary = (f"Summary of closing allocations:\n"
f"> Current Usage before closing: {current_usage}\n")
for allocation in closing_query.all():
for allocation in Allocation.objects.expiring_allocations(account, cluster):
allocation.final = min(current_usage, allocation.awarded)
closing_summary += f"> Allocation {allocation.id}: {current_usage} - {allocation.final} -> {current_usage - allocation.final}\n"
current_usage -= allocation.final
Expand All @@ -100,17 +91,14 @@ def update_limit_for_account(account: ResearchGroup, cluster: Cluster) -> None:
log.warning(f"The current usage is somehow higher than the limit for {account.name}!")

# Set the new account usage limit using the updated historical usage after closing any expired allocations
expired_requests = approved_query.filter(request__expire__lte=date.today())
updated_historical_usage = expired_requests.aggregate(Sum("final"))['final__sum'] or 0

updated_historical_usage = Allocation.objects.historical_usage(account, cluster)
updated_limit = updated_historical_usage + active_sus
slurm.set_cluster_limit(account.name, cluster.name, updated_limit)

# Log summary of changes during limits update for this Slurm account on this cluster
log.debug(f"Summary of limits update for {account.name} on {cluster.name}:\n"
f"> Approved allocations found: {len(approved_query)}\n"
f"> Service units from {len(active_query)} active allocations: {active_sus}\n"
f"> Service units from {len(closing_query)} closing allocations: {closing_sus}\n"
f"> Service units from active allocations: {active_sus}\n"
f"> Service units from closing allocations: {closing_sus}\n"
f"> {closing_summary}"
f"> historical usage change: {historical_usage} -> {updated_historical_usage}\n"
f"> limit change: {current_limit} -> {updated_limit}")
15 changes: 7 additions & 8 deletions keystone_api/apps/allocations/tasks/notifications.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Background tasks for issuing user notifications."""

import logging
from datetime import timedelta

from celery import shared_task
from django.utils import timezone
Expand All @@ -13,10 +12,10 @@

log = logging.getLogger(__name__)

__all__ = ['send_expiry_notifications', 'send_expiry_notification_for_request']
__all__ = ['send_notifications', 'send_expiry_notification']


def send_expiry_notification_for_request(user: User, request: AllocationRequest) -> None:
def send_expiry_notification(user: User, request: AllocationRequest) -> None:
"""Send any pending expiration notices to the given user.
A notification is only generated if warranted by the user's notification preferences.
Expand Down Expand Up @@ -72,20 +71,20 @@ def send_expiry_notification_for_request(user: User, request: AllocationRequest)


@shared_task()
def send_expiry_notifications() -> None:
def send_notifications() -> None:
"""Send any pending expiration notices to all users."""

expiring_requests = AllocationRequest.objects.filter(
active_requests = AllocationRequest.objects.filter(
status=AllocationRequest.StatusChoices.APPROVED,
expire__gte=timezone.now() - timedelta(days=7)
expire__gte=timezone.now()
).all()

failed = False
for request in expiring_requests:
for request in active_requests:
for user in request.group.get_all_members():

try:
send_expiry_notification_for_request(user, request)
send_expiry_notification(user, request)

except Exception as error:
log.exception(f'Error notifying user {user.username} for request {request.id}: {error}')
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Unit tests for the `AllocationManager` class."""

from django.test import TestCase
from django.utils import timezone

from apps.allocations.models import *
from apps.users.models import *


class GetAllocationData(TestCase):
"""Test get methods used to retrieve allocation metadata/status."""

def setUp(self) -> None:
"""Create test data."""

self.user = User.objects.create(username="user", password='foobar123!')
self.group = ResearchGroup.objects.create(name="Research Group 1", pi=self.user)
self.cluster = Cluster.objects.create(name="Test Cluster")

# An allocation request pending review
self.request1 = AllocationRequest.objects.create(
group=self.group,
status='PD',
active=timezone.now().date(),
expire=timezone.now().date() + timezone.timedelta(days=30)
)
self.allocation1 = Allocation.objects.create(
requested=100,
awarded=80,
final=None,
cluster=self.cluster,
request=self.request1
)

# An approved allocation request that is active
self.request2 = AllocationRequest.objects.create(
group=self.group,
status='AP',
active=timezone.now().date(),
expire=timezone.now().date() + timezone.timedelta(days=30)
)
self.allocation2 = Allocation.objects.create(
requested=100,
awarded=80,
final=None,
cluster=self.cluster,
request=self.request2
)

# An approved allocation request that is expired without final usage
self.request3 = AllocationRequest.objects.create(
group=self.group,
status='AP',
active=timezone.now().date() - timezone.timedelta(days=60),
expire=timezone.now().date() - timezone.timedelta(days=30)
)
self.allocation3 = Allocation.objects.create(
requested=100,
awarded=70,
final=None,
cluster=self.cluster,
request=self.request3
)

# An approved allocation request that is expired with final usage
self.request4 = AllocationRequest.objects.create(
group=self.group,
status='AP',
active=timezone.now().date() - timezone.timedelta(days=30),
expire=timezone.now().date()
)
self.allocation4 = Allocation.objects.create(
requested=100,
awarded=60,
final=60,
cluster=self.cluster,
request=self.request4
)

def test_approved_allocations(self) -> None:
"""Test the `approved_allocations` method returns only approved allocations."""

approved_allocations = Allocation.objects.approved_allocations(self.group, self.cluster)
expected_allocations = [self.allocation2, self.allocation3, self.allocation4]
self.assertQuerySetEqual(expected_allocations, approved_allocations, ordered=False)

def test_active_allocations(self) -> None:
"""Test the `active_allocations` method returns only active allocations."""

active_allocations = Allocation.objects.active_allocations(self.group, self.cluster)
expected_allocations = [self.allocation2]
self.assertQuerySetEqual(expected_allocations, active_allocations, ordered=False)

def test_expired_allocations(self) -> None:
"""Test the `expired_allocations` method returns only expired allocations."""

expiring_allocations = Allocation.objects.expiring_allocations(self.group, self.cluster)
expected_allocations = [self.allocation3]
self.assertQuerySetEqual(expected_allocations, expiring_allocations, ordered=False)

def test_active_service_units(self) -> None:
"""Test the `active_service_units` method returns the total awarded service units for active allocations."""

active_su = Allocation.objects.active_service_units(self.group, self.cluster)
self.assertEqual(80, active_su)

def test_expired_service_units(self) -> None:
"""Test the `expired_service_units` method returns the total awarded service units for expired allocations."""

expiring_su = Allocation.objects.expiring_service_units(self.group, self.cluster)
self.assertEqual(70, expiring_su)

def test_historical_usage(self) -> None:
"""Test the `historical_usage` method returns the total final usage for expired allocations."""

historical_usage = Allocation.objects.historical_usage(self.group, self.cluster)
self.assertEqual(60, historical_usage)
Loading

0 comments on commit ec3f867

Please sign in to comment.