-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #121 from edx/rijuma/save-messages-to-db
feat: Adding chat messages to the DB
- Loading branch information
Showing
8 changed files
with
147 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
learning_assistant/migrations/0008_alter_learningassistantmessage_role.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Generated by Django 4.2.14 on 2024-11-04 08:52 | ||
|
||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
('learning_assistant', '0007_learningassistantmessage'), | ||
] | ||
|
||
operations = [ | ||
migrations.AlterField( | ||
model_name='learningassistantmessage', | ||
name='role', | ||
field=models.CharField(choices=[('user', 'user'), ('assistant', 'assistant')], max_length=64), | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import json | ||
import sys | ||
from importlib import import_module | ||
from unittest.mock import MagicMock, patch | ||
from unittest.mock import MagicMock, call, patch | ||
|
||
import ddt | ||
from django.conf import settings | ||
|
@@ -13,13 +13,14 @@ | |
from django.test import TestCase, override_settings | ||
from django.test.client import Client | ||
from django.urls import reverse | ||
from opaque_keys.edx.keys import CourseKey | ||
|
||
from learning_assistant.models import LearningAssistantMessage | ||
|
||
User = get_user_model() | ||
|
||
|
||
class TestClient(Client): | ||
class FakeClient(Client): | ||
""" | ||
Allows for 'fake logins' of a user so we don't need to expose a 'login' HTTP endpoint. | ||
""" | ||
|
@@ -66,14 +67,14 @@ def setUp(self): | |
Setup for tests. | ||
""" | ||
super().setUp() | ||
self.client = TestClient() | ||
self.client = FakeClient() | ||
self.user = User(username='tester', email='[email protected]', is_staff=True) | ||
self.user.save() | ||
self.client.login_user(self.user) | ||
|
||
|
||
@ddt.ddt | ||
class CourseChatViewTests(LoggedInTestCase): | ||
class TestCourseChatView(LoggedInTestCase): | ||
""" | ||
Test for the CourseChatView | ||
""" | ||
|
@@ -85,6 +86,7 @@ class CourseChatViewTests(LoggedInTestCase): | |
def setUp(self): | ||
super().setUp() | ||
self.course_id = 'course-v1:edx+test+23' | ||
self.course_run_key = CourseKey.from_string(self.course_id) | ||
|
||
self.patcher = patch( | ||
'learning_assistant.api.get_cache_course_run_data', | ||
|
@@ -153,15 +155,27 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): | |
) | ||
self.assertEqual(response.status_code, 400) | ||
|
||
@ddt.data(False, True) | ||
@patch('learning_assistant.views.render_prompt_template') | ||
@patch('learning_assistant.views.get_chat_response') | ||
@patch('learning_assistant.views.learning_assistant_enabled') | ||
@patch('learning_assistant.views.get_user_role') | ||
@patch('learning_assistant.views.CourseEnrollment.get_enrollment') | ||
@patch('learning_assistant.views.CourseMode') | ||
@patch('learning_assistant.views.save_chat_message') | ||
@patch('learning_assistant.views.chat_history_enabled') | ||
@override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template') | ||
def test_chat_response_default( | ||
self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render | ||
self, | ||
enabled_flag, | ||
mock_chat_history_enabled, | ||
mock_save_chat_message, | ||
mock_mode, | ||
mock_enrollment, | ||
mock_role, | ||
mock_waffle, | ||
mock_chat_response, | ||
mock_render, | ||
): | ||
mock_waffle.return_value = True | ||
mock_role.return_value = 'student' | ||
|
@@ -171,9 +185,12 @@ def test_chat_response_default( | |
mock_render.return_value = 'Rendered template mock' | ||
test_unit_id = 'test-unit-id' | ||
|
||
mock_chat_history_enabled.return_value = enabled_flag | ||
|
||
test_data = [ | ||
{'role': 'user', 'content': 'What is 2+2?'}, | ||
{'role': 'assistant', 'content': 'It is 4'} | ||
{'role': 'assistant', 'content': 'It is 4'}, | ||
{'role': 'user', 'content': 'And what else?'}, | ||
] | ||
|
||
response = self.client.post( | ||
|
@@ -192,6 +209,14 @@ def test_chat_response_default( | |
test_data, | ||
) | ||
|
||
if enabled_flag: | ||
mock_save_chat_message.assert_has_calls([ | ||
call(self.course_run_key, self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), | ||
call(self.course_run_key, self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') | ||
]) | ||
else: | ||
mock_save_chat_message.assert_not_called() | ||
|
||
|
||
@ddt.ddt | ||
class LearningAssistantEnabledViewTests(LoggedInTestCase): | ||
|