-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Adding chat messages to the DB #121
Changes from 7 commits
99ccaa2
c419fde
79917a5
dae5df8
23fd013
352340e
7cdc763
f0b0cf8
4fa9bf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,15 @@ class LearningAssistantMessage(TimeStampedModel): | |
.. pii_retirement: third_party | ||
""" | ||
|
||
USER_ROLE = 'user' | ||
ASSISTANT_ROLE = 'assistant' | ||
|
||
Roles = ( | ||
(USER_ROLE, USER_ROLE), | ||
(ASSISTANT_ROLE, ASSISTANT_ROLE), | ||
) | ||
|
||
course_id = CourseKeyField(max_length=255, db_index=True) | ||
user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE) | ||
role = models.CharField(max_length=64) | ||
role = models.CharField(choices=Roles, max_length=64) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this doesn't affect the database, so no migration needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure? Did you run There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't. I'm sorry for that. |
||
content = models.TextField() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,8 +25,11 @@ | |
get_message_history, | ||
learning_assistant_enabled, | ||
render_prompt_template, | ||
save_chat_message, | ||
) | ||
from learning_assistant.models import LearningAssistantMessage | ||
from learning_assistant.serializers import MessageSerializer | ||
from learning_assistant.toggles import chat_history_enabled | ||
from learning_assistant.utils import get_chat_response, user_role_is_staff | ||
|
||
log = logging.getLogger(__name__) | ||
|
@@ -81,6 +84,20 @@ def post(self, request, course_run_id): | |
unit_id = request.query_params.get('unit_id') | ||
|
||
message_list = request.data | ||
|
||
# Check that the last message in the list corresponds to a user | ||
new_user_message = message_list[-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really understand the code block and the comment. Can you explain more to help someone like to understand the logic better? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if the message_list array is empty? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assumed the endpoint was to process a new user message. Maybe I misunderstood it but if it is, then it should include at least one user message. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to point out: We discussed this offline and came to the conclusion that the approach is correct. |
||
if new_user_message['role'] != LearningAssistantMessage.USER_ROLE: | ||
return Response( | ||
status=http_status.HTTP_400_BAD_REQUEST, | ||
data={'detail': "Expects user role on last message."} | ||
) | ||
|
||
user_id = request.user.id | ||
|
||
if chat_history_enabled(courserun_key): | ||
save_chat_message(user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content']) | ||
|
||
serializer = MessageSerializer(data=message_list, many=True) | ||
|
||
# serializer will not be valid in the case that the message list contains any roles other than | ||
|
@@ -108,6 +125,9 @@ def post(self, request, course_run_id): | |
) | ||
status_code, message = get_chat_response(prompt_template, message_list) | ||
|
||
if chat_history_enabled(courserun_key): | ||
save_chat_message(user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content']) | ||
|
||
return Response(status=status_code, data=message) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import json | ||
import sys | ||
from importlib import import_module | ||
from unittest.mock import MagicMock, patch | ||
from unittest.mock import MagicMock, call, patch | ||
|
||
import ddt | ||
from django.conf import settings | ||
|
@@ -19,7 +19,7 @@ | |
User = get_user_model() | ||
|
||
|
||
class TestClient(Client): | ||
class FakeClient(Client): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class was renamed to remove the following test warning:
|
||
""" | ||
Allows for 'fake logins' of a user so we don't need to expose a 'login' HTTP endpoint. | ||
""" | ||
|
@@ -66,14 +66,14 @@ def setUp(self): | |
Setup for tests. | ||
""" | ||
super().setUp() | ||
self.client = TestClient() | ||
self.client = FakeClient() | ||
self.user = User(username='tester', email='[email protected]', is_staff=True) | ||
self.user.save() | ||
self.client.login_user(self.user) | ||
|
||
|
||
@ddt.ddt | ||
class CourseChatViewTests(LoggedInTestCase): | ||
class TestCourseChatView(LoggedInTestCase): | ||
""" | ||
Test for the CourseChatView | ||
""" | ||
|
@@ -153,15 +153,27 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render): | |
) | ||
self.assertEqual(response.status_code, 400) | ||
|
||
@ddt.data(False, True) | ||
@patch('learning_assistant.views.render_prompt_template') | ||
@patch('learning_assistant.views.get_chat_response') | ||
@patch('learning_assistant.views.learning_assistant_enabled') | ||
@patch('learning_assistant.views.get_user_role') | ||
@patch('learning_assistant.views.CourseEnrollment.get_enrollment') | ||
@patch('learning_assistant.views.CourseMode') | ||
@patch('learning_assistant.views.save_chat_message') | ||
@patch('learning_assistant.views.chat_history_enabled') | ||
@override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template') | ||
def test_chat_response_default( | ||
self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render | ||
self, | ||
enabled_flag, | ||
mock_chat_history_enabled, | ||
mock_save_chat_message, | ||
mock_mode, | ||
mock_enrollment, | ||
mock_role, | ||
mock_waffle, | ||
mock_chat_response, | ||
mock_render, | ||
): | ||
mock_waffle.return_value = True | ||
mock_role.return_value = 'student' | ||
|
@@ -171,9 +183,12 @@ def test_chat_response_default( | |
mock_render.return_value = 'Rendered template mock' | ||
test_unit_id = 'test-unit-id' | ||
|
||
mock_chat_history_enabled.return_value = enabled_flag | ||
|
||
test_data = [ | ||
{'role': 'user', 'content': 'What is 2+2?'}, | ||
{'role': 'assistant', 'content': 'It is 4'} | ||
{'role': 'assistant', 'content': 'It is 4'}, | ||
{'role': 'user', 'content': 'And what else?'}, | ||
Comment on lines
+192
to
+193
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added a check that validates the last message from the list to be from a user. |
||
] | ||
|
||
response = self.client.post( | ||
|
@@ -192,6 +207,14 @@ def test_chat_response_default( | |
test_data, | ||
) | ||
|
||
if enabled_flag: | ||
mock_save_chat_message.assert_has_calls([ | ||
call(self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']), | ||
call(self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else') | ||
]) | ||
else: | ||
mock_save_chat_message.assert_not_called() | ||
|
||
|
||
@ddt.ddt | ||
class LearningAssistantEnabledViewTests(LoggedInTestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should also take in a course_id, and use that as a field in the message being saved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch. I wonder why this wasn't detected on the unit tests, maybe there's a constraint missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just updated
save_chat_message()
to take the Courserun Key and save it to the model.