From 99ccaa22046beeb7d7c925d6c41b7b7966c2ecf0 Mon Sep 17 00:00:00 2001 From: Marcos Date: Wed, 30 Oct 2024 09:40:54 -0300 Subject: [PATCH 1/9] feat: Adding chat messages to the DB --- learning_assistant/models.py | 10 +++++++++- learning_assistant/serializers.py | 4 +++- learning_assistant/views.py | 28 ++++++++++++++++++++++++++++ tests/test_views.py | 15 ++++++++++++++- 4 files changed, 54 insertions(+), 3 deletions(-) diff --git a/learning_assistant/models.py b/learning_assistant/models.py index c890087..482bfe2 100644 --- a/learning_assistant/models.py +++ b/learning_assistant/models.py @@ -35,7 +35,15 @@ class LearningAssistantMessage(TimeStampedModel): .. pii_retirement: third_party """ + USER_ROLE = 'user' + ASSISTANT_ROLE = 'assistant' + + Roles = ( + (USER_ROLE, USER_ROLE), + (ASSISTANT_ROLE, ASSISTANT_ROLE), + ) + course_id = CourseKeyField(max_length=255, db_index=True) user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE) - role = models.CharField(max_length=64) + role = models.CharField(choices=Roles, max_length=64) content = models.TextField() diff --git a/learning_assistant/serializers.py b/learning_assistant/serializers.py index a212654..62141ef 100644 --- a/learning_assistant/serializers.py +++ b/learning_assistant/serializers.py @@ -3,6 +3,8 @@ """ from rest_framework import serializers +from learning_assistant.models import LearningAssistantMessage + class MessageSerializer(serializers.Serializer): # pylint: disable=abstract-method """ @@ -16,7 +18,7 @@ def validate_role(self, value): """ Validate that role is one of two acceptable values. """ - valid_roles = ['user', 'assistant'] + valid_roles = [LearningAssistantMessage.USER_ROLE, LearningAssistantMessage.ASSISTANT_ROLE] if value not in valid_roles: raise serializers.ValidationError('Must be valid role.') return value diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 68b1c0a..aa9dc26 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -4,6 +4,7 @@ import logging from django.conf import settings +from django.contrib.auth import get_user_model from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey @@ -26,10 +27,12 @@ learning_assistant_enabled, render_prompt_template, ) +from learning_assistant.models import LearningAssistantMessage from learning_assistant.serializers import MessageSerializer from learning_assistant.utils import get_chat_response, user_role_is_staff log = logging.getLogger(__name__) +User = get_user_model() class CourseChatView(APIView): @@ -81,6 +84,15 @@ def post(self, request, course_run_id): unit_id = request.query_params.get('unit_id') message_list = request.data + + # Check that the last message in the list corresponds to a user + new_user_message = message_list[-1] + if new_user_message['role'] != LearningAssistantMessage.USER_ROLE: + return Response( + status=http_status.HTTP_400_BAD_REQUEST, + data={'detail': "Expects user role on last message."} + ) + serializer = MessageSerializer(data=message_list, many=True) # serializer will not be valid in the case that the message list contains any roles other than @@ -108,6 +120,22 @@ def post(self, request, course_run_id): ) status_code, message = get_chat_response(prompt_template, message_list) + user = User.objects.get(id=request.user.id) # Based on the previous code, user exists. + + # Save the user message to the database. + LearningAssistantMessage.objects.create( + user=user, + role=LearningAssistantMessage.USER_ROLE, + content=new_user_message['content'], + ) + + # Save the assistant response to the database. + LearningAssistantMessage.objects.create( + user=user, + role=LearningAssistantMessage.ASSISTANT_ROLE, + content=message['content'], + ) + return Response(status=status_code, data=message) diff --git a/tests/test_views.py b/tests/test_views.py index 492a10d..35e90d3 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -173,7 +173,8 @@ def test_chat_response_default( test_data = [ {'role': 'user', 'content': 'What is 2+2?'}, - {'role': 'assistant', 'content': 'It is 4'} + {'role': 'assistant', 'content': 'It is 4'}, + {'role': 'user', 'content': 'And what else?'}, ] response = self.client.post( @@ -181,8 +182,20 @@ def test_chat_response_default( data=json.dumps(test_data), content_type='application/json' ) + self.assertEqual(response.status_code, 200) + last_rows = LearningAssistantMessage.objects.all().order_by('-created').values()[:2][::-1] + + user_msg = last_rows[0] + assistant_msg = last_rows[1] + + self.assertEqual(user_msg['role'], LearningAssistantMessage.USER_ROLE) + self.assertEqual(user_msg['content'], test_data[2]['content']) + + self.assertEqual(assistant_msg['role'], LearningAssistantMessage.ASSISTANT_ROLE) + self.assertEqual(assistant_msg['content'], 'Something else') + render_args = mock_render.call_args.args self.assertIn(test_unit_id, render_args) self.assertIn('This is the default template', render_args) From c419fdefefd96045d68f610956e6cf91462adcce Mon Sep 17 00:00:00 2001 From: Marcos Date: Wed, 30 Oct 2024 12:23:49 -0300 Subject: [PATCH 2/9] chore: Moved code to save the message to its own method --- learning_assistant/views.py | 39 ++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/learning_assistant/views.py b/learning_assistant/views.py index aa9dc26..792aee3 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -43,6 +43,27 @@ class CourseChatView(APIView): authentication_classes = (SessionAuthentication, JwtAuthentication,) permission_classes = (IsAuthenticated,) + def __save_user_interaction(self, user_id, user_message, assistant_message): + """ + Saves the last question/response to the database. + """ + user = User.objects.get(id=user_id) + + # Save the user message to the database. + LearningAssistantMessage.objects.create( + user=user, + role=LearningAssistantMessage.USER_ROLE, + content=user_message, + ) + + # Save the assistant response to the database. + LearningAssistantMessage.objects.create( + user=user, + role=LearningAssistantMessage.ASSISTANT_ROLE, + content=assistant_message, + ) + + def post(self, request, course_run_id): """ Given a course run ID, retrieve a chat response for that course. @@ -120,20 +141,10 @@ def post(self, request, course_run_id): ) status_code, message = get_chat_response(prompt_template, message_list) - user = User.objects.get(id=request.user.id) # Based on the previous code, user exists. - - # Save the user message to the database. - LearningAssistantMessage.objects.create( - user=user, - role=LearningAssistantMessage.USER_ROLE, - content=new_user_message['content'], - ) - - # Save the assistant response to the database. - LearningAssistantMessage.objects.create( - user=user, - role=LearningAssistantMessage.ASSISTANT_ROLE, - content=message['content'], + self.__save_user_interaction( + user_id=request.user.id, + user_message=new_user_message['content'], + assistant_message=message['content'] ) return Response(status=status_code, data=message) From 79917a5863ee026ff794f2c4cbf0cc744fe3c82f Mon Sep 17 00:00:00 2001 From: Marcos Date: Wed, 30 Oct 2024 15:35:04 -0300 Subject: [PATCH 3/9] feat: Added learning_assistant.enable_chat_history waffle flag --- learning_assistant/toggles.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/learning_assistant/toggles.py b/learning_assistant/toggles.py index 88d5ae6..61d515e 100644 --- a/learning_assistant/toggles.py +++ b/learning_assistant/toggles.py @@ -14,6 +14,16 @@ # .. toggle_tickets: COSMO-80 ENABLE_COURSE_CONTENT = 'enable_course_content' +# .. toggle_name: learning_assistant.enable_chat_history +# .. toggle_implementation: CourseWaffleFlag +# .. toggle_default: False +# .. toggle_description: Waffle flag to enable the chat history with the learning assistant +# .. toggle_use_cases: temporary +# .. toggle_creation_date: 2024-10-30 +# .. toggle_target_removal_date: 2024-12-31 +# .. toggle_tickets: COSMO-436 +ENABLE_CHAT_HISTORY = 'enable_chat_history' + def _is_learning_assistant_waffle_flag_enabled(flag_name, course_key): """ @@ -32,3 +42,10 @@ def course_content_enabled(course_key): Return whether the learning_assistant.enable_course_content WaffleFlag is on. """ return _is_learning_assistant_waffle_flag_enabled(ENABLE_COURSE_CONTENT, course_key) + + +def chat_history_enabled(course_key): + """ + Return whether the learning_assistant.enable_chat_history WaffleFlag is on. + """ + return _is_learning_assistant_waffle_flag_enabled(ENABLE_CHAT_HISTORY, course_key) From dae5df8a781347a5df4080823fa7fb695c02df08 Mon Sep 17 00:00:00 2001 From: Marcos Date: Wed, 30 Oct 2024 15:36:05 -0300 Subject: [PATCH 4/9] feat: Added save_chat_message() to API --- learning_assistant/api.py | 20 ++++++++++++++++++++ tests/test_api.py | 26 +++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 0790059..7ddaa4b 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -4,6 +4,7 @@ import logging from django.conf import settings +from django.contrib.auth import get_user_model from django.core.cache import cache from edx_django_utils.cache import get_cache_key from jinja2 import BaseLoader, Environment @@ -24,6 +25,7 @@ from learning_assistant.text_utils import html_to_text log = logging.getLogger(__name__) +User = get_user_model() def _extract_block_contents(child, category): @@ -188,6 +190,24 @@ def get_course_id(course_run_id): course_key = course_data['course'] return course_key +def save_chat_message(user_id, chat_role, message): + """ + Saves the chat message to the database. + """ + + user = None + try: + user = User.objects.get(id=user_id) + except User.DoesNotExist: + raise Exception("User does not exists.") + + # Save the user message to the database. + LearningAssistantMessage.objects.create( + user=user, + role=chat_role, + content=message, + ) + def get_message_history(course_id, user, message_count): """ diff --git a/tests/test_api.py b/tests/test_api.py index 3ca365d..a2f51c3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -21,13 +21,13 @@ learning_assistant_available, learning_assistant_enabled, render_prompt_template, + save_chat_message, set_learning_assistant_enabled, ) from learning_assistant.data import LearningAssistantCourseEnabledData from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage fake_transcript = 'This is the text version from the transcript' - User = get_user_model() @@ -235,6 +235,30 @@ def test_render_prompt_template_invalid_unit_key(self, mock_get_content): self.assertNotIn('The following text is useful.', prompt_text) +@ddt.ddt +class TestLearningAssistantCourseEnabledApi(TestCase): + """ + Test suite for save_chat_message. + """ + def setUp(self): + super().setUp() + + self.test_user = User.objects.create(username='username', password='password') + + @ddt.data( + (LearningAssistantMessage.USER_ROLE, 'What is the meaning of life, the universe and everything?'), + (LearningAssistantMessage.ASSISTANT_ROLE, '42'), + ) + @ddt.unpack + def test_save_chat_message(self, chat_role, message): + save_chat_message(self.test_user.id, chat_role, message) + + row = LearningAssistantMessage.objects.all().last() + + self.assertEqual(row.role, chat_role) + self.assertEqual(row.content, message) + + @ddt.ddt class LearningAssistantCourseEnabledApiTests(TestCase): From 23fd0131484c975cb8fab81b1276010820213534 Mon Sep 17 00:00:00 2001 From: Marcos Date: Wed, 30 Oct 2024 17:53:09 -0300 Subject: [PATCH 5/9] chore: Fixed coverage for CourseChatView --- learning_assistant/api.py | 8 +++--- learning_assistant/views.py | 40 ++++++++--------------------- tests/test_api.py | 2 +- tests/test_views.py | 50 ++++++++++++++++++++++++------------- 4 files changed, 48 insertions(+), 52 deletions(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 7ddaa4b..e6fa870 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -190,16 +190,16 @@ def get_course_id(course_run_id): course_key = course_data['course'] return course_key + def save_chat_message(user_id, chat_role, message): """ - Saves the chat message to the database. + Save the chat message to the database. """ - user = None try: user = User.objects.get(id=user_id) - except User.DoesNotExist: - raise Exception("User does not exists.") + except User.DoesNotExist as exc: + raise Exception("User does not exists.") from exc # Save the user message to the database. LearningAssistantMessage.objects.create( diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 792aee3..44eb8d0 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -4,7 +4,6 @@ import logging from django.conf import settings -from django.contrib.auth import get_user_model from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import CourseKey @@ -26,13 +25,14 @@ get_message_history, learning_assistant_enabled, render_prompt_template, + save_chat_message, ) from learning_assistant.models import LearningAssistantMessage from learning_assistant.serializers import MessageSerializer +from learning_assistant.toggles import chat_history_enabled from learning_assistant.utils import get_chat_response, user_role_is_staff log = logging.getLogger(__name__) -User = get_user_model() class CourseChatView(APIView): @@ -43,27 +43,6 @@ class CourseChatView(APIView): authentication_classes = (SessionAuthentication, JwtAuthentication,) permission_classes = (IsAuthenticated,) - def __save_user_interaction(self, user_id, user_message, assistant_message): - """ - Saves the last question/response to the database. - """ - user = User.objects.get(id=user_id) - - # Save the user message to the database. - LearningAssistantMessage.objects.create( - user=user, - role=LearningAssistantMessage.USER_ROLE, - content=user_message, - ) - - # Save the assistant response to the database. - LearningAssistantMessage.objects.create( - user=user, - role=LearningAssistantMessage.ASSISTANT_ROLE, - content=assistant_message, - ) - - def post(self, request, course_run_id): """ Given a course run ID, retrieve a chat response for that course. @@ -114,6 +93,12 @@ def post(self, request, course_run_id): data={'detail': "Expects user role on last message."} ) + course_id = get_course_id(course_run_id) + user_id = request.user.id + + if chat_history_enabled(course_id): + save_chat_message(user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) + serializer = MessageSerializer(data=message_list, many=True) # serializer will not be valid in the case that the message list contains any roles other than @@ -132,8 +117,6 @@ def post(self, request, course_run_id): } ) - course_id = get_course_id(course_run_id) - template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '') prompt_template = render_prompt_template( @@ -141,11 +124,8 @@ def post(self, request, course_run_id): ) status_code, message = get_chat_response(prompt_template, message_list) - self.__save_user_interaction( - user_id=request.user.id, - user_message=new_user_message['content'], - assistant_message=message['content'] - ) + if chat_history_enabled(course_id): + save_chat_message(user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) return Response(status=status_code, data=message) diff --git a/tests/test_api.py b/tests/test_api.py index a2f51c3..6969af8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -235,6 +235,7 @@ def test_render_prompt_template_invalid_unit_key(self, mock_get_content): self.assertNotIn('The following text is useful.', prompt_text) + @ddt.ddt class TestLearningAssistantCourseEnabledApi(TestCase): """ @@ -259,7 +260,6 @@ def test_save_chat_message(self, chat_role, message): self.assertEqual(row.content, message) - @ddt.ddt class LearningAssistantCourseEnabledApiTests(TestCase): """ diff --git a/tests/test_views.py b/tests/test_views.py index 35e90d3..c521a30 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -4,7 +4,7 @@ import json import sys from importlib import import_module -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import ddt from django.conf import settings @@ -19,7 +19,7 @@ User = get_user_model() -class TestClient(Client): +class FakeClient(Client): """ Allows for 'fake logins' of a user so we don't need to expose a 'login' HTTP endpoint. """ @@ -66,14 +66,14 @@ def setUp(self): Setup for tests. """ super().setUp() - self.client = TestClient() + self.client = FakeClient() self.user = User(username='tester', email='tester@test.com', is_staff=True) self.user.save() self.client.login_user(self.user) @ddt.ddt -class CourseChatViewTests(LoggedInTestCase): +class TestCourseChatView(LoggedInTestCase): """ Test for the CourseChatView """ @@ -153,15 +153,27 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): ) self.assertEqual(response.status_code, 400) + @ddt.data(True, False) # TODO: Fix this - See below. @patch('learning_assistant.views.render_prompt_template') @patch('learning_assistant.views.get_chat_response') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @patch('learning_assistant.views.CourseMode') + @patch('learning_assistant.api.save_chat_message') + @patch('learning_assistant.toggles.chat_history_enabled') @override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template') def test_chat_response_default( - self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render + self, + enabled_flag, + mock_chat_history_enabled, + mock_save_chat_message, + mock_mode, + mock_enrollment, + mock_role, + mock_waffle, + mock_chat_response, + mock_render, ): mock_waffle.return_value = True mock_role.return_value = 'student' @@ -171,6 +183,14 @@ def test_chat_response_default( mock_render.return_value = 'Rendered template mock' test_unit_id = 'test-unit-id' + # TODO: Fix this... + # For some reason this only works the first time. The 2nd time (enabled_flag = False) + # Doesn't actually work since the mocked chat_history_enabled() will return False no matter what. + # Swap the order of the @ddt.data() above by: @ddt.data(False, True) and watch it fail. + # The value for enabled_flag is corrct on this scope, but the mocked method doesn't update. + # It even happens if we split the test cases into two different methods. + mock_chat_history_enabled.return_value = enabled_flag + test_data = [ {'role': 'user', 'content': 'What is 2+2?'}, {'role': 'assistant', 'content': 'It is 4'}, @@ -182,20 +202,8 @@ def test_chat_response_default( data=json.dumps(test_data), content_type='application/json' ) - self.assertEqual(response.status_code, 200) - last_rows = LearningAssistantMessage.objects.all().order_by('-created').values()[:2][::-1] - - user_msg = last_rows[0] - assistant_msg = last_rows[1] - - self.assertEqual(user_msg['role'], LearningAssistantMessage.USER_ROLE) - self.assertEqual(user_msg['content'], test_data[2]['content']) - - self.assertEqual(assistant_msg['role'], LearningAssistantMessage.ASSISTANT_ROLE) - self.assertEqual(assistant_msg['content'], 'Something else') - render_args = mock_render.call_args.args self.assertIn(test_unit_id, render_args) self.assertIn('This is the default template', render_args) @@ -205,6 +213,14 @@ def test_chat_response_default( test_data, ) + if enabled_flag: + mock_save_chat_message.assert_has_calls([ + call(self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), + call(self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') + ]) + else: + mock_save_chat_message.assert_not_called() + @ddt.ddt class LearningAssistantEnabledViewTests(LoggedInTestCase): From 352340ea71c5ec7fae32cee3346827755e25a828 Mon Sep 17 00:00:00 2001 From: Marcos Date: Thu, 31 Oct 2024 13:25:21 -0300 Subject: [PATCH 6/9] fix: Swapped course id for course key --- learning_assistant/views.py | 7 ++++--- tests/test_views.py | 12 +++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 44eb8d0..3e136fe 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -93,10 +93,9 @@ def post(self, request, course_run_id): data={'detail': "Expects user role on last message."} ) - course_id = get_course_id(course_run_id) user_id = request.user.id - if chat_history_enabled(course_id): + if chat_history_enabled(courserun_key): save_chat_message(user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) serializer = MessageSerializer(data=message_list, many=True) @@ -117,6 +116,8 @@ def post(self, request, course_run_id): } ) + course_id = get_course_id(course_run_id) + template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '') prompt_template = render_prompt_template( @@ -124,7 +125,7 @@ def post(self, request, course_run_id): ) status_code, message = get_chat_response(prompt_template, message_list) - if chat_history_enabled(course_id): + if chat_history_enabled(courserun_key): save_chat_message(user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) return Response(status=status_code, data=message) diff --git a/tests/test_views.py b/tests/test_views.py index c521a30..4e791b7 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -184,11 +184,13 @@ def test_chat_response_default( test_unit_id = 'test-unit-id' # TODO: Fix this... - # For some reason this only works the first time. The 2nd time (enabled_flag = False) - # Doesn't actually work since the mocked chat_history_enabled() will return False no matter what. - # Swap the order of the @ddt.data() above by: @ddt.data(False, True) and watch it fail. - # The value for enabled_flag is corrct on this scope, but the mocked method doesn't update. - # It even happens if we split the test cases into two different methods. + # For some reason this assignment only works the first iteration. The 2nd time onwards the return value is + # always falsy. Swap the order of the @ddt.data() above by: @ddt.data(False, True) to see it fail. + # I'm leaving it like this because we are testing the False return in the second iteration, but it's important + # to consider whenever this test needs to be updated. + # It even happens if we split the test cases into two different methods (instead of @ddt.data()), so there's + # probably some scoping issues in how the test is set up. + # Note: There's a similar test for LearningAssistantEnabledView in this file that works just fine. mock_chat_history_enabled.return_value = enabled_flag test_data = [ From 7cdc76331a83a515f91ad6b1b2a88810775fa570 Mon Sep 17 00:00:00 2001 From: Marcos Date: Mon, 4 Nov 2024 11:16:58 -0300 Subject: [PATCH 7/9] chore: Fixed patching on the CourseChatView unit tests --- tests/test_views.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tests/test_views.py b/tests/test_views.py index 4e791b7..9e07e55 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -153,15 +153,15 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): ) self.assertEqual(response.status_code, 400) - @ddt.data(True, False) # TODO: Fix this - See below. + @ddt.data(False, True) @patch('learning_assistant.views.render_prompt_template') @patch('learning_assistant.views.get_chat_response') @patch('learning_assistant.views.learning_assistant_enabled') @patch('learning_assistant.views.get_user_role') @patch('learning_assistant.views.CourseEnrollment.get_enrollment') @patch('learning_assistant.views.CourseMode') - @patch('learning_assistant.api.save_chat_message') - @patch('learning_assistant.toggles.chat_history_enabled') + @patch('learning_assistant.views.save_chat_message') + @patch('learning_assistant.views.chat_history_enabled') @override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template') def test_chat_response_default( self, @@ -183,14 +183,6 @@ def test_chat_response_default( mock_render.return_value = 'Rendered template mock' test_unit_id = 'test-unit-id' - # TODO: Fix this... - # For some reason this assignment only works the first iteration. The 2nd time onwards the return value is - # always falsy. Swap the order of the @ddt.data() above by: @ddt.data(False, True) to see it fail. - # I'm leaving it like this because we are testing the False return in the second iteration, but it's important - # to consider whenever this test needs to be updated. - # It even happens if we split the test cases into two different methods (instead of @ddt.data()), so there's - # probably some scoping issues in how the test is set up. - # Note: There's a similar test for LearningAssistantEnabledView in this file that works just fine. mock_chat_history_enabled.return_value = enabled_flag test_data = [ From f0b0cf8d3ec4393d4b1aeae7bdd9519c94a65864 Mon Sep 17 00:00:00 2001 From: Marcos Date: Mon, 4 Nov 2024 11:57:09 -0300 Subject: [PATCH 8/9] fix: Added migration after updating role in model --- ...0008_alter_learningassistantmessage_role.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 learning_assistant/migrations/0008_alter_learningassistantmessage_role.py diff --git a/learning_assistant/migrations/0008_alter_learningassistantmessage_role.py b/learning_assistant/migrations/0008_alter_learningassistantmessage_role.py new file mode 100644 index 0000000..bd699b3 --- /dev/null +++ b/learning_assistant/migrations/0008_alter_learningassistantmessage_role.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.14 on 2024-11-04 08:52 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('learning_assistant', '0007_learningassistantmessage'), + ] + + operations = [ + migrations.AlterField( + model_name='learningassistantmessage', + name='role', + field=models.CharField(choices=[('user', 'user'), ('assistant', 'assistant')], max_length=64), + ), + ] From 4fa9bf5148bc68dc6962bb935ea39f33ecb0b181 Mon Sep 17 00:00:00 2001 From: Marcos Date: Mon, 4 Nov 2024 15:23:31 -0300 Subject: [PATCH 9/9] fix: Added course run key to save_chat_message() --- learning_assistant/api.py | 4 +++- learning_assistant/views.py | 4 ++-- tests/test_api.py | 4 +++- tests/test_views.py | 6 ++++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index e6fa870..eaf4b10 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -191,7 +191,7 @@ def get_course_id(course_run_id): return course_key -def save_chat_message(user_id, chat_role, message): +def save_chat_message(courserun_key, user_id, chat_role, message): """ Save the chat message to the database. """ @@ -203,9 +203,11 @@ def save_chat_message(user_id, chat_role, message): # Save the user message to the database. LearningAssistantMessage.objects.create( + course_id=courserun_key, user=user, role=chat_role, content=message, + ) diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 3e136fe..583aed1 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -96,7 +96,7 @@ def post(self, request, course_run_id): user_id = request.user.id if chat_history_enabled(courserun_key): - save_chat_message(user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) + save_chat_message(courserun_key, user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) serializer = MessageSerializer(data=message_list, many=True) @@ -126,7 +126,7 @@ def post(self, request, course_run_id): status_code, message = get_chat_response(prompt_template, message_list) if chat_history_enabled(courserun_key): - save_chat_message(user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) + save_chat_message(courserun_key, user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) return Response(status=status_code, data=message) diff --git a/tests/test_api.py b/tests/test_api.py index 6969af8..0470e32 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -245,6 +245,7 @@ def setUp(self): super().setUp() self.test_user = User.objects.create(username='username', password='password') + self.course_run_key = CourseKey.from_string('course-v1:edx+test+23') @ddt.data( (LearningAssistantMessage.USER_ROLE, 'What is the meaning of life, the universe and everything?'), @@ -252,10 +253,11 @@ def setUp(self): ) @ddt.unpack def test_save_chat_message(self, chat_role, message): - save_chat_message(self.test_user.id, chat_role, message) + save_chat_message(self.course_run_key, self.test_user.id, chat_role, message) row = LearningAssistantMessage.objects.all().last() + self.assertEqual(row.course_id, self.course_run_key) self.assertEqual(row.role, chat_role) self.assertEqual(row.content, message) diff --git a/tests/test_views.py b/tests/test_views.py index 9e07e55..5125d27 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -13,6 +13,7 @@ from django.test import TestCase, override_settings from django.test.client import Client from django.urls import reverse +from opaque_keys.edx.keys import CourseKey from learning_assistant.models import LearningAssistantMessage @@ -85,6 +86,7 @@ class TestCourseChatView(LoggedInTestCase): def setUp(self): super().setUp() self.course_id = 'course-v1:edx+test+23' + self.course_run_key = CourseKey.from_string(self.course_id) self.patcher = patch( 'learning_assistant.api.get_cache_course_run_data', @@ -209,8 +211,8 @@ def test_chat_response_default( if enabled_flag: mock_save_chat_message.assert_has_calls([ - call(self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), - call(self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') + call(self.course_run_key, self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), + call(self.course_run_key, self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') ]) else: mock_save_chat_message.assert_not_called()