Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add audit gate to CourseChatView #134

Merged
merged 8 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions learning_assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Library for the learning_assistant app.
"""
import logging
from datetime import datetime, timedelta

from django.conf import settings
from django.contrib.auth import get_user_model
Expand All @@ -10,9 +11,13 @@
from jinja2 import BaseLoader, Environment
from opaque_keys import InvalidKeyError

from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP
from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, AUDIT_TRIAL_MAX_DAYS, CATEGORY_TYPE_MAP
from learning_assistant.data import LearningAssistantCourseEnabledData
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage
from learning_assistant.models import (
LearningAssistantAuditTrial,
LearningAssistantCourseEnabled,
LearningAssistantMessage,
)
from learning_assistant.platform_imports import (
block_get_children,
block_leaf_filter,
Expand Down Expand Up @@ -224,3 +229,24 @@ def get_message_history(courserun_key, user, message_count):
message_history = list(LearningAssistantMessage.objects.filter(
course_id=courserun_key, user=user).order_by('-created')[:message_count])[::-1]
return message_history


def audit_trial_is_expired(user, upgrade_deadline):
"""
Given a user (User), get or create the corresponding LearningAssistantAuditTrial trial object.
"""
# If the upgrade deadline has passed, return "True" for expired
DAYS_SINCE_UPGRADE_DEADLINE = datetime.now() - upgrade_deadline
if DAYS_SINCE_UPGRADE_DEADLINE >= timedelta(days=0):
return True

audit_trial, _ = LearningAssistantAuditTrial.objects.get_or_create(
user=user,
defaults={
"start_date": datetime.now(),
},
)

# If the user's trial is past its expiry date, return "True" for expired. Else, return False
DAYS_SINCE_TRIAL_START_DATE = datetime.now() - audit_trial.start_date
return DAYS_SINCE_TRIAL_START_DATE >= timedelta(days=AUDIT_TRIAL_MAX_DAYS)
2 changes: 2 additions & 0 deletions learning_assistant/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@
"html": "TEXT",
"video": "VIDEO",
}

AUDIT_TRIAL_MAX_DAYS = 14
16 changes: 8 additions & 8 deletions learning_assistant/platform_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def get_text_transcript(video_block):
"""Get the transcript for a video block in text format, or None."""
# pylint: disable=import-error, import-outside-toplevel
# pylint: disable=import-outside-toplevel
from xmodule.exceptions import NotFoundError
from xmodule.video_block.transcripts_utils import get_transcript
try:
Expand All @@ -21,28 +21,28 @@ def get_text_transcript(video_block):

def get_single_block(request, user_id, course_id, usage_key_string, course=None):
"""Load a single xblock."""
# pylint: disable=import-error, import-outside-toplevel
# pylint: disable=import-outside-toplevel
from lms.djangoapps.courseware.block_render import load_single_xblock
return load_single_xblock(request, user_id, course_id, usage_key_string, course)


def traverse_block_pre_order(start_node, get_children, filter_func=None):
"""Traverse a DAG or tree in pre-order."""
# pylint: disable=import-error, import-outside-toplevel
# pylint: disable=import-outside-toplevel
from openedx.core.lib.graph_traversals import traverse_pre_order
return traverse_pre_order(start_node, get_children, filter_func)


def block_leaf_filter(block):
"""Return only leaf nodes."""
# pylint: disable=import-error, import-outside-toplevel
# pylint: disable=import-outside-toplevel
from openedx.core.lib.graph_traversals import leaf_filter
return leaf_filter(block)


def block_get_children(block):
"""Return children of a given block."""
# pylint: disable=import-error, import-outside-toplevel
# pylint: disable=import-outside-toplevel
from openedx.core.lib.graph_traversals import get_children
return get_children(block)

Expand All @@ -54,7 +54,7 @@ def get_cache_course_run_data(course_run_id, fields):
This function makes use of the course run cache in the LMS, which caches data from the discovery service. This is
necessary because only the discovery service stores the relation between courseruns and courses.
"""
# pylint: disable=import-error, import-outside-toplevel
# pylint: disable=import-outside-toplevel
from openedx.core.djangoapps.catalog.utils import get_course_run_data
return get_course_run_data(course_run_id, fields)

Expand All @@ -66,7 +66,7 @@ def get_cache_course_data(course_id, fields):
This function makes use of the course cache in the LMS, which caches data from the discovery service. This is
necessary because only the discovery service stores course skills data.
"""
# pylint: disable=import-error, import-outside-toplevel
# pylint: disable=import-outside-toplevel
from openedx.core.djangoapps.catalog.utils import get_course_data
return get_course_data(course_id, fields)

Expand All @@ -82,6 +82,6 @@ def get_user_role(user, course_key):
Returns:
* str: the user's role
"""
# pylint: disable=import-error, import-outside-toplevel
# pylint: disable=import-outside-toplevel
from lms.djangoapps.courseware.access import get_user_role as platform_get_user_role
return platform_get_user_role(user, course_key)
105 changes: 66 additions & 39 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
pass

from learning_assistant.api import (
audit_trial_is_expired,
get_course_id,
get_message_history,
learning_assistant_enabled,
Expand All @@ -43,46 +44,10 @@
authentication_classes = (SessionAuthentication, JwtAuthentication,)
permission_classes = (IsAuthenticated,)

def post(self, request, course_run_id):
def _get_next_message(self, request, courserun_key, course_run_id):
"""
Given a course run ID, retrieve a chat response for that course.

Expected POST data: {
[
{'role': 'user', 'content': 'What is 2+2?'},
{'role': 'assistant', 'content': '4'}
]
}
Generate the next message to be returned by the learning assistant.
"""
try:
courserun_key = CourseKey.from_string(course_run_id)
except InvalidKeyError:
return Response(
status=http_status.HTTP_400_BAD_REQUEST,
data={'detail': 'Course ID is not a valid course ID.'}
)

if not learning_assistant_enabled(courserun_key):
return Response(
status=http_status.HTTP_403_FORBIDDEN,
data={'detail': 'Learning assistant not enabled for course.'}
)

# If user does not have an enrollment record, or is not staff, they should not have access
user_role = get_user_role(request.user, courserun_key)
enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key)
enrollment_mode = enrollment_object.mode if enrollment_object else None
if (
(enrollment_mode not in CourseMode.VERIFIED_MODES)
and not user_role_is_staff(user_role)
):
return Response(
status=http_status.HTTP_403_FORBIDDEN,
data={'detail': 'Must be staff or have valid enrollment.'}
)

unit_id = request.query_params.get('unit_id')

message_list = request.data

# Check that the last message in the list corresponds to a user
Expand Down Expand Up @@ -117,8 +82,8 @@
)

course_id = get_course_id(course_run_id)

template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')
unit_id = request.query_params.get('unit_id')

prompt_template = render_prompt_template(
request, request.user.id, course_run_id, unit_id, course_id, template_string
Expand All @@ -130,6 +95,68 @@

return Response(status=status_code, data=message)

def post(self, request, course_run_id):
"""
Given a course run ID, retrieve a chat response for that course.

Expected POST data: {
[
{'role': 'user', 'content': 'What is 2+2?'},
{'role': 'assistant', 'content': '4'}
]
}
"""
try:
courserun_key = CourseKey.from_string(course_run_id)
except InvalidKeyError:
return Response(

Check failure on line 112 in learning_assistant/views.py

View workflow job for this annotation

GitHub Actions / tests (ubuntu-20.04, 3.12, django42)

Missing coverage

Missing coverage on lines 111-112
status=http_status.HTTP_400_BAD_REQUEST,
data={'detail': 'Course ID is not a valid course ID.'}
)

if not learning_assistant_enabled(courserun_key):
return Response(
status=http_status.HTTP_403_FORBIDDEN,
data={'detail': 'Learning assistant not enabled for course.'}
)

# If user does not have a verified enrollment record, or is not staff, they should not have full access
user_role = get_user_role(request.user, courserun_key)
enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key)
enrollment_mode = enrollment_object.mode if enrollment_object else None

# If the user is in a verified course mode or is staff, return the next message
if (
# Here we include CREDIT_MODE and NO_ID_PROFESSIONAL_MODE, as CourseMode.VERIFIED_MODES on its own
# doesn't match what we count as "verified modes" in the frontend component.
enrollment_mode in CourseMode.VERIFIED_MODES + CourseMode.CREDIT_MODE + CourseMode.NO_ID_PROFESSIONAL_MODE
or user_role_is_staff(user_role)
):
return self._get_next_message(request, courserun_key, course_run_id)

# If user has an audit enrollment record, get or create their trial. If the trial is not expired, return the
# next message. Otherwise, return 403
elif enrollment_mode in CourseMode.UPSELL_TO_VERIFIED_MODES: # AUDIT, HONOR
course_mode = CourseMode.objects.get(course=courserun_key)
upgrade_deadline = course_mode.expiration_datetime()

user_audit_trial_expired = audit_trial_is_expired(request.user, upgrade_deadline)
if user_audit_trial_expired:
return Response(
status=http_status.HTTP_403_FORBIDDEN,
data={'detail': 'The audit trial for this user has expired.'}
)
else:
return self._get_next_message(request, courserun_key, course_run_id)

Check failure on line 150 in learning_assistant/views.py

View workflow job for this annotation

GitHub Actions / tests (ubuntu-20.04, 3.12, django42)

Missing coverage

Missing coverage on line 150
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a somewhat important case to cover, could a test be added to ensure that this code is still called?


# If user has a course mode that is not verified & not meant to access to the learning assistant, return 403
# This covers the other course modes: UNPAID_EXECUTIVE_EDUCATION & UNPAID_BOOTCAMP
else:
return Response(
status=http_status.HTTP_403_FORBIDDEN,
data={'detail': 'Must be staff or have valid enrollment.'}
)


class LearningAssistantEnabledView(APIView):
"""
Expand Down
44 changes: 43 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Test cases for the learning-assistant api module.
"""
import itertools
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch

import ddt
Expand All @@ -16,6 +17,7 @@
_extract_block_contents,
_get_children_contents,
_leaf_filter,
audit_trial_is_expired,
get_block_content,
get_message_history,
learning_assistant_available,
Expand All @@ -24,8 +26,13 @@
save_chat_message,
set_learning_assistant_enabled,
)
from learning_assistant.constants import AUDIT_TRIAL_MAX_DAYS
from learning_assistant.data import LearningAssistantCourseEnabledData
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage
from learning_assistant.models import (
LearningAssistantAuditTrial,
LearningAssistantCourseEnabled,
LearningAssistantMessage,
)

fake_transcript = 'This is the text version from the transcript'
User = get_user_model()
Expand Down Expand Up @@ -241,6 +248,7 @@ class TestLearningAssistantCourseEnabledApi(TestCase):
"""
Test suite for save_chat_message.
"""

def setUp(self):
super().setUp()

Expand Down Expand Up @@ -473,3 +481,37 @@ def test_get_message_course_id_differences(self):
self.assertEqual(return_value.user, expected_value[i].user)
self.assertEqual(return_value.role, expected_value[i].role)
self.assertEqual(return_value.content, expected_value[i].content)


@ddt.ddt
class CheckIfAuditTrialIsExpiredTests(TestCase):
"""
Test suite for audit_trial_is_expired.
"""

def setUp(self):
super().setUp()
self.course_key = CourseKey.from_string('course-v1:edx+fake+1')
self.user = User(username='tester', email='[email protected]')
self.user.save()

self.role = 'verified'
self.upgrade_deadline = datetime.now() + timedelta(days=1) # 1 day from now

def test_check_if_past_upgrade_deadline(self):
upgrade_deadline = datetime.now() - timedelta(days=1) # yesterday
self.assertEqual(audit_trial_is_expired(self.user, upgrade_deadline), True)

def test_audit_trial_is_expired_audit_trial_expired(self):
LearningAssistantAuditTrial.objects.create(
user=self.user,
start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS + 1), # 1 day more than trial deadline
)
self.assertEqual(audit_trial_is_expired(self.user, self.upgrade_deadline), True)

def test_audit_trial_is_expired_audit_trial_unexpired(self):
LearningAssistantAuditTrial.objects.create(
user=self.user,
start_date=datetime.now() - timedelta(days=AUDIT_TRIAL_MAX_DAYS - 0.99), # 0.99 days less than deadline
)
self.assertEqual(audit_trial_is_expired(self.user, self.upgrade_deadline), False)
Loading
Loading