diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 0790059..eaf4b10 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -4,6 +4,7 @@ import logging from django.conf import settings +from django.contrib.auth import get_user_model from django.core.cache import cache from edx_django_utils.cache import get_cache_key from jinja2 import BaseLoader, Environment @@ -24,6 +25,7 @@ from learning_assistant.text_utils import html_to_text log = logging.getLogger(__name__) +User = get_user_model() def _extract_block_contents(child, category): @@ -189,6 +191,26 @@ def get_course_id(course_run_id): return course_key +def save_chat_message(courserun_key, user_id, chat_role, message): + """ + Save the chat message to the database. + """ + user = None + try: + user = User.objects.get(id=user_id) + except User.DoesNotExist as exc: + raise Exception("User does not exists.") from exc + + # Save the user message to the database. + LearningAssistantMessage.objects.create( + course_id=courserun_key, + user=user, + role=chat_role, + content=message, + + ) + + def get_message_history(course_id, user, message_count): """ Given a course run id (str), user (User), and message count (int), return the associated message history. diff --git a/learning_assistant/migrations/0008_alter_learningassistantmessage_role.py b/learning_assistant/migrations/0008_alter_learningassistantmessage_role.py new file mode 100644 index 0000000..bd699b3 --- /dev/null +++ b/learning_assistant/migrations/0008_alter_learningassistantmessage_role.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.14 on 2024-11-04 08:52 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('learning_assistant', '0007_learningassistantmessage'), + ] + + operations = [ + migrations.AlterField( + model_name='learningassistantmessage', + name='role', + field=models.CharField(choices=[('user', 'user'), ('assistant', 'assistant')], max_length=64), + ), + ] diff --git a/learning_assistant/models.py b/learning_assistant/models.py index c890087..482bfe2 100644 --- a/learning_assistant/models.py +++ b/learning_assistant/models.py @@ -35,7 +35,15 @@ class LearningAssistantMessage(TimeStampedModel): .. pii_retirement: third_party """ + USER_ROLE = 'user' + ASSISTANT_ROLE = 'assistant' + + Roles = ( + (USER_ROLE, USER_ROLE), + (ASSISTANT_ROLE, ASSISTANT_ROLE), + ) + course_id = CourseKeyField(max_length=255, db_index=True) user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE) - role = models.CharField(max_length=64) + role = models.CharField(choices=Roles, max_length=64) content = models.TextField() diff --git a/learning_assistant/serializers.py b/learning_assistant/serializers.py index a212654..62141ef 100644 --- a/learning_assistant/serializers.py +++ b/learning_assistant/serializers.py @@ -3,6 +3,8 @@ """ from rest_framework import serializers +from learning_assistant.models import LearningAssistantMessage + class MessageSerializer(serializers.Serializer): # pylint: disable=abstract-method """ @@ -16,7 +18,7 @@ def validate_role(self, value): """ Validate that role is one of two acceptable values. """ - valid_roles = ['user', 'assistant'] + valid_roles = [LearningAssistantMessage.USER_ROLE, LearningAssistantMessage.ASSISTANT_ROLE] if value not in valid_roles: raise serializers.ValidationError('Must be valid role.') return value diff --git a/learning_assistant/toggles.py b/learning_assistant/toggles.py index 88d5ae6..61d515e 100644 --- a/learning_assistant/toggles.py +++ b/learning_assistant/toggles.py @@ -14,6 +14,16 @@ # .. toggle_tickets: COSMO-80 ENABLE_COURSE_CONTENT = 'enable_course_content' +# .. toggle_name: learning_assistant.enable_chat_history +# .. toggle_implementation: CourseWaffleFlag +# .. toggle_default: False +# .. toggle_description: Waffle flag to enable the chat history with the learning assistant +# .. toggle_use_cases: temporary +# .. toggle_creation_date: 2024-10-30 +# .. toggle_target_removal_date: 2024-12-31 +# .. toggle_tickets: COSMO-436 +ENABLE_CHAT_HISTORY = 'enable_chat_history' + def _is_learning_assistant_waffle_flag_enabled(flag_name, course_key): """ @@ -32,3 +42,10 @@ def course_content_enabled(course_key): Return whether the learning_assistant.enable_course_content WaffleFlag is on. """ return _is_learning_assistant_waffle_flag_enabled(ENABLE_COURSE_CONTENT, course_key) + + +def chat_history_enabled(course_key): + """ + Return whether the learning_assistant.enable_chat_history WaffleFlag is on. + """ + return _is_learning_assistant_waffle_flag_enabled(ENABLE_CHAT_HISTORY, course_key) diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 68b1c0a..583aed1 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -25,8 +25,11 @@ get_message_history, learning_assistant_enabled, render_prompt_template, + save_chat_message, ) +from learning_assistant.models import LearningAssistantMessage from learning_assistant.serializers import MessageSerializer +from learning_assistant.toggles import chat_history_enabled from learning_assistant.utils import get_chat_response, user_role_is_staff log = logging.getLogger(__name__) @@ -81,6 +84,20 @@ def post(self, request, course_run_id): unit_id = request.query_params.get('unit_id') message_list = request.data + + # Check that the last message in the list corresponds to a user + new_user_message = message_list[-1] + if new_user_message['role'] != LearningAssistantMessage.USER_ROLE: + return Response( + status=http_status.HTTP_400_BAD_REQUEST, + data={'detail': "Expects user role on last message."} + ) + + user_id = request.user.id + + if chat_history_enabled(courserun_key): + save_chat_message(courserun_key, user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) + serializer = MessageSerializer(data=message_list, many=True) # serializer will not be valid in the case that the message list contains any roles other than @@ -108,6 +125,9 @@ def post(self, request, course_run_id): ) status_code, message = get_chat_response(prompt_template, message_list) + if chat_history_enabled(courserun_key): + save_chat_message(courserun_key, user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) + return Response(status=status_code, data=message) diff --git a/tests/test_api.py b/tests/test_api.py index 3ca365d..0470e32 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -21,13 +21,13 @@ learning_assistant_available, learning_assistant_enabled, render_prompt_template, + save_chat_message, set_learning_assistant_enabled, ) from learning_assistant.data import LearningAssistantCourseEnabledData from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage fake_transcript = 'This is the text version from the transcript' - User = get_user_model() @@ -236,6 +236,32 @@ def test_render_prompt_template_invalid_unit_key(self, mock_get_content): self.assertNotIn('The following text is useful.', prompt_text) +@ddt.ddt +class TestLearningAssistantCourseEnabledApi(TestCase): + """ + Test suite for save_chat_message. + """ + def setUp(self): + super().setUp() + + self.test_user = User.objects.create(username='username', password='password') + self.course_run_key = CourseKey.from_string('course-v1:edx+test+23') + + @ddt.data( + (LearningAssistantMessage.USER_ROLE, 'What is the meaning of life, the universe and everything?'), + (LearningAssistantMessage.ASSISTANT_ROLE, '42'), + ) + @ddt.unpack + def test_save_chat_message(self, chat_role, message): + save_chat_message(self.course_run_key, self.test_user.id, chat_role, message) + + row = LearningAssistantMessage.objects.all().last() + + self.assertEqual(row.course_id, self.course_run_key) + self.assertEqual(row.role, chat_role) + self.assertEqual(row.content, message) + + @ddt.ddt class LearningAssistantCourseEnabledApiTests(TestCase): """ diff --git a/tests/test_views.py b/tests/test_views.py index 492a10d..5125d27 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -4,7 +4,7 @@ import json import sys from importlib import import_module -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import ddt from django.conf import settings @@ -13,13 +13,14 @@ from django.test import TestCase, override_settings from django.test.client import Client from django.urls import reverse +from opaque_keys.edx.keys import CourseKey from learning_assistant.models import LearningAssistantMessage User = get_user_model() -class TestClient(Client): +class FakeClient(Client): """ Allows for 'fake logins' of a user so we don't need to expose a 'login' HTTP endpoint. """ @@ -66,14 +67,14 @@ def setUp(self): Setup for tests. """ super().setUp() - self.client = TestClient() + self.client = FakeClient() self.user = User(username='tester', email='tester@test.com', is_staff=True) self.user.save() self.client.login_user(self.user) @ddt.ddt -class CourseChatViewTests(LoggedInTestCase): +class TestCourseChatView(LoggedInTestCase): """ Test for the CourseChatView """ @@ -85,6 +86,7 @@ class CourseChatViewTests(LoggedInTestCase): def setUp(self): super().setUp() self.course_id = 'course-v1:edx+test+23' + self.course_run_key = CourseKey.from_string(self.course_id) self.patcher = patch( 'learning_assistant.api.get_cache_course_run_data', @@ -153,15 +155,27 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): ) self.assertEqual(response.status_code, 400) + @ddt.data(False, True) @patch('learning_assistant.views.render_prompt_template') @patch('learning_assistant.views.get_chat_response') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @patch('learning_assistant.views.CourseMode') + @patch('learning_assistant.views.save_chat_message') + @patch('learning_assistant.views.chat_history_enabled') @override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template') def test_chat_response_default( - self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render + self, + enabled_flag, + mock_chat_history_enabled, + mock_save_chat_message, + mock_mode, + mock_enrollment, + mock_role, + mock_waffle, + mock_chat_response, + mock_render, ): mock_waffle.return_value = True mock_role.return_value = 'student' @@ -171,9 +185,12 @@ def test_chat_response_default( mock_render.return_value = 'Rendered template mock' test_unit_id = 'test-unit-id' + mock_chat_history_enabled.return_value = enabled_flag + test_data = [ {'role': 'user', 'content': 'What is 2+2?'}, - {'role': 'assistant', 'content': 'It is 4'} + {'role': 'assistant', 'content': 'It is 4'}, + {'role': 'user', 'content': 'And what else?'}, ] response = self.client.post( @@ -192,6 +209,14 @@ def test_chat_response_default( test_data, ) + if enabled_flag: + mock_save_chat_message.assert_has_calls([ + call(self.course_run_key, self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), + call(self.course_run_key, self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') + ]) + else: + mock_save_chat_message.assert_not_called() + @ddt.ddt class LearningAssistantEnabledViewTests(LoggedInTestCase):