From 1fc3972623389550b706cb32e5389526bb82f878 Mon Sep 17 00:00:00 2001 From: michaelroytman Date: Tue, 26 Nov 2024 10:51:23 -0500 Subject: [PATCH] feat: add BFFE endpoint for Learning Assistant to get all necessary data to function This commit adds a back-end-for-frontend (BFFE) endpoint for the Learning Assistant to get all the necessary data it needs to function. The response from this endpoint includes the following information. * whether the Learning Assistant is enabled * message history information, if the learner is eligible to use the Learning Assistant * audit trial information --- CHANGELOG.rst | 5 + learning_assistant/__init__.py | 2 +- learning_assistant/api.py | 104 ++++++++- learning_assistant/data.py | 13 ++ learning_assistant/urls.py | 12 +- learning_assistant/views.py | 165 +++++++++++-- requirements/dev.txt | 7 + requirements/doc.txt | 17 +- requirements/quality.txt | 11 +- requirements/test.in | 1 + requirements/test.txt | 6 + test_settings.py | 2 + tests/__init__.py | 0 tests/test_api.py | 194 ++++++++++++++-- tests/test_views.py | 410 ++++++++++++++++++++++++++++++--- 15 files changed, 864 insertions(+), 85 deletions(-) delete mode 100644 tests/__init__.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cf031d8..96feb25 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,7 +13,12 @@ Change Log Unreleased ********** + +4.5.0 - 2024-12-04 +****************** * Add local setup to readme +* Add a BFFE chat summary endpoint for Learning Assistant, including information about whether the Learning Assistant is + enabled, Learning Assistant message history, and Learning Assistant audit trial data. 4.4.7 - 2024-11-25 ****************** diff --git a/learning_assistant/__init__.py b/learning_assistant/__init__.py index 25bfe14..b14e6c9 100644 --- a/learning_assistant/__init__.py +++ b/learning_assistant/__init__.py @@ -2,6 +2,6 @@ Plugin for a learning assistant backend, intended for use within edx-platform. """ -__version__ = '4.4.7' +__version__ = '4.5.0' default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 3857945..9483888 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -1,6 +1,7 @@ """ Library for the learning_assistant app. """ +import datetime import logging from datetime import datetime, timedelta @@ -11,8 +12,13 @@ from jinja2 import BaseLoader, Environment from opaque_keys import InvalidKeyError -from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, AUDIT_TRIAL_MAX_DAYS, CATEGORY_TYPE_MAP -from learning_assistant.data import LearningAssistantCourseEnabledData +try: + from common.djangoapps.course_modes.models import CourseMode +except ImportError: + CourseMode = None + +from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP +from learning_assistant.data import LearningAssistantAuditTrialData, LearningAssistantCourseEnabledData from learning_assistant.models import ( LearningAssistantAuditTrial, LearningAssistantCourseEnabled, @@ -231,15 +237,71 @@ def get_message_history(courserun_key, user, message_count): return message_history -def audit_trial_is_expired(user, upgrade_deadline): +def get_audit_trial_expiration_date(start_date): """ - Given a user (User), get or create the corresponding LearningAssistantAuditTrial trial object. + Given a start date of an audit trial, calculate the expiration date of the audit trial. + + Arguments: + * start_date (datetime): the start date of the audit trial + + Returns: + * expiration_date (datetime): the expiration date of the audit trial """ - # If the upgrade deadline has passed, return "True" for expired - DAYS_SINCE_UPGRADE_DEADLINE = datetime.now() - upgrade_deadline - if DAYS_SINCE_UPGRADE_DEADLINE >= timedelta(days=0): - return True + default_trial_length_days = 14 + + trial_length_days = getattr(settings, 'LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS', default_trial_length_days) + + if trial_length_days is None: + trial_length_days = default_trial_length_days + + # If LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS is set to a negative number, assume it should be 0. + # pylint: disable=consider-using-max-builtin + if trial_length_days < 0: + trial_length_days = 0 + + expiration_datetime = start_date + timedelta(days=trial_length_days) + return expiration_datetime + + +def get_audit_trial(user): + """ + Given a user, return the associated audit trial data. + + Arguments: + * user (User): the user + + Returns: + * audit_trial_data (LearningAssistantAuditTrialData): the audit trial data + * user_id (int): the user's id + * start_date (datetime): the start date of the audit trial + * expiration_date (datetime): the expiration date of the audit trial + * None: if no audit trial exists for the user + """ + try: + audit_trial = LearningAssistantAuditTrial.objects.get(user=user) + except LearningAssistantAuditTrial.DoesNotExist: + return None + + return LearningAssistantAuditTrialData( + user_id=user.id, + start_date=audit_trial.start_date, + expiration_date=get_audit_trial_expiration_date(audit_trial.start_date), + ) + + +def get_or_create_audit_trial(user): + """ + Given a user, return the associated audit trial data, creating a new audit trial for the user if one does not exist. + Arguments: + * user (User): the user + + Returns: + * audit_trial_data (LearningAssistantAuditTrialData): the audit trial data + * user_id (int): the user's id + * start_date (datetime): the start date of the audit trial + * expiration_date (datetime): the expiration date of the audit trial + """ audit_trial, _ = LearningAssistantAuditTrial.objects.get_or_create( user=user, defaults={ @@ -247,6 +309,26 @@ def audit_trial_is_expired(user, upgrade_deadline): }, ) - # If the user's trial is past its expiry date, return "True" for expired. Else, return False - DAYS_SINCE_TRIAL_START_DATE = datetime.now() - audit_trial.start_date - return DAYS_SINCE_TRIAL_START_DATE >= timedelta(days=AUDIT_TRIAL_MAX_DAYS) + return LearningAssistantAuditTrialData( + user_id=user.id, + start_date=audit_trial.start_date, + expiration_date=get_audit_trial_expiration_date(audit_trial.start_date), + ) + + +def audit_trial_is_expired(audit_trial_data, courserun_key): + """ + Given a user (User), get or create the corresponding LearningAssistantAuditTrial trial object. + """ + course_mode = CourseMode.objects.get(course=courserun_key) + + upgrade_deadline = course_mode.expiration_datetime() + + # If the upgrade deadline has passed, return True for expired. Upgrade deadline is an optional attribute of a + # CourseMode, so if it's None, then do not return True. + days_until_upgrade_deadline = datetime.now() - upgrade_deadline if upgrade_deadline else None + if days_until_upgrade_deadline is not None and days_until_upgrade_deadline >= timedelta(days=0): + return True + + # If the user's trial is past its expiry date, return True for expired. Else, return False. + return audit_trial_data is None or audit_trial_data.expiration_date <= datetime.now() diff --git a/learning_assistant/data.py b/learning_assistant/data.py index 2e89c27..e9e923b 100644 --- a/learning_assistant/data.py +++ b/learning_assistant/data.py @@ -1,6 +1,8 @@ """ Data classes for the Learning Assistant application. """ +from datetime import datetime + from attrs import field, frozen, validators from opaque_keys.edx.keys import CourseKey @@ -13,3 +15,14 @@ class LearningAssistantCourseEnabledData: course_key: CourseKey = field(validator=validators.instance_of(CourseKey)) enabled: bool = field(validator=validators.instance_of(bool)) + + +@frozen +class LearningAssistantAuditTrialData: + """ + Data class representing an audit learner's trial of the Learning Assistant. + """ + + user_id: int = field(validator=validators.instance_of(int)) + start_date: datetime = field(validator=validators.optional(validators.instance_of(datetime))) + expiration_date: datetime = field(validator=validators.optional(validators.instance_of(datetime))) diff --git a/learning_assistant/urls.py b/learning_assistant/urls.py index b0dfb48..31914b8 100644 --- a/learning_assistant/urls.py +++ b/learning_assistant/urls.py @@ -4,7 +4,12 @@ from django.urls import re_path from learning_assistant.constants import COURSE_ID_PATTERN -from learning_assistant.views import CourseChatView, LearningAssistantEnabledView, LearningAssistantMessageHistoryView +from learning_assistant.views import ( + CourseChatView, + LearningAssistantChatSummaryView, + LearningAssistantEnabledView, + LearningAssistantMessageHistoryView, +) app_name = 'learning_assistant' @@ -24,4 +29,9 @@ LearningAssistantMessageHistoryView.as_view(), name='message-history', ), + re_path( + fr'learning_assistant/v1/course_id/{COURSE_ID_PATTERN}/chat-summary', + LearningAssistantChatSummaryView.as_view(), + name='chat-summary', + ), ] diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 8d9ed62..58351e1 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -18,12 +18,16 @@ from common.djangoapps.student.models import CourseEnrollment from lms.djangoapps.courseware.access import get_user_role except ImportError: - pass + CourseMode = None + CourseEnrollment = None + get_user_role = None from learning_assistant.api import ( audit_trial_is_expired, + get_audit_trial, get_course_id, get_message_history, + get_or_create_audit_trial, learning_assistant_enabled, render_prompt_template, save_chat_message, @@ -39,6 +43,19 @@ class CourseChatView(APIView): """ View to retrieve chat response. + + Accepts: [POST] + + Path: /learning_assistant/v1/course_id/{course_run_id} + + Parameters: + * course_run_id: the ID of the course + + Responses: + * 200: OK + * 400: Malformed Request - Course ID is not a valid course ID. + * 403: Forbidden - Learning assistant not enabled for course or learner does not have a valid enrollment or is + not staff. """ authentication_classes = (SessionAuthentication, JwtAuthentication,) @@ -127,7 +144,7 @@ def post(self, request, course_run_id): # If the user is in a verified course mode or is staff, return the next message if ( - # Here we include CREDIT_MODE and NO_ID_PROFESSIONAL_MODE, as CourseMode.VERIFIED_MODES on its own + # Here we include CREDIT_MODES and NO_ID_PROFESSIONAL_MODE, as CourseMode.VERIFIED_MODES on its own # doesn't match what we count as "verified modes" in the frontend component. enrollment_mode in CourseMode.VERIFIED_MODES + CourseMode.CREDIT_MODES + [CourseMode.NO_ID_PROFESSIONAL_MODE] @@ -138,11 +155,9 @@ def post(self, request, course_run_id): # If user has an audit enrollment record, get or create their trial. If the trial is not expired, return the # next message. Otherwise, return 403 elif enrollment_mode in CourseMode.UPSELL_TO_VERIFIED_MODES: # AUDIT, HONOR - course_mode = CourseMode.objects.get(course=courserun_key) - upgrade_deadline = course_mode.expiration_datetime() - - user_audit_trial_expired = audit_trial_is_expired(request.user, upgrade_deadline) - if user_audit_trial_expired: + audit_trial = get_or_create_audit_trial(request.user) + is_user_audit_trial_expired = audit_trial_is_expired(audit_trial, courserun_key) + if is_user_audit_trial_expired: return Response( status=http_status.HTTP_403_FORBIDDEN, data={'detail': 'The audit trial for this user has expired.'} @@ -164,14 +179,14 @@ class LearningAssistantEnabledView(APIView): View to retrieve whether the Learning Assistant is enabled for a course. This endpoint returns a boolean representing whether the Learning Assistant feature is enabled in a course - represented by the course_key, which is provided in the URL. + represented by the course_run_id, which is provided in the URL. Accepts: [GET] - Path: /learning_assistant/v1/course_id/{course_key}/enabled + Path: /learning_assistant/v1/course_id/{course_run_id}/enabled Parameters: - * course_key: the ID of the course + * course_run_id: the ID of the course Responses: * 200: OK @@ -209,18 +224,20 @@ class LearningAssistantMessageHistoryView(APIView): View to retrieve the message history for user in a course. This endpoint returns the message history stored in the LearningAssistantMessage table in a course - represented by the course_key, which is provided in the URL. + represented by the course_run_id, which is provided in the URL. Accepts: [GET] - Path: /learning_assistant/v1/course_id/{course_key}/history + Path: /learning_assistant/v1/course_id/{course_run_id}/history Parameters: - * course_key: the ID of the course + * course_run_id: the ID of the course Responses: * 200: OK * 400: Malformed Request - Course ID is not a valid course ID. + * 403: Forbidden - Learning assistant not enabled for course or learner does not have a valid enrollment or is + not staff. """ authentication_classes = (SessionAuthentication, JwtAuthentication,) @@ -271,3 +288,125 @@ def get(self, request, course_run_id): message_history = get_message_history(courserun_key, user, message_count) data = MessageSerializer(message_history, many=True).data return Response(status=http_status.HTTP_200_OK, data=data) + + +class LearningAssistantChatSummaryView(APIView): + """ + View to retrieve data about a learner's session with the Learning Assistant. + + This endpoint returns all the data necessary for the Learning Assistant to function, including the following + information. + * whether the Learning Assistant is enabled + * message history information, if the learner is eligible to use the Learning Assistant + * audit trial information + + Accepts: [GET] + + Path: /learning_assistant/v1/course_id/{course_run_id}/chat-summary + + Parameters: + * course_run_id: the ID of the course + + Responses: + * 200: OK + * 400: Malformed Request - Course ID is not a valid course ID. + """ + + authentication_classes = (SessionAuthentication, JwtAuthentication,) + permission_classes = (IsAuthenticated,) + + def get(self, request, course_run_id): + """ + Given a course run ID, return all the data necessary for the Learning Assistant to fuction. + + The response will be in the following format. + + { + "enabled": true, + "message_history": [ + { + "role": "user", + "content": "test message from user", + "timestamp": "2024-12-02T15:04:17.495928Z" + }, + { + "role": "assistant", + "content": "test message from assistant", + "timestamp": "2024-12-02T15:04:40.084584Z" + } + ], + "trial": { + "start_date": "2024-12-02T14:59:16.148236Z", + "expiration_date": "2024-12-16T14:59:16.148236Z" + } + } + """ + try: + courserun_key = CourseKey.from_string(course_run_id) + except InvalidKeyError: + return Response( + status=http_status.HTTP_400_BAD_REQUEST, + data={'detail': 'Course ID is not a valid course ID.'} + ) + + data = {} + user = request.user + + # Get whether the Learning Assistant is enabled. + data['enabled'] = learning_assistant_enabled(courserun_key) + + # Get message history. + # If user does not have a verified enrollment record or is does not have an active audit trial, or is not staff, + # then they should not have access to the message history. + user_role = get_user_role(user, courserun_key) + enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key) + enrollment_mode = enrollment_object.mode if enrollment_object else None + + # Here we include CREDIT_MODES and NO_ID_PROFESSIONAL_MODE, as CourseMode.VERIFIED_MODES on its own + # doesn't match what we count as "verified modes" in the frontend component. We also include AUDIT and HONOR to + # ensure learners with audit trials see message history if the trial is non-expired. + valid_full_access_modes = ( + CourseMode.VERIFIED_MODES + + CourseMode.CREDIT_MODES + + [CourseMode.NO_ID_PROFESSIONAL_MODE] + ) + valid_trial_access_modes = CourseMode.UPSELL_TO_VERIFIED_MODES + + # Get audit trial. Note that we do not want to create an audit trial when calling this endpoint. + audit_trial = get_audit_trial(request.user) + + # If the learner doesn't meet criteria to use the Learning Assistant, or if the chat history is disabled, we + # return no messages in the response. + message_history_data = [] + + has_trial_access = ( + enrollment_mode in valid_trial_access_modes + and audit_trial + and not audit_trial_is_expired(audit_trial, courserun_key) + ) + + if ( + ( + (enrollment_mode in valid_full_access_modes) + or has_trial_access + or user_role_is_staff(user_role) + ) + and chat_history_enabled(courserun_key) + ): + message_count = int(request.GET.get('message_count', 50)) + message_history = get_message_history(courserun_key, user, message_count) + message_history_data = MessageSerializer(message_history, many=True).data + + data['message_history'] = message_history_data + + # Get audit trial. + trial = get_audit_trial(user) + + trial_data = {} + if trial: + trial_data['start_date'] = trial.start_date + trial_data['expiration_date'] = trial.expiration_date + + data['audit_trial'] = trial_data + + return Response(status=http_status.HTTP_200_OK, data=data) diff --git a/requirements/dev.txt b/requirements/dev.txt index ca3d2d1..5b600f6 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -140,6 +140,8 @@ filelock==3.16.1 # -r requirements/ci.txt # tox # virtualenv +freezegun==1.5.1 + # via -r requirements/quality.txt idna==3.10 # via # -r requirements/quality.txt @@ -274,6 +276,10 @@ pytest-cov==6.0.0 # via -r requirements/quality.txt pytest-django==4.9.0 # via -r requirements/quality.txt +python-dateutil==2.9.0.post0 + # via + # -r requirements/quality.txt + # freezegun python-slugify==8.0.4 # via # -r requirements/quality.txt @@ -300,6 +306,7 @@ six==1.16.0 # via # -r requirements/quality.txt # edx-lint + # python-dateutil snowballstemmer==2.2.0 # via # -r requirements/quality.txt diff --git a/requirements/doc.txt b/requirements/doc.txt index 03e90e8..8bd5308 100644 --- a/requirements/doc.txt +++ b/requirements/doc.txt @@ -46,7 +46,6 @@ cryptography==43.0.3 # via # -r requirements/test.txt # pyjwt - # secretstorage ddt==1.7.2 # via -r requirements/test.txt django==4.2.16 @@ -105,6 +104,8 @@ edx-opaque-keys==2.11.0 # edx-drf-extensions edx-rest-api-client==6.0.0 # via -r requirements/test.txt +freezegun==1.5.1 + # via -r requirements/test.txt idna==3.10 # via # -r requirements/test.txt @@ -125,10 +126,6 @@ jaraco-context==6.0.1 # via keyring jaraco-functools==4.1.0 # via keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.4 # via # -r requirements/test.txt @@ -209,6 +206,10 @@ pytest-cov==6.0.0 # via -r requirements/test.txt pytest-django==4.9.0 # via -r requirements/test.txt +python-dateutil==2.9.0.post0 + # via + # -r requirements/test.txt + # freezegun python-slugify==8.0.4 # via # -r requirements/test.txt @@ -239,12 +240,14 @@ rfc3986==2.0.0 # via twine rich==13.9.4 # via twine -secretstorage==3.3.3 - # via keyring semantic-version==2.10.0 # via # -r requirements/test.txt # edx-drf-extensions +six==1.16.0 + # via + # -r requirements/test.txt + # python-dateutil snowballstemmer==2.2.0 # via sphinx sphinx==8.1.3 diff --git a/requirements/quality.txt b/requirements/quality.txt index eea5bb9..ed97b2c 100644 --- a/requirements/quality.txt +++ b/requirements/quality.txt @@ -102,6 +102,8 @@ edx-opaque-keys==2.11.0 # edx-drf-extensions edx-rest-api-client==6.0.0 # via -r requirements/test.txt +freezegun==1.5.1 + # via -r requirements/test.txt idna==3.10 # via # -r requirements/test.txt @@ -191,6 +193,10 @@ pytest-cov==6.0.0 # via -r requirements/test.txt pytest-django==4.9.0 # via -r requirements/test.txt +python-dateutil==2.9.0.post0 + # via + # -r requirements/test.txt + # freezegun python-slugify==8.0.4 # via # -r requirements/test.txt @@ -213,7 +219,10 @@ semantic-version==2.10.0 # -r requirements/test.txt # edx-drf-extensions six==1.16.0 - # via edx-lint + # via + # -r requirements/test.txt + # edx-lint + # python-dateutil snowballstemmer==2.2.0 # via pydocstyle sqlparse==0.5.2 diff --git a/requirements/test.in b/requirements/test.in index 8b21fc7..bb09779 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -5,6 +5,7 @@ code-annotations # provides commands used by the pii_check make target. ddt +freezegun pytest-cov # pytest extension for code coverage statistics pytest-django # pytest extension for better Django support responses diff --git a/requirements/test.txt b/requirements/test.txt index 35bcf93..ad6ed72 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -85,6 +85,8 @@ edx-opaque-keys==2.11.0 # edx-drf-extensions edx-rest-api-client==6.0.0 # via -r requirements/base.txt +freezegun==1.5.1 + # via -r requirements/test.in idna==3.10 # via # -r requirements/base.txt @@ -141,6 +143,8 @@ pytest-cov==6.0.0 # via -r requirements/test.in pytest-django==4.9.0 # via -r requirements/test.in +python-dateutil==2.9.0.post0 + # via freezegun python-slugify==8.0.4 # via code-annotations pyyaml==6.0.2 @@ -159,6 +163,8 @@ semantic-version==2.10.0 # via # -r requirements/base.txt # edx-drf-extensions +six==1.16.0 + # via python-dateutil sqlparse==0.5.2 # via # -r requirements/base.txt diff --git a/test_settings.py b/test_settings.py index 78d99fc..8717a1c 100644 --- a/test_settings.py +++ b/test_settings.py @@ -88,3 +88,5 @@ def root(*args): ) LEARNING_ASSISTANT_AVAILABLE = True + +LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS = 14 diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_api.py b/tests/test_api.py index d73af07..eda7602 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,6 +10,7 @@ from django.contrib.auth import get_user_model from django.core.cache import cache from django.test import TestCase, override_settings +from freezegun import freeze_time from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey, UsageKey @@ -18,16 +19,18 @@ _get_children_contents, _leaf_filter, audit_trial_is_expired, + get_audit_trial, + get_audit_trial_expiration_date, get_block_content, get_message_history, + get_or_create_audit_trial, learning_assistant_available, learning_assistant_enabled, render_prompt_template, save_chat_message, set_learning_assistant_enabled, ) -from learning_assistant.constants import AUDIT_TRIAL_MAX_DAYS -from learning_assistant.data import LearningAssistantCourseEnabledData +from learning_assistant.data import LearningAssistantAuditTrialData, LearningAssistantCourseEnabledData from learning_assistant.models import ( LearningAssistantAuditTrial, LearningAssistantCourseEnabled, @@ -483,6 +486,102 @@ def test_get_message_course_id_differences(self): self.assertEqual(return_value.content, expected_value[i].content) +@ddt.ddt +class GetAuditTrialExpirationDateTests(TestCase): + """ + Test suite for get_audit_trial_expiration_date. + """ + @ddt.data( + (datetime(2024, 1, 1, 0, 0, 0), datetime(2024, 1, 15, 0, 0, 0), None), + (datetime(2024, 1, 18, 0, 0, 0), datetime(2024, 2, 1, 0, 0, 0), None), + (datetime(2024, 1, 1, 0, 0, 0), datetime(2024, 1, 15, 0, 0, 0), 14), + (datetime(2024, 1, 18, 0, 0, 0), datetime(2024, 2, 1, 0, 0, 0), 14), + (datetime(2024, 1, 1, 0, 0, 0), datetime(2024, 1, 1, 0, 0, 0), -1), + (datetime(2024, 1, 18, 0, 0, 0), datetime(2024, 1, 18, 0, 0, 0), -1), + (datetime(2024, 1, 1, 0, 0, 0), datetime(2024, 1, 1, 0, 0, 0), 0), + (datetime(2024, 1, 18, 0, 0, 0), datetime(2024, 1, 18, 0, 0, 0), 0), + (datetime(2024, 1, 1, 0, 0, 0), datetime(2024, 1, 4, 0, 0, 0), 3), + (datetime(2024, 1, 18, 0, 0, 0), datetime(2024, 1, 21, 0, 0, 0), 3), + ) + @ddt.unpack + def test_expiration_date(self, start_date, expected_expiration_date, trial_length_days): + with override_settings(LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS=trial_length_days): + expiration_date = get_audit_trial_expiration_date(start_date) + self.assertEqual(expected_expiration_date, expiration_date) + + +class GetAuditTrialTests(TestCase): + """ + Test suite for get_audit_trial. + """ + @freeze_time('2024-01-01') + def setUp(self): + super().setUp() + self.user = User(username='tester', email='tester@test.com') + self.user.save() + + def test_exists(self): + start_date = datetime.now() + + LearningAssistantAuditTrial.objects.create( + user=self.user, + start_date=start_date + ) + + expected_return = LearningAssistantAuditTrialData( + user_id=self.user.id, + start_date=start_date, + expiration_date=start_date + timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS) + ) + self.assertEqual(expected_return, get_audit_trial(self.user)) + + def test_not_exists(self): + other_user = User(username='other-tester', email='other-tester@test.com') + other_user.save() + + self.assertIsNone(get_audit_trial(self.user)) + + +class GetOrCreateAuditTrialTests(TestCase): + """ + Test suite for get_or_create_audit_trial. + """ + def setUp(self): + super().setUp() + self.user = User(username='tester', email='tester@test.com') + self.user.save() + + @freeze_time('2024-01-01') + def test_exists(self): + start_date = datetime.now() + + LearningAssistantAuditTrial.objects.create( + user=self.user, + start_date=start_date + ) + + expected_return = LearningAssistantAuditTrialData( + user_id=self.user.id, + start_date=start_date, + expiration_date=start_date + timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS) + ) + self.assertEqual(expected_return, get_or_create_audit_trial(self.user)) + + @freeze_time('2024-01-01') + def test_not_exists(self): + other_user = User(username='other-tester', email='other-tester@test.com') + other_user.save() + + start_date = datetime.now() + expected_return = LearningAssistantAuditTrialData( + user_id=self.user.id, + start_date=start_date, + expiration_date=start_date + timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS) + ) + + self.assertEqual(expected_return, get_or_create_audit_trial(self.user)) + + @ddt.ddt class CheckIfAuditTrialIsExpiredTests(TestCase): """ @@ -495,23 +594,86 @@ def setUp(self): self.user = User(username='tester', email='tester@test.com') self.user.save() - self.role = 'verified' self.upgrade_deadline = datetime.now() + timedelta(days=1) # 1 day from now - def test_check_if_past_upgrade_deadline(self): - upgrade_deadline = datetime.now() - timedelta(days=1) # yesterday - self.assertEqual(audit_trial_is_expired(self.user, upgrade_deadline), True) + @freeze_time('2024-01-01') + @patch('learning_assistant.api.CourseMode') + def test_upgrade_deadline_expired(self, mock_course_mode): - def test_audit_trial_is_expired_audit_trial_expired(self): - LearningAssistantAuditTrial.objects.create( - user=self.user, - start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS + 1), # 1 day more than trial deadline + mock_mode = MagicMock() + mock_mode.expiration_datetime.return_value = datetime.now() - timedelta(days=1) # yesterday + mock_course_mode.objects.get.return_value = mock_mode + + start_date = datetime.now() + audit_trial_data = LearningAssistantAuditTrialData( + user_id=self.user.id, + start_date=start_date, + expiration_date=start_date + timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS), ) - self.assertEqual(audit_trial_is_expired(self.user, self.upgrade_deadline), True) - def test_audit_trial_is_expired_audit_trial_unexpired(self): - LearningAssistantAuditTrial.objects.create( - user=self.user, - start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS - 0.99), # 0.99 days less than deadline + self.assertEqual(audit_trial_is_expired(audit_trial_data, self.course_key), True) + + @freeze_time('2024-01-01') + @patch('learning_assistant.api.CourseMode') + def test_upgrade_deadline_none(self, mock_course_mode): + + mock_mode = MagicMock() + mock_mode.expiration_datetime.return_value = None + mock_course_mode.objects.get.return_value = mock_mode + + # Verify that the audit trial data is considered when determing whether an audit trial is expired and not the + # upgrade deadline. + start_date = datetime.now() + audit_trial_data = LearningAssistantAuditTrialData( + user_id=self.user.id, + start_date=start_date, + expiration_date=start_date + timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS), + ) + + self.assertEqual(audit_trial_is_expired(audit_trial_data, self.course_key), False) + + start_date = datetime.now() - timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS + 1) + audit_trial_data = LearningAssistantAuditTrialData( + user_id=self.user.id, + start_date=start_date, + expiration_date=start_date + timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS), + ) + + self.assertEqual(audit_trial_is_expired(audit_trial_data, self.course_key), True) + + @ddt.data( + # exactly the trial deadline + datetime(year=2024, month=1, day=1) - timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS), + # 1 day more than trial deadline + datetime(year=2024, month=1, day=1) - timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS + 1), + ) + @freeze_time('2024-01-01') + @patch('learning_assistant.api.CourseMode') + def test_audit_trial_expired(self, start_date, mock_course_mode): + mock_mode = MagicMock() + mock_mode.expiration_datetime.return_value = datetime.now() + timedelta(days=1) # tomorrow + mock_course_mode.objects.get.return_value = mock_mode + + audit_trial_data = LearningAssistantAuditTrialData( + user_id=self.user.id, + start_date=start_date, + expiration_date=get_audit_trial_expiration_date(start_date), + ) + + self.assertEqual(audit_trial_is_expired(audit_trial_data, self.upgrade_deadline), True) + + @freeze_time('2024-01-01') + @patch('learning_assistant.api.CourseMode') + def test_audit_trial_unexpired(self, mock_course_mode): + mock_mode = MagicMock() + mock_mode.expiration_datetime.return_value = datetime.now() + timedelta(days=1) # tomorrow + mock_course_mode.objects.get.return_value = mock_mode + + start_date = datetime.now() - timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS - 1) + audit_trial_data = LearningAssistantAuditTrialData( + user_id=self.user.id, + start_date=start_date, + expiration_date=get_audit_trial_expiration_date(start_date), ) - self.assertEqual(audit_trial_is_expired(self.user, self.upgrade_deadline), False) + + self.assertEqual(audit_trial_is_expired(audit_trial_data, self.upgrade_deadline), False) diff --git a/tests/test_views.py b/tests/test_views.py index edde3d6..a70aad5 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -5,7 +5,9 @@ import sys from datetime import date, datetime, timedelta from importlib import import_module +from itertools import product from unittest.mock import MagicMock, call, patch +from urllib.parse import urlencode import ddt from django.conf import settings @@ -14,9 +16,10 @@ from django.test import TestCase, override_settings from django.test.client import Client from django.urls import reverse +from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey -from learning_assistant.models import LearningAssistantMessage +from learning_assistant.models import LearningAssistantAuditTrial, LearningAssistantMessage User = get_user_model() @@ -95,13 +98,6 @@ def setUp(self): ) self.patcher.start() - @patch('learning_assistant.views.learning_assistant_enabled') - def test_invalid_course_id(self, mock_learning_assistant_enabled): - mock_learning_assistant_enabled.return_value = True - response = self.client.get(reverse('enabled', kwargs={'course_run_id': self.course_id+'+invalid'})) - - self.assertEqual(response.status_code, 400) - @patch('learning_assistant.views.learning_assistant_enabled') def test_course_waffle_inactive(self, mock_waffle): mock_waffle.return_value = False @@ -111,9 +107,13 @@ def test_course_waffle_inactive(self, mock_waffle): @patch('learning_assistant.views.render_prompt_template') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') - def test_invalid_messages(self, mock_role, mock_waffle, mock_render): + @patch('learning_assistant.views.CourseEnrollment') + @patch('learning_assistant.views.CourseMode') + def test_invalid_messages(self, mock_mode, mock_enrollment, mock_get_user_role, mock_waffle, mock_render): mock_waffle.return_value = True - mock_role.return_value = 'staff' + mock_get_user_role.return_value = 'staff' + mock_mode.VERIFIED_MODES = ['verified'] + mock_enrollment.get_enrollment.return_value = MagicMock(mode='verified') mock_render.return_value = 'This is a template' test_unit_id = 'test-unit-id' @@ -175,7 +175,7 @@ def test_invalid_enrollment_mode(self, mock_mode, mock_enrollment, mock_role, mo response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) - # Test that unexpired audit trials + vierfied track learners get the default chat response + # Test that unexpired audit trials + verified track learners get the default chat response @ddt.data((False, 'verified'), (True, 'audit')) @ddt.unpack @@ -184,7 +184,7 @@ def test_invalid_enrollment_mode(self, mock_mode, mock_enrollment, mock_role, mo @patch('learning_assistant.views.get_chat_response') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') - @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseEnrollment') @patch('learning_assistant.views.CourseMode') @patch('learning_assistant.views.save_chat_message') @patch('learning_assistant.views.chat_history_enabled') @@ -197,19 +197,19 @@ def test_chat_response_default( mock_save_chat_message, mock_mode, mock_enrollment, - mock_role, + mock_get_user_role, mock_waffle, mock_chat_response, mock_render, mock_trial_expired, ): mock_waffle.return_value = True - mock_role.return_value = 'student' + mock_get_user_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] mock_mode.CREDIT_MODES = ['credit'] mock_mode.NO_ID_PROFESSIONAL_MODE = 'no-id' mock_mode.UPSELL_TO_VERIFIED_MODES = ['audit'] - mock_enrollment.return_value = MagicMock(mode=enrollment_mode) + mock_enrollment.get_enrollment.return_value = MagicMock(mode=enrollment_mode) mock_chat_response.return_value = (200, {'role': 'assistant', 'content': 'Something else'}) mock_render.return_value = 'Rendered template mock' mock_trial_expired.return_value = False @@ -283,7 +283,6 @@ def test_invalid_course_id(self, mock_learning_assistant_enabled): self.assertEqual(response.status_code, 400) -@ddt.ddt class LearningAssistantMessageHistoryViewTests(LoggedInTestCase): """ Tests for the LearningAssistantMessageHistoryView @@ -293,13 +292,6 @@ def setUp(self): super().setUp() self.course_id = 'course-v1:edx+test+23' - @patch('learning_assistant.views.learning_assistant_enabled') - def test_invalid_course_id(self, mock_learning_assistant_enabled): - mock_learning_assistant_enabled.return_value = True - response = self.client.get(reverse('enabled', kwargs={'course_run_id': self.course_id+'+invalid'})) - - self.assertEqual(response.status_code, 400) - @patch('learning_assistant.views.learning_assistant_enabled') def test_course_waffle_inactive(self, mock_waffle): mock_waffle.return_value = False @@ -324,21 +316,21 @@ def test_learning_assistant_not_enabled(self, mock_learning_assistant_enabled): @patch('learning_assistant.views.chat_history_enabled') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') - @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseEnrollment') @patch('learning_assistant.views.CourseMode') def test_user_no_enrollment_not_staff( self, mock_mode, mock_enrollment, - mock_role, + mock_get_user_role, mock_assistant_waffle, mock_history_waffle ): mock_assistant_waffle.return_value = True mock_history_waffle.return_value = True - mock_role.return_value = 'student' + mock_get_user_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] - mock_enrollment.return_value = None + mock_enrollment.get_enrollment = MagicMock(return_value=None) message_count = 5 response = self.client.get( @@ -350,21 +342,21 @@ def test_user_no_enrollment_not_staff( @patch('learning_assistant.views.chat_history_enabled') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') - @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseEnrollment') @patch('learning_assistant.views.CourseMode') def test_user_audit_enrollment_not_staff( self, mock_mode, mock_enrollment, - mock_role, + mock_get_user_role, mock_assistant_waffle, mock_history_waffle ): mock_assistant_waffle.return_value = True mock_history_waffle.return_value = True - mock_role.return_value = 'student' + mock_get_user_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] - mock_enrollment.return_value = MagicMock(mode='audit') + mock_enrollment.get_enrollment.return_value = MagicMock(mode='audit') message_count = 5 response = self.client.get( @@ -376,7 +368,7 @@ def test_user_audit_enrollment_not_staff( @patch('learning_assistant.views.chat_history_enabled') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') - @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseEnrollment') @patch('learning_assistant.views.CourseMode') @patch('learning_assistant.views.get_course_id') def test_learning_message_history_view_get( @@ -384,15 +376,15 @@ def test_learning_message_history_view_get( mock_get_course_id, mock_mode, mock_enrollment, - mock_role, + mock_get_user_role, mock_assistant_waffle, mock_history_waffle, ): mock_assistant_waffle.return_value = True mock_history_waffle.return_value = True - mock_role.return_value = 'student' + mock_get_user_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] - mock_enrollment.return_value = MagicMock(mode='verified') + mock_enrollment.get_enrollment.return_value = MagicMock(mode='verified') LearningAssistantMessage.objects.create( course_id=self.course_id, @@ -470,3 +462,351 @@ def test_learning_message_history_view_get_disabled( # Ensure returning an empty list self.assertEqual(len(data), 0) self.assertEqual(data, []) + + +@ddt.ddt +class LearningAssistantChatSummaryViewTests(LoggedInTestCase): + """ + Tests for the LearningAssistantChatSummaryView + """ + sys.modules['lms.djangoapps.courseware.access'] = MagicMock() + sys.modules['lms.djangoapps.courseware.toggles'] = MagicMock() + sys.modules['common.djangoapps.course_modes.models'] = MagicMock() + sys.modules['common.djangoapps.student.models'] = MagicMock() + + def setUp(self): + super().setUp() + self.course_id = 'course-v1:edx+test+23' + + @patch('learning_assistant.views.CourseKey') + def test_invalid_course_id(self, mock_course_key): + mock_course_key.from_string = MagicMock(side_effect=InvalidKeyError('foo', 'bar')) + + response = self.client.get(reverse('chat-summary', kwargs={'course_run_id': self.course_id+'+invalid'})) + + self.assertEqual(response.status_code, 400) + self.assertEqual(response.data['detail'], 'Course ID is not a valid course ID.') + + @ddt.data( + *product( + [True, False], # learning assistant enabled + [True, False], # chat history enabled + ['staff', 'instructor'], # user role + ['verified', 'credit', 'no-id', 'audit', None], # course mode + [True, False], # trial available + [True, False], # trial expired + ) + ) + @ddt.unpack + @patch('learning_assistant.views.audit_trial_is_expired') + @patch('learning_assistant.views.chat_history_enabled') + @patch('learning_assistant.views.learning_assistant_enabled') + @patch('learning_assistant.views.get_user_role') + @patch('learning_assistant.views.CourseEnrollment') + @patch('learning_assistant.views.CourseMode') + def test_chat_summary_with_access_instructor( + self, + learning_assistant_enabled_mock_value, + chat_history_enabled_mock_value, + user_role_mock_value, + course_mode_mock_value, + trial_available, + audit_trial_is_expired_mock_value, + mock_mode, + mock_enrollment, + mock_get_user_role, + mock_learning_assistant_enabled, + mock_chat_history_enabled, + mock_audit_trial_is_expired, + ): + # Set up mocks. + mock_learning_assistant_enabled.return_value = learning_assistant_enabled_mock_value + mock_chat_history_enabled.return_value = chat_history_enabled_mock_value + + mock_get_user_role.return_value = user_role_mock_value + + mock_mode.VERIFIED_MODES = ['verified'] + mock_mode.CREDIT_MODES = ['credit'] + mock_mode.NO_ID_PROFESSIONAL_MODE = 'no-id' + mock_mode.UPSELL_TO_VERIFIED_MODES = ['audit'] + + mock_enrollment.get_enrollment.return_value = MagicMock(mode=course_mode_mock_value) + + # Set up message history data. + if chat_history_enabled_mock_value: + LearningAssistantMessage.objects.create( + course_id=self.course_id, + user=self.user, + role='user', + content='Older message', + created=date(2024, 10, 1) + ) + + LearningAssistantMessage.objects.create( + course_id=self.course_id, + user=self.user, + role='user', + content='Newer message', + created=date(2024, 10, 3) + ) + + db_messages = LearningAssistantMessage.objects.all().order_by('created') + db_messages_count = len(db_messages) + + # Set up audit trial data. + mock_audit_trial_is_expired.return_value = audit_trial_is_expired_mock_value + + trial_start_date = datetime(2024, 1, 1, 0, 0, 0) + if trial_available: + LearningAssistantAuditTrial.objects.create( + user=self.user, + start_date=trial_start_date, + ) + + url_kwargs = {'course_run_id': self.course_id} + url = reverse('chat-summary', kwargs=url_kwargs) + + if chat_history_enabled_mock_value: + query_params = {'message_count': db_messages_count} + url = f"{url}?{urlencode(query_params)}" + + response = self.client.get(url) + + # Assert message history data is correct. + if chat_history_enabled_mock_value: + data = response.data['message_history'] + + # Ensure same number of entries. + self.assertEqual(len(data), db_messages_count) + + # Ensure values are as expected. + for i, message in enumerate(data): + self.assertEqual(message['role'], db_messages[i].role) + self.assertEqual(message['content'], db_messages[i].content) + self.assertEqual(message['timestamp'], db_messages[i].created.isoformat()) + else: + self.assertEqual(response.data['message_history'], []) + + # Assert trial data is correct. + expected_trial_data = {} + if trial_available: + expected_trial_data['start_date'] = trial_start_date + expected_trial_data['expiration_date'] = trial_start_date + timedelta(days=14) + + self.assertEqual(response.data['audit_trial'], expected_trial_data) + + @ddt.data( + *product( + [True, False], # learning assistant enabled + [True, False], # chat history enabled + ['student'], # user role + ['verified', 'credit', 'no-id'], # course mode + [True, False], # trial available + [True, False], # trial expired + ) + ) + @ddt.unpack + @patch('learning_assistant.views.audit_trial_is_expired') + @patch('learning_assistant.views.chat_history_enabled') + @patch('learning_assistant.views.learning_assistant_enabled') + @patch('learning_assistant.views.get_user_role') + @patch('learning_assistant.views.CourseEnrollment') + @patch('learning_assistant.views.CourseMode') + def test_chat_summary_with_full_access_student( + self, + learning_assistant_enabled_mock_value, + chat_history_enabled_mock_value, + user_role_mock_value, + course_mode_mock_value, + trial_available, + audit_trial_is_expired_mock_value, + mock_mode, + mock_enrollment, + mock_get_user_role, + mock_learning_assistant_enabled, + mock_chat_history_enabled, + mock_audit_trial_is_expired, + ): + # Set up mocks. + mock_learning_assistant_enabled.return_value = learning_assistant_enabled_mock_value + mock_chat_history_enabled.return_value = chat_history_enabled_mock_value + + mock_get_user_role.return_value = user_role_mock_value + + mock_mode.VERIFIED_MODES = ['verified'] + mock_mode.CREDIT_MODES = ['credit'] + mock_mode.NO_ID_PROFESSIONAL_MODE = 'no-id' + mock_mode.UPSELL_TO_VERIFIED_MODES = ['audit'] + + mock_enrollment.get_enrollment.return_value = MagicMock(mode=course_mode_mock_value) + + # Set up message history data. + if chat_history_enabled_mock_value: + LearningAssistantMessage.objects.create( + course_id=self.course_id, + user=self.user, + role='user', + content='Older message', + created=date(2024, 10, 1) + ) + + LearningAssistantMessage.objects.create( + course_id=self.course_id, + user=self.user, + role='user', + content='Newer message', + created=date(2024, 10, 3) + ) + + db_messages = LearningAssistantMessage.objects.all().order_by('created') + db_messages_count = len(db_messages) + + # Set up audit trial data. + mock_audit_trial_is_expired.return_value = audit_trial_is_expired_mock_value + + trial_start_date = datetime(2024, 1, 1, 0, 0, 0) + if trial_available: + LearningAssistantAuditTrial.objects.create( + user=self.user, + start_date=trial_start_date, + ) + + url_kwargs = {'course_run_id': self.course_id} + url = reverse('chat-summary', kwargs=url_kwargs) + + if chat_history_enabled_mock_value: + query_params = {'message_count': db_messages_count} + url = f"{url}?{urlencode(query_params)}" + + response = self.client.get(url) + + # Assert message history data is correct. + if chat_history_enabled_mock_value: + data = response.data['message_history'] + + # Ensure same number of entries. + self.assertEqual(len(data), db_messages_count) + + # Ensure values are as expected. + for i, message in enumerate(data): + self.assertEqual(message['role'], db_messages[i].role) + self.assertEqual(message['content'], db_messages[i].content) + self.assertEqual(message['timestamp'], db_messages[i].created.isoformat()) + else: + self.assertEqual(response.data['message_history'], []) + + # Assert trial data is correct. + expected_trial_data = {} + if trial_available: + expected_trial_data['start_date'] = trial_start_date + expected_trial_data['expiration_date'] = trial_start_date + timedelta(days=14) + + self.assertEqual(response.data['audit_trial'], expected_trial_data) + + @ddt.data( + *product( + [True, False], # learning assistant enabled + [True, False], # chat history enabled + ['student'], # user role + ['audit'], # course mode + [True, False], # trial available + [True, False], # trial expired + ) + ) + @ddt.unpack + @patch('learning_assistant.views.audit_trial_is_expired') + @patch('learning_assistant.views.chat_history_enabled') + @patch('learning_assistant.views.learning_assistant_enabled') + @patch('learning_assistant.views.get_user_role') + @patch('learning_assistant.views.CourseEnrollment') + @patch('learning_assistant.views.CourseMode') + def test_chat_summary_with_trial_access_student( + self, + learning_assistant_enabled_mock_value, + chat_history_enabled_mock_value, + user_role_mock_value, + course_mode_mock_value, + trial_available, + audit_trial_is_expired_mock_value, + mock_mode, + mock_enrollment, + mock_get_user_role, + mock_learning_assistant_enabled, + mock_chat_history_enabled, + mock_audit_trial_is_expired, + ): + # Set up mocks. + mock_learning_assistant_enabled.return_value = learning_assistant_enabled_mock_value + mock_chat_history_enabled.return_value = chat_history_enabled_mock_value + + mock_get_user_role.return_value = user_role_mock_value + + mock_mode.VERIFIED_MODES = ['verified'] + mock_mode.CREDIT_MODES = ['credit'] + mock_mode.NO_ID_PROFESSIONAL_MODE = 'no-id' + mock_mode.UPSELL_TO_VERIFIED_MODES = ['audit'] + + mock_enrollment.get_enrollment.return_value = MagicMock(mode=course_mode_mock_value) + + # Set up message history data. + if chat_history_enabled_mock_value: + LearningAssistantMessage.objects.create( + course_id=self.course_id, + user=self.user, + role='user', + content='Older message', + created=date(2024, 10, 1) + ) + + LearningAssistantMessage.objects.create( + course_id=self.course_id, + user=self.user, + role='user', + content='Newer message', + created=date(2024, 10, 3) + ) + + db_messages = LearningAssistantMessage.objects.all().order_by('created') + db_messages_count = len(db_messages) + + # Set up audit trial data. + mock_audit_trial_is_expired.return_value = audit_trial_is_expired_mock_value + + trial_start_date = datetime(2024, 1, 1, 0, 0, 0) + if trial_available: + LearningAssistantAuditTrial.objects.create( + user=self.user, + start_date=trial_start_date, + ) + + url_kwargs = {'course_run_id': self.course_id} + url = reverse('chat-summary', kwargs=url_kwargs) + + if chat_history_enabled_mock_value: + query_params = {'message_count': db_messages_count} + url = f"{url}?{urlencode(query_params)}" + + response = self.client.get(url) + + # Assert message history data is correct. + if chat_history_enabled_mock_value and trial_available and not audit_trial_is_expired_mock_value: + data = response.data['message_history'] + + # Ensure same number of entries. + self.assertEqual(len(data), db_messages_count) + + # Ensure values are as expected. + for i, message in enumerate(data): + self.assertEqual(message['role'], db_messages[i].role) + self.assertEqual(message['content'], db_messages[i].content) + self.assertEqual(message['timestamp'], db_messages[i].created.isoformat()) + else: + self.assertEqual(response.data['message_history'], []) + + # Assert trial data is correct. + expected_trial_data = {} + if trial_available: + expected_trial_data['start_date'] = trial_start_date + expected_trial_data['expiration_date'] = trial_start_date + timedelta(days=14) + + self.assertEqual(response.data['audit_trial'], expected_trial_data)