From 95206e74d64f56d3fa99b7e9e5e25d127a64521a Mon Sep 17 00:00:00 2001 From: ilee2u Date: Mon, 18 Nov 2024 13:29:54 -0500 Subject: [PATCH 1/8] feat: add audit gate to CourseChatView --- learning_assistant/api.py | 27 +++++++++++++++++++++++++-- learning_assistant/constants.py | 2 ++ learning_assistant/views.py | 19 +++++++++++++------ 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 55c73c1..701fca0 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -10,9 +10,15 @@ from jinja2 import BaseLoader, Environment from opaque_keys import InvalidKeyError -from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP +from datetime import datetime + +from learning_assistant.constants import AUDIT_TRIAL_MAX_DAYS, ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP from learning_assistant.data import LearningAssistantCourseEnabledData -from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage +from learning_assistant.models import ( + LearningAssistantAuditTrial, + LearningAssistantCourseEnabled, + LearningAssistantMessage, +) from learning_assistant.platform_imports import ( block_get_children, block_leaf_filter, @@ -224,3 +230,20 @@ def get_message_history(courserun_key, user, message_count): message_history = list(LearningAssistantMessage.objects.filter( course_id=courserun_key, user=user).order_by('-created')[:message_count])[::-1] return message_history + + +def check_if_audit_trial_is_expired(user_id): + """ + Given a user (User), get the corresponding LearningAssistantAuditTrial trial object, + or create one if one does not exist yet. + """ + audit_trial, created = LearningAssistantAuditTrial.objects.get_or_create(user_id) + + # If the trial was just created, then it definitely isn't expired, so return False + if created: + return False + + # If the user's trial is expired, return True. Else, return False + if (datetime.now() - audit_trial.start_date) < AUDIT_TRIAL_MAX_DAYS: + return True + return False diff --git a/learning_assistant/constants.py b/learning_assistant/constants.py index 7027a28..92e3717 100644 --- a/learning_assistant/constants.py +++ b/learning_assistant/constants.py @@ -14,3 +14,5 @@ "html": "TEXT", "video": "VIDEO", } + +AUDIT_TRIAL_MAX_DAYS = 14 diff --git a/learning_assistant/views.py b/learning_assistant/views.py index b1b5c03..672f063 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -21,6 +21,7 @@ pass from learning_assistant.api import ( + check_if_audit_trial_is_expired, get_course_id, get_message_history, learning_assistant_enabled, @@ -68,18 +69,24 @@ def post(self, request, course_run_id): data={'detail': 'Learning assistant not enabled for course.'} ) - # If user does not have an enrollment record, or is not staff, they should not have access + # If user does not have a verified enrollment record, or is not staff, they should not have full access user_role = get_user_role(request.user, courserun_key) enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key) enrollment_mode = enrollment_object.mode if enrollment_object else None if ( - (enrollment_mode not in CourseMode.VERIFIED_MODES) + # NOTE: Will there ever be a case where the user has a course mod that's + # in neither VERIFIED_MODES nor AUDIT_MODES that we need to worry about? + enrollment_mode not in CourseMode.VERIFIED_MODES + and enrollment_mode in CourseMode.AUDIT_MODES and not user_role_is_staff(user_role) ): - return Response( - status=http_status.HTTP_403_FORBIDDEN, - data={'detail': 'Must be staff or have valid enrollment.'} - ) + # If user has an audit enrollment record, get or create their trial + user_audit_trial_expired = check_if_audit_trial_is_expired(user_id=request.user.id) + if user_audit_trial_expired: + return Response( + status=http_status.HTTP_403_FORBIDDEN, + data={'detail': 'Must be staff or have valid enrollment.'} + ) unit_id = request.query_params.get('unit_id') From 2466783dd54975db1771edf2ebd876b6d7e59c57 Mon Sep 17 00:00:00 2001 From: ilee2u Date: Mon, 18 Nov 2024 14:35:26 -0500 Subject: [PATCH 2/8] test: add tests --- learning_assistant/api.py | 20 +++++++----- learning_assistant/platform_imports.py | 16 +++++----- learning_assistant/views.py | 2 +- tests/__init__.py | 0 tests/test_api.py | 42 +++++++++++++++++++++++++- tests/test_views.py | 21 ++++--------- 6 files changed, 68 insertions(+), 33 deletions(-) delete mode 100644 tests/__init__.py diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 701fca0..00dff18 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -2,6 +2,7 @@ Library for the learning_assistant app. """ import logging +from datetime import datetime, timedelta from django.conf import settings from django.contrib.auth import get_user_model @@ -10,9 +11,7 @@ from jinja2 import BaseLoader, Environment from opaque_keys import InvalidKeyError -from datetime import datetime - -from learning_assistant.constants import AUDIT_TRIAL_MAX_DAYS, ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP +from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, AUDIT_TRIAL_MAX_DAYS, CATEGORY_TYPE_MAP from learning_assistant.data import LearningAssistantCourseEnabledData from learning_assistant.models import ( LearningAssistantAuditTrial, @@ -232,18 +231,23 @@ def get_message_history(courserun_key, user, message_count): return message_history -def check_if_audit_trial_is_expired(user_id): +def check_if_audit_trial_is_expired(user): """ - Given a user (User), get the corresponding LearningAssistantAuditTrial trial object, - or create one if one does not exist yet. + Given a user (User), get or create the corresponding LearningAssistantAuditTrial trial object. """ - audit_trial, created = LearningAssistantAuditTrial.objects.get_or_create(user_id) + audit_trial, created = LearningAssistantAuditTrial.objects.get_or_create( + user=user, + defaults={ + "start_date": datetime.now(), + }, + ) # If the trial was just created, then it definitely isn't expired, so return False if created: return False # If the user's trial is expired, return True. Else, return False - if (datetime.now() - audit_trial.start_date) < AUDIT_TRIAL_MAX_DAYS: + DAYS_SINCE_TRIAL_START_DATE = datetime.now() - audit_trial.start_date + if DAYS_SINCE_TRIAL_START_DATE > timedelta(days=AUDIT_TRIAL_MAX_DAYS): return True return False diff --git a/learning_assistant/platform_imports.py b/learning_assistant/platform_imports.py index 8c2411c..55f13d7 100644 --- a/learning_assistant/platform_imports.py +++ b/learning_assistant/platform_imports.py @@ -8,7 +8,7 @@ def get_text_transcript(video_block): """Get the transcript for a video block in text format, or None.""" - # pylint: disable=import-error, import-outside-toplevel + # pylint: disable=import-outside-toplevel from xmodule.exceptions import NotFoundError from xmodule.video_block.transcripts_utils import get_transcript try: @@ -21,28 +21,28 @@ def get_text_transcript(video_block): def get_single_block(request, user_id, course_id, usage_key_string, course=None): """Load a single xblock.""" - # pylint: disable=import-error, import-outside-toplevel + # pylint: disable=import-outside-toplevel from lms.djangoapps.courseware.block_render import load_single_xblock return load_single_xblock(request, user_id, course_id, usage_key_string, course) def traverse_block_pre_order(start_node, get_children, filter_func=None): """Traverse a DAG or tree in pre-order.""" - # pylint: disable=import-error, import-outside-toplevel + # pylint: disable=import-outside-toplevel from openedx.core.lib.graph_traversals import traverse_pre_order return traverse_pre_order(start_node, get_children, filter_func) def block_leaf_filter(block): """Return only leaf nodes.""" - # pylint: disable=import-error, import-outside-toplevel + # pylint: disable=import-outside-toplevel from openedx.core.lib.graph_traversals import leaf_filter return leaf_filter(block) def block_get_children(block): """Return children of a given block.""" - # pylint: disable=import-error, import-outside-toplevel + # pylint: disable=import-outside-toplevel from openedx.core.lib.graph_traversals import get_children return get_children(block) @@ -54,7 +54,7 @@ def get_cache_course_run_data(course_run_id, fields): This function makes use of the course run cache in the LMS, which caches data from the discovery service. This is necessary because only the discovery service stores the relation between courseruns and courses. """ - # pylint: disable=import-error, import-outside-toplevel + # pylint: disable=import-outside-toplevel from openedx.core.djangoapps.catalog.utils import get_course_run_data return get_course_run_data(course_run_id, fields) @@ -66,7 +66,7 @@ def get_cache_course_data(course_id, fields): This function makes use of the course cache in the LMS, which caches data from the discovery service. This is necessary because only the discovery service stores course skills data. """ - # pylint: disable=import-error, import-outside-toplevel + # pylint: disable=import-outside-toplevel from openedx.core.djangoapps.catalog.utils import get_course_data return get_course_data(course_id, fields) @@ -82,6 +82,6 @@ def get_user_role(user, course_key): Returns: * str: the user's role """ - # pylint: disable=import-error, import-outside-toplevel + # pylint: disable=import-outside-toplevel from lms.djangoapps.courseware.access import get_user_role as platform_get_user_role return platform_get_user_role(user, course_key) diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 672f063..2a5447b 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -81,7 +81,7 @@ def post(self, request, course_run_id): and not user_role_is_staff(user_role) ): # If user has an audit enrollment record, get or create their trial - user_audit_trial_expired = check_if_audit_trial_is_expired(user_id=request.user.id) + user_audit_trial_expired = check_if_audit_trial_is_expired(user=request.user) if user_audit_trial_expired: return Response( status=http_status.HTTP_403_FORBIDDEN, diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_api.py b/tests/test_api.py index 1344dea..1d7c41e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -2,6 +2,7 @@ Test cases for the learning-assistant api module. """ import itertools +from datetime import datetime, timedelta from unittest.mock import MagicMock, patch import ddt @@ -16,6 +17,7 @@ _extract_block_contents, _get_children_contents, _leaf_filter, + check_if_audit_trial_is_expired, get_block_content, get_message_history, learning_assistant_available, @@ -24,8 +26,13 @@ save_chat_message, set_learning_assistant_enabled, ) +from learning_assistant.constants import AUDIT_TRIAL_MAX_DAYS from learning_assistant.data import LearningAssistantCourseEnabledData -from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage +from learning_assistant.models import ( + LearningAssistantAuditTrial, + LearningAssistantCourseEnabled, + LearningAssistantMessage, +) fake_transcript = 'This is the text version from the transcript' User = get_user_model() @@ -241,6 +248,7 @@ class TestLearningAssistantCourseEnabledApi(TestCase): """ Test suite for save_chat_message. """ + def setUp(self): super().setUp() @@ -473,3 +481,35 @@ def test_get_message_course_id_differences(self): self.assertEqual(return_value.user, expected_value[i].user) self.assertEqual(return_value.role, expected_value[i].role) self.assertEqual(return_value.content, expected_value[i].content) + + +@ddt.ddt +class CheckIfAuditTrialIsExpiredTests(TestCase): + """ + Test suite for check_if_audit_trial_is_expired. + """ + + def setUp(self): + super().setUp() + self.course_key = CourseKey.from_string('course-v1:edx+fake+1') + self.user = User(username='tester', email='tester@test.com') + self.user.save() + + self.role = 'verified' + + def test_check_if_audit_trial_is_expired_audit_trial_created(self): + self.assertEqual(check_if_audit_trial_is_expired(self.user), False) + + def test_check_if_audit_trial_is_expired_audit_trial_expired(self): + LearningAssistantAuditTrial.objects.create( + user=self.user, + start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS + 1), + ) + self.assertEqual(check_if_audit_trial_is_expired(self.user), True) + + def test_check_if_audit_trial_is_expired_audit_trial_unexpired(self): + LearningAssistantAuditTrial.objects.create( + user=self.user, + start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS - 0.99), + ) + self.assertEqual(check_if_audit_trial_is_expired(self.user), False) diff --git a/tests/test_views.py b/tests/test_views.py index 1d567b8..f0cd9ef 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -75,7 +75,7 @@ def setUp(self): @ddt.ddt -class TestCourseChatView(LoggedInTestCase): +class CourseChatViewTests(LoggedInTestCase): """ Test for the CourseChatView """ @@ -108,28 +108,19 @@ def test_course_waffle_inactive(self, mock_waffle): response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) + @patch('learning_assistant.views.check_if_audit_trial_is_expired') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @patch('learning_assistant.views.CourseMode') - def test_user_no_enrollment_not_staff(self, mock_mode, mock_enrollment, mock_role, mock_waffle): - mock_waffle.return_value = True - mock_role.return_value = 'student' - mock_mode.VERIFIED_MODES = ['verified'] - mock_enrollment.return_value = None - - response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) - self.assertEqual(response.status_code, 403) - - @patch('learning_assistant.views.learning_assistant_enabled') - @patch('learning_assistant.views.get_user_role') - @patch('learning_assistant.views.CourseEnrollment.get_enrollment') - @patch('learning_assistant.views.CourseMode') - def test_user_audit_enrollment_not_staff(self, mock_mode, mock_enrollment, mock_role, mock_waffle): + def test_user_audit_enrollment_not_staff_trial_expired(self, mock_mode, mock_enrollment, mock_role, + mock_waffle, mock_expired): mock_waffle.return_value = True mock_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] + mock_mode.AUDIT_MODES = ['audit'] mock_enrollment.return_value = MagicMock(mode='audit') + mock_expired.return_value = True response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) From 046d13f300eed4c0a305ba783d06ee08f7299d86 Mon Sep 17 00:00:00 2001 From: ilee2u Date: Mon, 18 Nov 2024 15:21:58 -0500 Subject: [PATCH 3/8] feat: add upgrade deadline gate --- learning_assistant/api.py | 10 ++++++++-- learning_assistant/views.py | 6 +++++- tests/test_api.py | 15 ++++++++++----- tests/test_views.py | 8 +++++--- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 00dff18..5797139 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -231,10 +231,15 @@ def get_message_history(courserun_key, user, message_count): return message_history -def check_if_audit_trial_is_expired(user): +def check_if_audit_trial_is_expired(user, upgrade_deadline): """ Given a user (User), get or create the corresponding LearningAssistantAuditTrial trial object. """ + # If the upgrade deadline has passed, return "True" for expired + DAYS_SINCE_UPGRADE_DEADLINE = datetime.now() - upgrade_deadline + if DAYS_SINCE_UPGRADE_DEADLINE >= timedelta(days=0): + return True + audit_trial, created = LearningAssistantAuditTrial.objects.get_or_create( user=user, defaults={ @@ -246,8 +251,9 @@ def check_if_audit_trial_is_expired(user): if created: return False - # If the user's trial is expired, return True. Else, return False + # If the user's trial is past its expiry date, return "True" for expired. Else, return False DAYS_SINCE_TRIAL_START_DATE = datetime.now() - audit_trial.start_date + print("DAYS_SINCE_TRIAL_START_DATE:", DAYS_SINCE_TRIAL_START_DATE) if DAYS_SINCE_TRIAL_START_DATE > timedelta(days=AUDIT_TRIAL_MAX_DAYS): return True return False diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 2a5447b..7e69ee8 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -80,8 +80,12 @@ def post(self, request, course_run_id): and enrollment_mode in CourseMode.AUDIT_MODES and not user_role_is_staff(user_role) ): + # TODO: Add logic to make sure upgrade deadline has not passed. + course_mode = CourseMode.objects.get(course=courserun_key) + upgrade_deadline = course_mode.expiration_datetime() + # If user has an audit enrollment record, get or create their trial - user_audit_trial_expired = check_if_audit_trial_is_expired(user=request.user) + user_audit_trial_expired = check_if_audit_trial_is_expired(request.user, upgrade_deadline) if user_audit_trial_expired: return Response( status=http_status.HTTP_403_FORBIDDEN, diff --git a/tests/test_api.py b/tests/test_api.py index 1d7c41e..df74fac 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -496,20 +496,25 @@ def setUp(self): self.user.save() self.role = 'verified' + self.upgrade_deadline = datetime.now() + timedelta(days=1) # 1 day from now + + def test_check_if_past_upgrade_deadline(self): + upgrade_deadline = datetime.now() - timedelta(days=1) # yesterday + self.assertEqual(check_if_audit_trial_is_expired(self.user, upgrade_deadline), True) def test_check_if_audit_trial_is_expired_audit_trial_created(self): - self.assertEqual(check_if_audit_trial_is_expired(self.user), False) + self.assertEqual(check_if_audit_trial_is_expired(self.user, self.upgrade_deadline), False) def test_check_if_audit_trial_is_expired_audit_trial_expired(self): LearningAssistantAuditTrial.objects.create( user=self.user, - start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS + 1), + start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS + 1), # 1 day more than trial deadline ) - self.assertEqual(check_if_audit_trial_is_expired(self.user), True) + self.assertEqual(check_if_audit_trial_is_expired(self.user, self.upgrade_deadline), True) def test_check_if_audit_trial_is_expired_audit_trial_unexpired(self): LearningAssistantAuditTrial.objects.create( user=self.user, - start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS - 0.99), + start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS - 0.99), # 0.99 days less than deadline ) - self.assertEqual(check_if_audit_trial_is_expired(self.user), False) + self.assertEqual(check_if_audit_trial_is_expired(self.user, self.upgrade_deadline), False) diff --git a/tests/test_views.py b/tests/test_views.py index f0cd9ef..3b47cf7 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,9 +1,9 @@ """ Tests for the learning assistant views. """ -import datetime import json import sys +from datetime import date, datetime, timedelta from importlib import import_module from unittest.mock import MagicMock, call, patch @@ -119,6 +119,8 @@ def test_user_audit_enrollment_not_staff_trial_expired(self, mock_mode, mock_enr mock_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] mock_mode.AUDIT_MODES = ['audit'] + mock_mode.objects.get.return_value = MagicMock() + mock_mode.expiration_datetime.return_value = datetime.now() - timedelta(days=1) mock_enrollment.return_value = MagicMock(mode='audit') mock_expired.return_value = True @@ -340,7 +342,7 @@ def test_learning_message_history_view_get( user=self.user, role='staff', content='Older message', - created=datetime.date(2024, 10, 1) + created=date(2024, 10, 1) ) LearningAssistantMessage.objects.create( @@ -348,7 +350,7 @@ def test_learning_message_history_view_get( user=self.user, role='staff', content='Newer message', - created=datetime.date(2024, 10, 3) + created=date(2024, 10, 3) ) db_messages = LearningAssistantMessage.objects.all().order_by('created') From 04e0710cda3624ee3b5724f7faa3f9c8ec4dbea1 Mon Sep 17 00:00:00 2001 From: ilee2u Date: Mon, 18 Nov 2024 15:58:33 -0500 Subject: [PATCH 4/8] fix: remove "created" test for trial expiration --- learning_assistant/api.py | 9 ++------- tests/test_api.py | 3 --- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 5797139..0d1c67f 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -240,20 +240,15 @@ def check_if_audit_trial_is_expired(user, upgrade_deadline): if DAYS_SINCE_UPGRADE_DEADLINE >= timedelta(days=0): return True - audit_trial, created = LearningAssistantAuditTrial.objects.get_or_create( + audit_trial, _ = LearningAssistantAuditTrial.objects.get_or_create( user=user, defaults={ "start_date": datetime.now(), }, ) - # If the trial was just created, then it definitely isn't expired, so return False - if created: - return False - # If the user's trial is past its expiry date, return "True" for expired. Else, return False DAYS_SINCE_TRIAL_START_DATE = datetime.now() - audit_trial.start_date - print("DAYS_SINCE_TRIAL_START_DATE:", DAYS_SINCE_TRIAL_START_DATE) - if DAYS_SINCE_TRIAL_START_DATE > timedelta(days=AUDIT_TRIAL_MAX_DAYS): + if DAYS_SINCE_TRIAL_START_DATE >= timedelta(days=AUDIT_TRIAL_MAX_DAYS): return True return False diff --git a/tests/test_api.py b/tests/test_api.py index df74fac..d0abc97 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -502,9 +502,6 @@ def test_check_if_past_upgrade_deadline(self): upgrade_deadline = datetime.now() - timedelta(days=1) # yesterday self.assertEqual(check_if_audit_trial_is_expired(self.user, upgrade_deadline), True) - def test_check_if_audit_trial_is_expired_audit_trial_created(self): - self.assertEqual(check_if_audit_trial_is_expired(self.user, self.upgrade_deadline), False) - def test_check_if_audit_trial_is_expired_audit_trial_expired(self): LearningAssistantAuditTrial.objects.create( user=self.user, From 0b01a7bf2e49ba85bebd00b89ff3d395ff2ae9c6 Mon Sep 17 00:00:00 2001 From: ilee2u Date: Wed, 20 Nov 2024 15:22:23 -0500 Subject: [PATCH 5/8] chore: nits --- learning_assistant/api.py | 6 ++---- learning_assistant/views.py | 5 ++--- tests/test_api.py | 14 +++++++------- tests/test_views.py | 2 +- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 0d1c67f..3857945 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -231,7 +231,7 @@ def get_message_history(courserun_key, user, message_count): return message_history -def check_if_audit_trial_is_expired(user, upgrade_deadline): +def audit_trial_is_expired(user, upgrade_deadline): """ Given a user (User), get or create the corresponding LearningAssistantAuditTrial trial object. """ @@ -249,6 +249,4 @@ def check_if_audit_trial_is_expired(user, upgrade_deadline): # If the user's trial is past its expiry date, return "True" for expired. Else, return False DAYS_SINCE_TRIAL_START_DATE = datetime.now() - audit_trial.start_date - if DAYS_SINCE_TRIAL_START_DATE >= timedelta(days=AUDIT_TRIAL_MAX_DAYS): - return True - return False + return DAYS_SINCE_TRIAL_START_DATE >= timedelta(days=AUDIT_TRIAL_MAX_DAYS) diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 7e69ee8..bf468fd 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -21,7 +21,7 @@ pass from learning_assistant.api import ( - check_if_audit_trial_is_expired, + audit_trial_is_expired, get_course_id, get_message_history, learning_assistant_enabled, @@ -80,12 +80,11 @@ def post(self, request, course_run_id): and enrollment_mode in CourseMode.AUDIT_MODES and not user_role_is_staff(user_role) ): - # TODO: Add logic to make sure upgrade deadline has not passed. course_mode = CourseMode.objects.get(course=courserun_key) upgrade_deadline = course_mode.expiration_datetime() # If user has an audit enrollment record, get or create their trial - user_audit_trial_expired = check_if_audit_trial_is_expired(request.user, upgrade_deadline) + user_audit_trial_expired = audit_trial_is_expired(request.user, upgrade_deadline) if user_audit_trial_expired: return Response( status=http_status.HTTP_403_FORBIDDEN, diff --git a/tests/test_api.py b/tests/test_api.py index d0abc97..d73af07 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -17,7 +17,7 @@ _extract_block_contents, _get_children_contents, _leaf_filter, - check_if_audit_trial_is_expired, + audit_trial_is_expired, get_block_content, get_message_history, learning_assistant_available, @@ -486,7 +486,7 @@ def test_get_message_course_id_differences(self): @ddt.ddt class CheckIfAuditTrialIsExpiredTests(TestCase): """ - Test suite for check_if_audit_trial_is_expired. + Test suite for audit_trial_is_expired. """ def setUp(self): @@ -500,18 +500,18 @@ def setUp(self): def test_check_if_past_upgrade_deadline(self): upgrade_deadline = datetime.now() - timedelta(days=1) # yesterday - self.assertEqual(check_if_audit_trial_is_expired(self.user, upgrade_deadline), True) + self.assertEqual(audit_trial_is_expired(self.user, upgrade_deadline), True) - def test_check_if_audit_trial_is_expired_audit_trial_expired(self): + def test_audit_trial_is_expired_audit_trial_expired(self): LearningAssistantAuditTrial.objects.create( user=self.user, start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS + 1), # 1 day more than trial deadline ) - self.assertEqual(check_if_audit_trial_is_expired(self.user, self.upgrade_deadline), True) + self.assertEqual(audit_trial_is_expired(self.user, self.upgrade_deadline), True) - def test_check_if_audit_trial_is_expired_audit_trial_unexpired(self): + def test_audit_trial_is_expired_audit_trial_unexpired(self): LearningAssistantAuditTrial.objects.create( user=self.user, start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS - 0.99), # 0.99 days less than deadline ) - self.assertEqual(check_if_audit_trial_is_expired(self.user, self.upgrade_deadline), False) + self.assertEqual(audit_trial_is_expired(self.user, self.upgrade_deadline), False) diff --git a/tests/test_views.py b/tests/test_views.py index 3b47cf7..bd5aa5b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -108,7 +108,7 @@ def test_course_waffle_inactive(self, mock_waffle): response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) - @patch('learning_assistant.views.check_if_audit_trial_is_expired') + @patch('learning_assistant.views.audit_trial_is_expired') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') From 4de25d821cb42b9ef7a590453803b8fd74f43e10 Mon Sep 17 00:00:00 2001 From: ilee2u Date: Wed, 20 Nov 2024 17:31:35 -0500 Subject: [PATCH 6/8] fix: distinguish between course modes correctly --- learning_assistant/constants.py | 6 ++ learning_assistant/views.py | 117 +++++++++++++++++++------------- tests/__init__.py | 0 tests/test_views.py | 61 +++++++++++------ 4 files changed, 117 insertions(+), 67 deletions(-) create mode 100644 tests/__init__.py diff --git a/learning_assistant/constants.py b/learning_assistant/constants.py index 92e3717..a44e611 100644 --- a/learning_assistant/constants.py +++ b/learning_assistant/constants.py @@ -1,6 +1,12 @@ """ Constants for the learning_assistant app. """ + +try: + from common.djangoapps.course_modes.models import CourseMode +except ImportError: + pass + # Pulled from edx-platform. Will correctly capture both old- and new-style # course ID strings. INTERNAL_COURSE_KEY_PATTERN = r'([^/+]+(/|\+)[^/+]+(/|\+)[^/?]+)' diff --git a/learning_assistant/views.py b/learning_assistant/views.py index bf468fd..8891c4b 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -44,55 +44,10 @@ class CourseChatView(APIView): authentication_classes = (SessionAuthentication, JwtAuthentication,) permission_classes = (IsAuthenticated,) - def post(self, request, course_run_id): + def _get_next_message(self, request, courserun_key, course_run_id): """ - Given a course run ID, retrieve a chat response for that course. - - Expected POST data: { - [ - {'role': 'user', 'content': 'What is 2+2?'}, - {'role': 'assistant', 'content': '4'} - ] - } + Generate the next message to be returned by the learning assistant. """ - try: - courserun_key = CourseKey.from_string(course_run_id) - except InvalidKeyError: - return Response( - status=http_status.HTTP_400_BAD_REQUEST, - data={'detail': 'Course ID is not a valid course ID.'} - ) - - if not learning_assistant_enabled(courserun_key): - return Response( - status=http_status.HTTP_403_FORBIDDEN, - data={'detail': 'Learning assistant not enabled for course.'} - ) - - # If user does not have a verified enrollment record, or is not staff, they should not have full access - user_role = get_user_role(request.user, courserun_key) - enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key) - enrollment_mode = enrollment_object.mode if enrollment_object else None - if ( - # NOTE: Will there ever be a case where the user has a course mod that's - # in neither VERIFIED_MODES nor AUDIT_MODES that we need to worry about? - enrollment_mode not in CourseMode.VERIFIED_MODES - and enrollment_mode in CourseMode.AUDIT_MODES - and not user_role_is_staff(user_role) - ): - course_mode = CourseMode.objects.get(course=courserun_key) - upgrade_deadline = course_mode.expiration_datetime() - - # If user has an audit enrollment record, get or create their trial - user_audit_trial_expired = audit_trial_is_expired(request.user, upgrade_deadline) - if user_audit_trial_expired: - return Response( - status=http_status.HTTP_403_FORBIDDEN, - data={'detail': 'Must be staff or have valid enrollment.'} - ) - - unit_id = request.query_params.get('unit_id') - message_list = request.data # Check that the last message in the list corresponds to a user @@ -127,8 +82,8 @@ def post(self, request, course_run_id): ) course_id = get_course_id(course_run_id) - template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '') + unit_id = request.query_params.get('unit_id') prompt_template = render_prompt_template( request, request.user.id, course_run_id, unit_id, course_id, template_string @@ -140,6 +95,72 @@ def post(self, request, course_run_id): return Response(status=status_code, data=message) + def post(self, request, course_run_id): + """ + Given a course run ID, retrieve a chat response for that course. + + Expected POST data: { + [ + {'role': 'user', 'content': 'What is 2+2?'}, + {'role': 'assistant', 'content': '4'} + ] + } + """ + try: + courserun_key = CourseKey.from_string(course_run_id) + except InvalidKeyError: + return Response( + status=http_status.HTTP_400_BAD_REQUEST, + data={'detail': 'Course ID is not a valid course ID.'} + ) + + if not learning_assistant_enabled(courserun_key): + return Response( + status=http_status.HTTP_403_FORBIDDEN, + data={'detail': 'Learning assistant not enabled for course.'} + ) + + # If user does not have a verified enrollment record, or is not staff, they should not have full access + user_role = get_user_role(request.user, courserun_key) + enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key) + enrollment_mode = enrollment_object.mode if enrollment_object else None + + print("\nenrollment_mode:", enrollment_mode) + # If the user is in a verified course mode or is staff, return the next message + if ( + # Here we include CREDIT_MODE and NO_ID_PROFESSIONAL_MODE, as CourseMode.VERIFIED_MODES on its own + # doesn't match what we count as "verified modes" in the frontend component. + enrollment_mode in CourseMode.VERIFIED_MODES + CourseMode.CREDIT_MODE + CourseMode.NO_ID_PROFESSIONAL_MODE + or user_role_is_staff(user_role) + ): + print("\n\nVERIFIED\n") + return self._get_next_message(request, courserun_key, course_run_id) + + # If user has an audit enrollment record, get or create their trial. If the trial is not expired, return the + # next message. Otherwise, return 403 + elif enrollment_mode in CourseMode.UPSELL_TO_VERIFIED_MODES: # AUDIT, HONOR + print("\n\nAUDIT\n") + course_mode = CourseMode.objects.get(course=courserun_key) + upgrade_deadline = course_mode.expiration_datetime() + + user_audit_trial_expired = audit_trial_is_expired(request.user, upgrade_deadline) + if user_audit_trial_expired: + return Response( + status=http_status.HTTP_403_FORBIDDEN, + data={'detail': 'The audit trial for this user has expired.'} + ) + else: + return self._get_next_message(request, courserun_key, course_run_id) + + # If user has a course mode that is not verified & not meant to access to the learning assistant, return 403 + # This covers the other course modes: UNPAID_EXECUTIVE_EDUCATION & UNPAID_BOOTCAMP + else: + print("\n\nHUH????\n") + return Response( + status=http_status.HTTP_403_FORBIDDEN, + data={'detail': 'Must be staff or have valid enrollment.'} + ) + class LearningAssistantEnabledView(APIView): """ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_views.py b/tests/test_views.py index bd5aa5b..30ed88a 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -108,25 +108,6 @@ def test_course_waffle_inactive(self, mock_waffle): response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) - @patch('learning_assistant.views.audit_trial_is_expired') - @patch('learning_assistant.views.learning_assistant_enabled') - @patch('learning_assistant.views.get_user_role') - @patch('learning_assistant.views.CourseEnrollment.get_enrollment') - @patch('learning_assistant.views.CourseMode') - def test_user_audit_enrollment_not_staff_trial_expired(self, mock_mode, mock_enrollment, mock_role, - mock_waffle, mock_expired): - mock_waffle.return_value = True - mock_role.return_value = 'student' - mock_mode.VERIFIED_MODES = ['verified'] - mock_mode.AUDIT_MODES = ['audit'] - mock_mode.objects.get.return_value = MagicMock() - mock_mode.expiration_datetime.return_value = datetime.now() - timedelta(days=1) - mock_enrollment.return_value = MagicMock(mode='audit') - mock_expired.return_value = True - - response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) - self.assertEqual(response.status_code, 403) - @patch('learning_assistant.views.render_prompt_template') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') @@ -149,6 +130,46 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): ) self.assertEqual(response.status_code, 400) + @patch('learning_assistant.views.audit_trial_is_expired') + @patch('learning_assistant.views.learning_assistant_enabled') + @patch('learning_assistant.views.get_user_role') + @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseMode') + def test_audit_trial_expired(self, mock_mode, mock_enrollment, mock_role, + mock_waffle, mock_trial_expired): + mock_waffle.return_value = True + mock_role.return_value = 'student' + mock_mode.VERIFIED_MODES = ['verified'] + mock_mode.CREDIT_MODE = ['credit'] + mock_mode.NO_ID_PROFESSIONAL_MODE = ['no-id'] + mock_mode.UPSELL_TO_VERIFIED_MODES = ['audit'] + mock_mode.objects.get.return_value = MagicMock() + mock_mode.expiration_datetime.return_value = datetime.now() - timedelta(days=1) + mock_enrollment.return_value = MagicMock(mode='audit') + mock_trial_expired.return_value = True + + response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) + self.assertEqual(response.status_code, 403) + mock_trial_expired.assert_called_once() + + @patch('learning_assistant.views.learning_assistant_enabled') + @patch('learning_assistant.views.get_user_role') + @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseMode') + def test_invalid_enrollment_mode(self, mock_mode, mock_enrollment, mock_role, mock_waffle): + mock_waffle.return_value = True + mock_role.return_value = 'student' + mock_mode.VERIFIED_MODES = ['verified'] + mock_mode.CREDIT_MODE = ['credit'] + mock_mode.NO_ID_PROFESSIONAL_MODE = ['no-id'] + mock_mode.UPSELL_TO_VERIFIED_MODES = ['audit'] + mock_mode.objects.get.return_value = MagicMock() + mock_mode.expiration_datetime.return_value = datetime.now() - timedelta(days=1) + mock_enrollment.return_value = MagicMock(mode='unpaid_executive_education') + + response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) + self.assertEqual(response.status_code, 403) + @ddt.data(False, True) @patch('learning_assistant.views.render_prompt_template') @patch('learning_assistant.views.get_chat_response') @@ -174,6 +195,8 @@ def test_chat_response_default( mock_waffle.return_value = True mock_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] + mock_mode.CREDIT_MODE = ['credit'] + mock_mode.NO_ID_PROFESSIONAL_MODE = ['no-id'] mock_enrollment.return_value = MagicMock(mode='verified') mock_chat_response.return_value = (200, {'role': 'assistant', 'content': 'Something else'}) mock_render.return_value = 'Rendered template mock' From 07e446d9f30d2113d0825201c966fc3b2b833fa1 Mon Sep 17 00:00:00 2001 From: ilee2u Date: Wed, 20 Nov 2024 17:33:49 -0500 Subject: [PATCH 7/8] chore: lint --- learning_assistant/constants.py | 6 ------ learning_assistant/views.py | 4 ---- tests/test_views.py | 3 +-- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/learning_assistant/constants.py b/learning_assistant/constants.py index a44e611..92e3717 100644 --- a/learning_assistant/constants.py +++ b/learning_assistant/constants.py @@ -1,12 +1,6 @@ """ Constants for the learning_assistant app. """ - -try: - from common.djangoapps.course_modes.models import CourseMode -except ImportError: - pass - # Pulled from edx-platform. Will correctly capture both old- and new-style # course ID strings. INTERNAL_COURSE_KEY_PATTERN = r'([^/+]+(/|\+)[^/+]+(/|\+)[^/?]+)' diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 8891c4b..b474984 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -125,7 +125,6 @@ def post(self, request, course_run_id): enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key) enrollment_mode = enrollment_object.mode if enrollment_object else None - print("\nenrollment_mode:", enrollment_mode) # If the user is in a verified course mode or is staff, return the next message if ( # Here we include CREDIT_MODE and NO_ID_PROFESSIONAL_MODE, as CourseMode.VERIFIED_MODES on its own @@ -133,13 +132,11 @@ def post(self, request, course_run_id): enrollment_mode in CourseMode.VERIFIED_MODES + CourseMode.CREDIT_MODE + CourseMode.NO_ID_PROFESSIONAL_MODE or user_role_is_staff(user_role) ): - print("\n\nVERIFIED\n") return self._get_next_message(request, courserun_key, course_run_id) # If user has an audit enrollment record, get or create their trial. If the trial is not expired, return the # next message. Otherwise, return 403 elif enrollment_mode in CourseMode.UPSELL_TO_VERIFIED_MODES: # AUDIT, HONOR - print("\n\nAUDIT\n") course_mode = CourseMode.objects.get(course=courserun_key) upgrade_deadline = course_mode.expiration_datetime() @@ -155,7 +152,6 @@ def post(self, request, course_run_id): # If user has a course mode that is not verified & not meant to access to the learning assistant, return 403 # This covers the other course modes: UNPAID_EXECUTIVE_EDUCATION & UNPAID_BOOTCAMP else: - print("\n\nHUH????\n") return Response( status=http_status.HTTP_403_FORBIDDEN, data={'detail': 'Must be staff or have valid enrollment.'} diff --git a/tests/test_views.py b/tests/test_views.py index 30ed88a..34ef198 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -135,8 +135,7 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @patch('learning_assistant.views.CourseMode') - def test_audit_trial_expired(self, mock_mode, mock_enrollment, mock_role, - mock_waffle, mock_trial_expired): + def test_audit_trial_expired(self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_trial_expired): mock_waffle.return_value = True mock_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] From 09d0c5576ca547fe9c172e79d6436d4ca637bd77 Mon Sep 17 00:00:00 2001 From: ilee2u Date: Thu, 21 Nov 2024 13:08:40 -0500 Subject: [PATCH 8/8] fix: correct test to cover unexpired audit trial --- tests/test_views.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/test_views.py b/tests/test_views.py index 34ef198..98c0ab6 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -135,7 +135,8 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @patch('learning_assistant.views.CourseMode') - def test_audit_trial_expired(self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_trial_expired): + def test_audit_trial_expired(self, mock_mode, mock_enrollment, + mock_role, mock_waffle, mock_trial_expired): mock_waffle.return_value = True mock_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] @@ -145,12 +146,17 @@ def test_audit_trial_expired(self, mock_mode, mock_enrollment, mock_role, mock_w mock_mode.objects.get.return_value = MagicMock() mock_mode.expiration_datetime.return_value = datetime.now() - timedelta(days=1) mock_enrollment.return_value = MagicMock(mode='audit') - mock_trial_expired.return_value = True response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) mock_trial_expired.assert_called_once() + mock_waffle.reset_mock() + mock_role.reset_mock() + mock_mode.reset_mock() + mock_enrollment.reset_mock() + mock_trial_expired.reset_mock() + @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @@ -169,7 +175,11 @@ def test_invalid_enrollment_mode(self, mock_mode, mock_enrollment, mock_role, mo response = self.client.post(reverse('chat', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 403) - @ddt.data(False, True) + # Test that unexpired audit trials + vierfied track learners get the default chat response + @ddt.data((False, 'verified'), + (True, 'audit')) + @ddt.unpack + @patch('learning_assistant.views.audit_trial_is_expired') @patch('learning_assistant.views.render_prompt_template') @patch('learning_assistant.views.get_chat_response') @patch('learning_assistant.views.learning_assistant_enabled') @@ -182,6 +192,7 @@ def test_invalid_enrollment_mode(self, mock_mode, mock_enrollment, mock_role, mo def test_chat_response_default( self, enabled_flag, + enrollment_mode, mock_chat_history_enabled, mock_save_chat_message, mock_mode, @@ -190,15 +201,18 @@ def test_chat_response_default( mock_waffle, mock_chat_response, mock_render, + mock_trial_expired, ): mock_waffle.return_value = True mock_role.return_value = 'student' mock_mode.VERIFIED_MODES = ['verified'] mock_mode.CREDIT_MODE = ['credit'] mock_mode.NO_ID_PROFESSIONAL_MODE = ['no-id'] - mock_enrollment.return_value = MagicMock(mode='verified') + mock_mode.UPSELL_TO_VERIFIED_MODES = ['audit'] + mock_enrollment.return_value = MagicMock(mode=enrollment_mode) mock_chat_response.return_value = (200, {'role': 'assistant', 'content': 'Something else'}) mock_render.return_value = 'Rendered template mock' + mock_trial_expired.return_value = False test_unit_id = 'test-unit-id' mock_chat_history_enabled.return_value = enabled_flag