diff --git a/lms/djangoapps/grades/rest_api/v1/gradebook_views.py b/lms/djangoapps/grades/rest_api/v1/gradebook_views.py index 26df7ce2583..1d6f6e57c3d 100644 --- a/lms/djangoapps/grades/rest_api/v1/gradebook_views.py +++ b/lms/djangoapps/grades/rest_api/v1/gradebook_views.py @@ -670,18 +670,18 @@ def get(self, request, course_key): # lint-amnesty, pylint: disable=too-many-st serializer = StudentGradebookEntrySerializer(entries, many=True) return self.get_paginated_response(serializer.data, **users_counts) - def _get_user_count(self, query_args, cache_time=3600, annotations=None): + def _get_user_count(self, course_key, query_args, cache_time=600, annotations=None): """ Return the user count for the given query arguments to CourseEnrollment. - caches the count for cache_time seconds. + Caches the count for cache_time seconds, the default value is 10 minutes. """ queryset = CourseEnrollment.objects if annotations: queryset = queryset.annotate(**annotations) queryset = queryset.filter(*query_args) - cache_key = 'usercount.%s' % queryset.query + cache_key = 'usercount.{course_key}.{queryset.query}' user_count = cache.get(cache_key, None) if user_count is None: user_count = queryset.count() @@ -710,7 +710,7 @@ def _get_users_counts(self, course_key, course_enrollment_filters, annotations=N Q(course_id=course_key) & Q(is_active=True) ] - total_users_count = self._get_user_count(filter_args) + total_users_count = self._get_user_count(course_key, filter_args) filter_args.extend(course_enrollment_filters or []) @@ -718,7 +718,7 @@ def _get_users_counts(self, course_key, course_enrollment_filters, annotations=N filtered_users_count = ( total_users_count if not course_enrollment_filters - else self._get_user_count(filter_args, annotations=annotations) + else self._get_user_count(course_key, filter_args, annotations=annotations) ) return { diff --git a/lms/djangoapps/grades/signals/handlers.py b/lms/djangoapps/grades/signals/handlers.py index 58c5fb90aac..89fa3a55a85 100644 --- a/lms/djangoapps/grades/signals/handlers.py +++ b/lms/djangoapps/grades/signals/handlers.py @@ -5,7 +5,10 @@ from contextlib import contextmanager from logging import getLogger +from urllib.parse import unquote +from django.conf import settings +from django.core.cache import cache from django.dispatch import receiver from opaque_keys.edx.keys import LearningContextKey from openedx_events.learning.signals import EXAM_ATTEMPT_REJECTED, EXAM_ATTEMPT_VERIFIED @@ -13,7 +16,7 @@ from xblock.scorable import ScorableXBlockMixin, Score from common.djangoapps.student.models import user_by_anonymous_id -from common.djangoapps.student.signals import ENROLLMENT_TRACK_UPDATED +from common.djangoapps.student.signals import ENROLL_STATUS_CHANGE, ENROLLMENT_TRACK_UPDATED from common.djangoapps.track.event_transaction_utils import get_event_transaction_id, get_event_transaction_type from common.djangoapps.util.date_utils import to_timestamp from lms.djangoapps.courseware.model_data import get_score, set_score @@ -347,3 +350,38 @@ def exam_attempt_rejected_event_handler(sender, signal, **kwargs): # pylint: di overrider=None, comment=None, ) + +@receiver(ENROLL_STATUS_CHANGE) +def invalidate_usercount_in_course_cache(sender, signal, **kwargs): # pylint: disable=unused-argument + """ + Invalidate the cache of get_user_count utility on CourseEnrollment status change. + """ + event_data = kwargs.get('exam_attempt') + course_key = event_data.course_id + + cache_key_prefix = f"usercount.{course_key}" + cache_keys = get_cache_keys(cache_key_prefix) + cache.delete_many(cache_keys) + + +def get_cache_keys(cache_key_prefix): + """ + Get all cache keys for the given cache key prefix. + LocMemCache does not have a keys method, so we need to iterate over the cache + and manually filter out the keys that match the given prefix. + """ + cache_backend = settings.CACHES['default']['BACKEND'] + if cache_backend == 'django_redis.cache.RedisCache': + yield cache.iter_keys(f"{cache_key_prefix}*") + elif cache_backend == 'django.core.cache.backends.locmem.LocMemCache': + for key in cache._cache.keys(): # pylint: disable=protected-access + try: + decoded_key = unquote(key.split(':', 2)[-1], encoding='utf-8') + except IndexError: + continue + + if decoded_key.startswith(cache_key_prefix): + yield decoded_key + else: + log.error(f"Unsupported cache backend: {cache_backend}") + yield