diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9b9fd36..3cf14ca 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,7 +13,32 @@ Change Log Unreleased ********** + +4.4.5 - 2024-11-12 +****************** +* Updated Learning Assistant History payload to return in ascending order + +4.4.4 - 2024-11-06 +****************** +* Fixed Learning Assistant History endpoint +* Added timestamp to the Learning Assistant History payload + +4.4.3 - 2024-11-06 +****************** +* Fixed package version + +4.4.2 - 2024-11-04 +****************** +* Added chat messages to the DB + +4.4.1 - 2024-10-31 +****************** +* Add management command to remove expired messages + +4.4.0 - 2024-10-30 +****************** * Add LearningAssistantMessage model +* Add new GET endpoint to retrieve a user's message history in a given course. 4.4.0 - 2024-10-25 ****************** diff --git a/learning_assistant/__init__.py b/learning_assistant/__init__.py index f654ddd..8ef32fa 100644 --- a/learning_assistant/__init__.py +++ b/learning_assistant/__init__.py @@ -2,6 +2,6 @@ Plugin for a learning assistant backend, intended for use within edx-platform. """ -__version__ = '4.3.3' +__version__ = '4.4.5' default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 3e0b5e0..55c73c1 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -4,6 +4,7 @@ import logging from django.conf import settings +from django.contrib.auth import get_user_model from django.core.cache import cache from edx_django_utils.cache import get_cache_key from jinja2 import BaseLoader, Environment @@ -11,7 +12,7 @@ from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP from learning_assistant.data import LearningAssistantCourseEnabledData -from learning_assistant.models import LearningAssistantCourseEnabled +from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage from learning_assistant.platform_imports import ( block_get_children, block_leaf_filter, @@ -24,6 +25,7 @@ from learning_assistant.text_utils import html_to_text log = logging.getLogger(__name__) +User = get_user_model() def _extract_block_contents(child, category): @@ -187,3 +189,38 @@ def get_course_id(course_run_id): course_data = get_cache_course_run_data(course_run_id, ['course']) course_key = course_data['course'] return course_key + + +def save_chat_message(courserun_key, user_id, chat_role, message): + """ + Save the chat message to the database. + """ + user = None + try: + user = User.objects.get(id=user_id) + except User.DoesNotExist as exc: + raise Exception("User does not exists.") from exc + + # Save the user message to the database. + LearningAssistantMessage.objects.create( + course_id=courserun_key, + user=user, + role=chat_role, + content=message, + + ) + + +def get_message_history(courserun_key, user, message_count): + """ + Given a courserun key (CourseKey), user (User), and message count (int), return the associated message history. + + Returns a number of messages equal to the message_count value. + """ + # Explanation over the double reverse: This fetches the last message_count elements ordered by creating order DESC. + # Slicing the list in the model is an equivalent of adding LIMIT on the query. + # The result is the last chat messages for that user and course but in inversed order, so in order to flip them + # its first turn into a list and then reversed. + message_history = list(LearningAssistantMessage.objects.filter( + course_id=courserun_key, user=user).order_by('-created')[:message_count])[::-1] + return message_history diff --git a/learning_assistant/management/commands/retire_user_messages.py b/learning_assistant/management/commands/retire_user_messages.py new file mode 100644 index 0000000..8b90082 --- /dev/null +++ b/learning_assistant/management/commands/retire_user_messages.py @@ -0,0 +1,68 @@ +"""" +Django management command to remove LearningAssistantMessage objects +if they have reached their expiration date. +""" +import logging +import time +from datetime import datetime, timedelta + +from django.conf import settings +from django.core.management.base import BaseCommand + +from learning_assistant.models import LearningAssistantMessage + +log = logging.getLogger(__name__) + + +class Command(BaseCommand): + """ + Django Management command to remove expired messages. + """ + + def add_arguments(self, parser): + parser.add_argument( + '--batch_size', + action='store', + dest='batch_size', + type=int, + default=300, + help='Maximum number of messages to remove. ' + 'This helps avoid overloading the database while updating large amount of data.' + ) + parser.add_argument( + '--sleep_time', + action='store', + dest='sleep_time', + type=int, + default=10, + help='Sleep time in seconds between update of batches' + ) + + def handle(self, *args, **options): + """ + Management command entry point. + """ + batch_size = options['batch_size'] + sleep_time = options['sleep_time'] + + expiry_date = datetime.now() - timedelta(days=getattr(settings, 'LEARNING_ASSISTANT_MESSAGES_EXPIRY', 30)) + + total_deleted = 0 + deleted_count = None + + while deleted_count != 0: + ids_to_delete = LearningAssistantMessage.objects.filter( + created__lte=expiry_date + ).values_list('id', flat=True)[:batch_size] + + ids_to_delete = list(ids_to_delete) + delete_queryset = LearningAssistantMessage.objects.filter( + id__in=ids_to_delete + ) + deleted_count, _ = delete_queryset.delete() + + total_deleted += deleted_count + log.info(f'{deleted_count} messages deleted.') + time.sleep(sleep_time) + + log.info(f'Job completed. {total_deleted} messages deleted.') diff --git a/learning_assistant/management/commands/tests/test_retire_user_messages.py b/learning_assistant/management/commands/tests/test_retire_user_messages.py new file mode 100644 index 0000000..afd720f --- /dev/null +++ b/learning_assistant/management/commands/tests/test_retire_user_messages.py @@ -0,0 +1,68 @@ +""" +Tests for the retire_user_messages management command +""" +from datetime import datetime, timedelta + +from django.contrib.auth import get_user_model +from django.core.management import call_command +from django.test import TestCase + +from learning_assistant.models import LearningAssistantMessage + +User = get_user_model() + + +class RetireUserMessagesTests(TestCase): + """ + Tests for the retire_user_messages command. + """ + + def setUp(self): + """ + Build up test data + """ + super().setUp() + self.user = User(username='tester', email='tester@test.com') + self.user.save() + + self.course_id = 'course-v1:edx+test+23' + + LearningAssistantMessage.objects.create( + user=self.user, + course_id=self.course_id, + role='user', + content='Hello', + created=datetime.now() - timedelta(days=60) + ) + + LearningAssistantMessage.objects.create( + user=self.user, + course_id=self.course_id, + role='user', + content='Hello', + created=datetime.now() - timedelta(days=2) + ) + + LearningAssistantMessage.objects.create( + user=self.user, + course_id=self.course_id, + role='user', + content='Hello', + created=datetime.now() - timedelta(days=4) + ) + + def test_run_command(self): + """ + Run the management command + """ + current_messages = LearningAssistantMessage.objects.filter() + self.assertEqual(len(current_messages), 3) + + call_command( + 'retire_user_messages', + batch_size=2, + sleep_time=0, + ) + + current_messages = LearningAssistantMessage.objects.filter() + self.assertEqual(len(current_messages), 2) diff --git a/learning_assistant/migrations/0008_alter_learningassistantmessage_role.py b/learning_assistant/migrations/0008_alter_learningassistantmessage_role.py new file mode 100644 index 0000000..bd699b3 --- /dev/null +++ b/learning_assistant/migrations/0008_alter_learningassistantmessage_role.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.14 on 2024-11-04 08:52 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('learning_assistant', '0007_learningassistantmessage'), + ] + + operations = [ + migrations.AlterField( + model_name='learningassistantmessage', + name='role', + field=models.CharField(choices=[('user', 'user'), ('assistant', 'assistant')], max_length=64), + ), + ] diff --git a/learning_assistant/migrations/0009_learningassistantaudittrial.py b/learning_assistant/migrations/0009_learningassistantaudittrial.py new file mode 100644 index 0000000..30068a2 --- /dev/null +++ b/learning_assistant/migrations/0009_learningassistantaudittrial.py @@ -0,0 +1,31 @@ +# Generated by Django 4.2.16 on 2024-11-14 13:55 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django.utils.timezone +import model_utils.fields + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('learning_assistant', '0008_alter_learningassistantmessage_role'), + ] + + operations = [ + migrations.CreateModel( + name='LearningAssistantAuditTrial', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created', model_utils.fields.AutoCreatedField(default=django.utils.timezone.now, editable=False, verbose_name='created')), + ('modified', model_utils.fields.AutoLastModifiedField(default=django.utils.timezone.now, editable=False, verbose_name='modified')), + ('start_date', models.DateTimeField()), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, unique=True)), + ], + options={ + 'abstract': False, + }, + ), + ] diff --git a/learning_assistant/models.py b/learning_assistant/models.py index c890087..bd05ceb 100644 --- a/learning_assistant/models.py +++ b/learning_assistant/models.py @@ -35,7 +35,30 @@ class LearningAssistantMessage(TimeStampedModel): .. pii_retirement: third_party """ + USER_ROLE = 'user' + ASSISTANT_ROLE = 'assistant' + + Roles = ( + (USER_ROLE, USER_ROLE), + (ASSISTANT_ROLE, ASSISTANT_ROLE), + ) + course_id = CourseKeyField(max_length=255, db_index=True) user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE) - role = models.CharField(max_length=64) + role = models.CharField(choices=Roles, max_length=64) content = models.TextField() + + +class LearningAssistantAuditTrial(TimeStampedModel): + """ + This model stores the trial period for an audit learner using the learning assistant. + + A LearningAssistantAuditTrial instance will be created on a per user basis, + when an audit learner first sends a message using Xpert LA. + + .. no_pii: This model has no PII. + """ + + # Unique constraint since each user should only have one trial + user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE, unique=True) + start_date = models.DateTimeField() diff --git a/learning_assistant/serializers.py b/learning_assistant/serializers.py index a212654..1896182 100644 --- a/learning_assistant/serializers.py +++ b/learning_assistant/serializers.py @@ -3,6 +3,8 @@ """ from rest_framework import serializers +from learning_assistant.models import LearningAssistantMessage + class MessageSerializer(serializers.Serializer): # pylint: disable=abstract-method """ @@ -11,12 +13,25 @@ class MessageSerializer(serializers.Serializer): # pylint: disable=abstract-met role = serializers.CharField(required=True) content = serializers.CharField(required=True) + timestamp = serializers.DateTimeField(required=False, source='created') + + class Meta: + """ + Serializer metadata. + """ + + model = LearningAssistantMessage + fields = ( + 'role', + 'content', + 'timestamp', + ) def validate_role(self, value): """ Validate that role is one of two acceptable values. """ - valid_roles = ['user', 'assistant'] + valid_roles = [LearningAssistantMessage.USER_ROLE, LearningAssistantMessage.ASSISTANT_ROLE] if value not in valid_roles: raise serializers.ValidationError('Must be valid role.') return value diff --git a/learning_assistant/text_utils.py b/learning_assistant/text_utils.py index a4d0846..a1795e9 100644 --- a/learning_assistant/text_utils.py +++ b/learning_assistant/text_utils.py @@ -20,7 +20,7 @@ def cleanup_text(text): return stripped -class _HTMLToTextHelper(HTMLParser): +class _HTMLToTextHelper(HTMLParser): # lint-amnesty """ Helper function for html_to_text below. """ diff --git a/learning_assistant/toggles.py b/learning_assistant/toggles.py index 88d5ae6..61d515e 100644 --- a/learning_assistant/toggles.py +++ b/learning_assistant/toggles.py @@ -14,6 +14,16 @@ # .. toggle_tickets: COSMO-80 ENABLE_COURSE_CONTENT = 'enable_course_content' +# .. toggle_name: learning_assistant.enable_chat_history +# .. toggle_implementation: CourseWaffleFlag +# .. toggle_default: False +# .. toggle_description: Waffle flag to enable the chat history with the learning assistant +# .. toggle_use_cases: temporary +# .. toggle_creation_date: 2024-10-30 +# .. toggle_target_removal_date: 2024-12-31 +# .. toggle_tickets: COSMO-436 +ENABLE_CHAT_HISTORY = 'enable_chat_history' + def _is_learning_assistant_waffle_flag_enabled(flag_name, course_key): """ @@ -32,3 +42,10 @@ def course_content_enabled(course_key): Return whether the learning_assistant.enable_course_content WaffleFlag is on. """ return _is_learning_assistant_waffle_flag_enabled(ENABLE_COURSE_CONTENT, course_key) + + +def chat_history_enabled(course_key): + """ + Return whether the learning_assistant.enable_chat_history WaffleFlag is on. + """ + return _is_learning_assistant_waffle_flag_enabled(ENABLE_CHAT_HISTORY, course_key) diff --git a/learning_assistant/urls.py b/learning_assistant/urls.py index a5d3bbe..b0dfb48 100644 --- a/learning_assistant/urls.py +++ b/learning_assistant/urls.py @@ -4,7 +4,7 @@ from django.urls import re_path from learning_assistant.constants import COURSE_ID_PATTERN -from learning_assistant.views import CourseChatView, LearningAssistantEnabledView +from learning_assistant.views import CourseChatView, LearningAssistantEnabledView, LearningAssistantMessageHistoryView app_name = 'learning_assistant' @@ -19,4 +19,9 @@ LearningAssistantEnabledView.as_view(), name='enabled', ), + re_path( + fr'learning_assistant/v1/course_id/{COURSE_ID_PATTERN}/history', + LearningAssistantMessageHistoryView.as_view(), + name='message-history', + ), ] diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 7739240..b1b5c03 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -20,8 +20,16 @@ except ImportError: pass -from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template +from learning_assistant.api import ( + get_course_id, + get_message_history, + learning_assistant_enabled, + render_prompt_template, + save_chat_message, +) +from learning_assistant.models import LearningAssistantMessage from learning_assistant.serializers import MessageSerializer +from learning_assistant.toggles import chat_history_enabled from learning_assistant.utils import get_chat_response, user_role_is_staff log = logging.getLogger(__name__) @@ -76,6 +84,20 @@ def post(self, request, course_run_id): unit_id = request.query_params.get('unit_id') message_list = request.data + + # Check that the last message in the list corresponds to a user + new_user_message = message_list[-1] + if new_user_message['role'] != LearningAssistantMessage.USER_ROLE: + return Response( + status=http_status.HTTP_400_BAD_REQUEST, + data={'detail': "Expects user role on last message."} + ) + + user_id = request.user.id + + if chat_history_enabled(courserun_key): + save_chat_message(courserun_key, user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) + serializer = MessageSerializer(data=message_list, many=True) # serializer will not be valid in the case that the message list contains any roles other than @@ -103,6 +125,9 @@ def post(self, request, course_run_id): ) status_code, message = get_chat_response(prompt_template, message_list) + if chat_history_enabled(courserun_key): + save_chat_message(courserun_key, user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) + return Response(status=status_code, data=message) @@ -149,3 +174,68 @@ def get(self, request, course_run_id): } return Response(status=http_status.HTTP_200_OK, data=data) + + +class LearningAssistantMessageHistoryView(APIView): + """ + View to retrieve the message history for user in a course. + + This endpoint returns the message history stored in the LearningAssistantMessage table in a course + represented by the course_key, which is provided in the URL. + + Accepts: [GET] + + Path: /learning_assistant/v1/course_id/{course_key}/history + + Parameters: + * course_key: the ID of the course + + Responses: + * 200: OK + * 400: Malformed Request - Course ID is not a valid course ID. + """ + + authentication_classes = (SessionAuthentication, JwtAuthentication,) + permission_classes = (IsAuthenticated,) + + def get(self, request, course_run_id): + """ + Given a course run ID, retrieve the message history for the corresponding user. + + The response will be in the following format. + + [{'role': 'assistant', 'content': 'something'}] + """ + try: + courserun_key = CourseKey.from_string(course_run_id) + except InvalidKeyError: + return Response( + status=http_status.HTTP_400_BAD_REQUEST, + data={'detail': 'Course ID is not a valid course ID.'} + ) + + if not learning_assistant_enabled(courserun_key): + return Response( + status=http_status.HTTP_403_FORBIDDEN, + data={'detail': 'Learning assistant not enabled for course.'} + ) + + # If user does not have an enrollment record, or is not staff, they should not have access + user_role = get_user_role(request.user, courserun_key) + enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key) + enrollment_mode = enrollment_object.mode if enrollment_object else None + if ( + (enrollment_mode not in CourseMode.VERIFIED_MODES) + and not user_role_is_staff(user_role) + ): + return Response( + status=http_status.HTTP_403_FORBIDDEN, + data={'detail': 'Must be staff or have valid enrollment.'} + ) + + user = request.user + + message_count = int(request.GET.get('message_count', 50)) + message_history = get_message_history(courserun_key, user, message_count) + data = MessageSerializer(message_history, many=True).data + return Response(status=http_status.HTTP_200_OK, data=data) diff --git a/tests/test_api.py b/tests/test_api.py index 5cb2a43..1344dea 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -6,6 +6,7 @@ import ddt from django.conf import settings +from django.contrib.auth import get_user_model from django.core.cache import cache from django.test import TestCase, override_settings from opaque_keys import InvalidKeyError @@ -16,15 +17,18 @@ _get_children_contents, _leaf_filter, get_block_content, + get_message_history, learning_assistant_available, learning_assistant_enabled, render_prompt_template, + save_chat_message, set_learning_assistant_enabled, ) from learning_assistant.data import LearningAssistantCourseEnabledData -from learning_assistant.models import LearningAssistantCourseEnabled +from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage fake_transcript = 'This is the text version from the transcript' +User = get_user_model() class FakeChild: @@ -49,6 +53,7 @@ def get_html(self): class FakeBlock: "Fake block for testing, returns given children" + def __init__(self, children): self.children = children self.scope_ids = lambda: None @@ -231,11 +236,38 @@ def test_render_prompt_template_invalid_unit_key(self, mock_get_content): self.assertNotIn('The following text is useful.', prompt_text) +@ddt.ddt +class TestLearningAssistantCourseEnabledApi(TestCase): + """ + Test suite for save_chat_message. + """ + def setUp(self): + super().setUp() + + self.test_user = User.objects.create(username='username', password='password') + self.course_run_key = CourseKey.from_string('course-v1:edx+test+23') + + @ddt.data( + (LearningAssistantMessage.USER_ROLE, 'What is the meaning of life, the universe and everything?'), + (LearningAssistantMessage.ASSISTANT_ROLE, '42'), + ) + @ddt.unpack + def test_save_chat_message(self, chat_role, message): + save_chat_message(self.course_run_key, self.test_user.id, chat_role, message) + + row = LearningAssistantMessage.objects.all().last() + + self.assertEqual(row.course_id, self.course_run_key) + self.assertEqual(row.role, chat_role) + self.assertEqual(row.content, message) + + @ddt.ddt class LearningAssistantCourseEnabledApiTests(TestCase): """ Test suite for learning_assistant_available, learning_assistant_enabled, and set_learning_assistant_enabled. """ + def setUp(self): super().setUp() self.course_key = CourseKey.from_string('course-v1:edx+fake+1') @@ -305,3 +337,139 @@ def test_learning_assistant_available(self, learning_assistant_available_setting expected_value = learning_assistant_available_setting_value self.assertEqual(return_value, expected_value) + + +@ddt.ddt +class GetMessageHistoryTests(TestCase): + """ + Test suite for get_message_history. + """ + + def setUp(self): + super().setUp() + self.course_key = CourseKey.from_string('course-v1:edx+fake+1') + self.user = User(username='tester', email='tester@test.com') + self.user.save() + + self.role = 'verified' + + def test_get_message_history(self): + message_count = 5 + for i in range(1, message_count + 1): + LearningAssistantMessage.objects.create( + course_id=self.course_key, + user=self.user, + role=self.role, + content=f'Content of message {i}', + ) + + return_value = get_message_history(self.course_key, self.user, message_count) + + expected_value = list(LearningAssistantMessage.objects.filter( + course_id=self.course_key, user=self.user).order_by('-created')[:message_count])[::-1] + + # Ensure same number of entries + self.assertEqual(len(return_value), len(expected_value)) + + # Ensure values are as expected for all LearningAssistantMessage instances + for i, return_value in enumerate(return_value): + self.assertEqual(return_value.course_id, expected_value[i].course_id) + self.assertEqual(return_value.user, expected_value[i].user) + self.assertEqual(return_value.role, expected_value[i].role) + self.assertEqual(return_value.content, expected_value[i].content) + + @ddt.data( + 0, 1, 5, 10, 50 + ) + def test_get_message_history_message_count(self, actual_message_count): + for i in range(1, actual_message_count + 1): + LearningAssistantMessage.objects.create( + course_id=self.course_key, + user=self.user, + role=self.role, + content=f'Content of message {i}', + ) + + message_count_parameter = 5 + return_value = get_message_history(self.course_key, self.user, message_count_parameter) + + expected_value = LearningAssistantMessage.objects.filter( + course_id=self.course_key, user=self.user).order_by('-created')[:message_count_parameter] + + # Ensure same number of entries + self.assertEqual(len(return_value), len(expected_value)) + + def test_get_message_history_user_difference(self): + # Default Message + LearningAssistantMessage.objects.create( + course_id=self.course_key, + user=self.user, + role=self.role, + content='Expected content of message', + ) + + # New message w/ new user + new_user = User(username='not_tester', email='not_tester@test.com') + new_user.save() + LearningAssistantMessage.objects.create( + course_id=self.course_key, + user=new_user, + role=self.role, + content='Expected content of message', + ) + + message_count = 2 + return_value = get_message_history(self.course_key, self.user, message_count) + + expected_value = LearningAssistantMessage.objects.filter( + course_id=self.course_key, user=self.user).order_by('-created')[:message_count] + + # Ensure we filtered one of the two present messages + self.assertNotEqual(len(return_value), LearningAssistantMessage.objects.count()) + + # Ensure same number of entries + self.assertEqual(len(return_value), len(expected_value)) + + # Ensure values are as expected for all LearningAssistantMessage instances + for i, return_value in enumerate(return_value): + self.assertEqual(return_value.course_id, expected_value[i].course_id) + self.assertEqual(return_value.user, expected_value[i].user) + self.assertEqual(return_value.role, expected_value[i].role) + self.assertEqual(return_value.content, expected_value[i].content) + + def test_get_message_course_id_differences(self): + # Default Message + LearningAssistantMessage.objects.create( + course_id=self.course_key, + user=self.user, + role=self.role, + content='Expected content of message', + ) + + # New message + wrong_course_id = 'course-v1:wrong+id+1' + LearningAssistantMessage.objects.create( + course_id=wrong_course_id, + user=self.user, + role=self.role, + content='Expected content of message', + ) + + message_count = 2 + return_value = get_message_history(self.course_key, self.user, message_count) + + expected_value = LearningAssistantMessage.objects.filter( + course_id=self.course_key, user=self.user).order_by('-created')[:message_count] + + # Ensure we filtered one of the two present messages + self.assertNotEqual(len(return_value), LearningAssistantMessage.objects.count()) + + # Ensure same number of entries + self.assertEqual(len(return_value), len(expected_value)) + + # Ensure values are as expected for all LearningAssistantMessage instances + for i, return_value in enumerate(return_value): + self.assertEqual(return_value.course_id, expected_value[i].course_id) + self.assertEqual(return_value.user, expected_value[i].user) + self.assertEqual(return_value.role, expected_value[i].role) + self.assertEqual(return_value.content, expected_value[i].content) diff --git a/tests/test_plugins_api.py b/tests/test_plugins_api.py index 076335d..4e30539 100644 --- a/tests/test_plugins_api.py +++ b/tests/test_plugins_api.py @@ -19,6 +19,7 @@ class PluginApiTests(TestCase): """ Test suite for the plugins_api module. """ + def setUp(self): super().setUp() self.course_key = CourseKey.from_string('course-v1:edx+fake+1') diff --git a/tests/test_views.py b/tests/test_views.py index 7ff440e..1d567b8 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,10 +1,11 @@ """ Tests for the learning assistant views. """ +import datetime import json import sys from importlib import import_module -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import ddt from django.conf import settings @@ -13,14 +14,18 @@ from django.test import TestCase, override_settings from django.test.client import Client from django.urls import reverse +from opaque_keys.edx.keys import CourseKey + +from learning_assistant.models import LearningAssistantMessage User = get_user_model() -class TestClient(Client): +class FakeClient(Client): """ Allows for 'fake logins' of a user so we don't need to expose a 'login' HTTP endpoint. """ + def login_user(self, user): """ Login as specified user. @@ -63,14 +68,14 @@ def setUp(self): Setup for tests. """ super().setUp() - self.client = TestClient() - self.user = User(username='tester', email='tester@test.com') + self.client = FakeClient() + self.user = User(username='tester', email='tester@test.com', is_staff=True) self.user.save() self.client.login_user(self.user) @ddt.ddt -class CourseChatViewTests(LoggedInTestCase): +class TestCourseChatView(LoggedInTestCase): """ Test for the CourseChatView """ @@ -82,6 +87,7 @@ class CourseChatViewTests(LoggedInTestCase): def setUp(self): super().setUp() self.course_id = 'course-v1:edx+test+23' + self.course_run_key = CourseKey.from_string(self.course_id) self.patcher = patch( 'learning_assistant.api.get_cache_course_run_data', @@ -150,15 +156,27 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): ) self.assertEqual(response.status_code, 400) + @ddt.data(False, True) @patch('learning_assistant.views.render_prompt_template') @patch('learning_assistant.views.get_chat_response') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @patch('learning_assistant.views.CourseMode') + @patch('learning_assistant.views.save_chat_message') + @patch('learning_assistant.views.chat_history_enabled') @override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template') def test_chat_response_default( - self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render + self, + enabled_flag, + mock_chat_history_enabled, + mock_save_chat_message, + mock_mode, + mock_enrollment, + mock_role, + mock_waffle, + mock_chat_response, + mock_render, ): mock_waffle.return_value = True mock_role.return_value = 'student' @@ -168,9 +186,12 @@ def test_chat_response_default( mock_render.return_value = 'Rendered template mock' test_unit_id = 'test-unit-id' + mock_chat_history_enabled.return_value = enabled_flag + test_data = [ {'role': 'user', 'content': 'What is 2+2?'}, - {'role': 'assistant', 'content': 'It is 4'} + {'role': 'assistant', 'content': 'It is 4'}, + {'role': 'user', 'content': 'And what else?'}, ] response = self.client.post( @@ -189,6 +210,14 @@ def test_chat_response_default( test_data, ) + if enabled_flag: + mock_save_chat_message.assert_has_calls([ + call(self.course_run_key, self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), + call(self.course_run_key, self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') + ]) + else: + mock_save_chat_message.assert_not_called() + @ddt.ddt class LearningAssistantEnabledViewTests(LoggedInTestCase): @@ -210,12 +239,12 @@ def setUp(self): ) @ddt.unpack @patch('learning_assistant.views.learning_assistant_enabled') - def test_learning_assistant_enabled(self, mock_value, expected_value, mock_learning_assistant_enabled): + def test_learning_assistant_enabled(self, mock_value, message, mock_learning_assistant_enabled): mock_learning_assistant_enabled.return_value = mock_value response = self.client.get(reverse('enabled', kwargs={'course_run_id': self.course_id})) self.assertEqual(response.status_code, 200) - self.assertEqual(response.data, {'enabled': expected_value}) + self.assertEqual(response.data, {'enabled': message}) @patch('learning_assistant.views.learning_assistant_enabled') def test_invalid_course_id(self, mock_learning_assistant_enabled): @@ -223,3 +252,129 @@ def test_invalid_course_id(self, mock_learning_assistant_enabled): response = self.client.get(reverse('enabled', kwargs={'course_run_id': self.course_id+'+invalid'})) self.assertEqual(response.status_code, 400) + + +@ddt.ddt +class LearningAssistantMessageHistoryViewTests(LoggedInTestCase): + """ + Tests for the LearningAssistantMessageHistoryView + """ + + def setUp(self): + super().setUp() + self.course_id = 'course-v1:edx+test+23' + + @patch('learning_assistant.views.learning_assistant_enabled') + def test_invalid_course_id(self, mock_learning_assistant_enabled): + mock_learning_assistant_enabled.return_value = True + response = self.client.get(reverse('enabled', kwargs={'course_run_id': self.course_id+'+invalid'})) + + self.assertEqual(response.status_code, 400) + + @patch('learning_assistant.views.learning_assistant_enabled') + def test_course_waffle_inactive(self, mock_waffle): + mock_waffle.return_value = False + message_count = 5 + response = self.client.get( + reverse('message-history', kwargs={'course_run_id': self.course_id})+f'?message_count={message_count}', + content_type='application/json' + ) + self.assertEqual(response.status_code, 403) + + @patch('learning_assistant.views.learning_assistant_enabled') + def test_learning_assistant_not_enabled(self, mock_learning_assistant_enabled): + mock_learning_assistant_enabled.return_value = False + message_count = 5 + response = self.client.get( + reverse('message-history', kwargs={'course_run_id': self.course_id})+f'?message_count={message_count}', + content_type='application/json' + ) + + self.assertEqual(response.status_code, 403) + + @patch('learning_assistant.views.learning_assistant_enabled') + @patch('learning_assistant.views.get_user_role') + @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseMode') + def test_user_no_enrollment_not_staff(self, mock_mode, mock_enrollment, mock_role, mock_waffle): + mock_waffle.return_value = True + mock_role.return_value = 'student' + mock_mode.VERIFIED_MODES = ['verified'] + mock_enrollment.return_value = None + + message_count = 5 + response = self.client.get( + reverse('message-history', kwargs={'course_run_id': self.course_id})+f'?message_count={message_count}', + content_type='application/json' + ) + self.assertEqual(response.status_code, 403) + + @patch('learning_assistant.views.learning_assistant_enabled') + @patch('learning_assistant.views.get_user_role') + @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseMode') + def test_user_audit_enrollment_not_staff(self, mock_mode, mock_enrollment, mock_role, mock_waffle): + mock_waffle.return_value = True + mock_role.return_value = 'student' + mock_mode.VERIFIED_MODES = ['verified'] + mock_enrollment.return_value = MagicMock(mode='audit') + + message_count = 5 + response = self.client.get( + reverse('message-history', kwargs={'course_run_id': self.course_id})+f'?message_count={message_count}', + content_type='application/json' + ) + self.assertEqual(response.status_code, 403) + + @patch('learning_assistant.views.learning_assistant_enabled') + @patch('learning_assistant.views.get_user_role') + @patch('learning_assistant.views.CourseEnrollment.get_enrollment') + @patch('learning_assistant.views.CourseMode') + @patch('learning_assistant.views.get_course_id') + def test_learning_message_history_view_get( + self, + mock_get_course_id, + mock_mode, + mock_enrollment, + mock_role, + mock_waffle + ): + mock_waffle.return_value = True + mock_role.return_value = 'student' + mock_mode.VERIFIED_MODES = ['verified'] + mock_enrollment.return_value = MagicMock(mode='verified') + + LearningAssistantMessage.objects.create( + course_id=self.course_id, + user=self.user, + role='staff', + content='Older message', + created=datetime.date(2024, 10, 1) + ) + + LearningAssistantMessage.objects.create( + course_id=self.course_id, + user=self.user, + role='staff', + content='Newer message', + created=datetime.date(2024, 10, 3) + ) + + db_messages = LearningAssistantMessage.objects.all().order_by('created') + db_messages_count = len(db_messages) + + mock_get_course_id.return_value = self.course_id + response = self.client.get( + reverse('message-history', kwargs={'course_run_id': self.course_id})+f'?message_count={db_messages_count}', + content_type='application/json' + ) + data = response.data + + # Ensure same number of entries + self.assertEqual(len(data), db_messages_count) + + # Ensure values are as expected + for i, message in enumerate(data): + self.assertEqual(message['role'], db_messages[i].role) + self.assertEqual(message['content'], db_messages[i].content) + self.assertEqual(message['timestamp'], db_messages[i].created.isoformat())