Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding chat messages to the DB #121

Merged
merged 9 commits into from
Nov 4, 2024
20 changes: 20 additions & 0 deletions learning_assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.cache import cache
from edx_django_utils.cache import get_cache_key
from jinja2 import BaseLoader, Environment
Expand All @@ -24,6 +25,7 @@
from learning_assistant.text_utils import html_to_text

log = logging.getLogger(__name__)
User = get_user_model()


def _extract_block_contents(child, category):
Expand Down Expand Up @@ -189,6 +191,24 @@
return course_key


def save_chat_message(user_id, chat_role, message):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should also take in a course_id, and use that as a field in the message being saved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. I wonder why this wasn't detected on the unit tests, maybe there's a constraint missing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just updated save_chat_message() to take the Courserun Key and save it to the model.

"""
Save the chat message to the database.
"""
user = None
try:
user = User.objects.get(id=user_id)
except User.DoesNotExist as exc:
raise Exception("User does not exists.") from exc

Check failure on line 202 in learning_assistant/api.py

View workflow job for this annotation

GitHub Actions / tests (ubuntu-20.04, 3.11, django42)

Missing coverage

Missing coverage on lines 201-202

# Save the user message to the database.
LearningAssistantMessage.objects.create(
user=user,
role=chat_role,
content=message,
)


def get_message_history(course_id, user, message_count):
"""
Given a course run id (str), user (User), and message count (int), return the associated message history.
Expand Down
10 changes: 9 additions & 1 deletion learning_assistant/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@ class LearningAssistantMessage(TimeStampedModel):
.. pii_retirement: third_party
"""

USER_ROLE = 'user'
ASSISTANT_ROLE = 'assistant'

Roles = (
(USER_ROLE, USER_ROLE),
(ASSISTANT_ROLE, ASSISTANT_ROLE),
)

course_id = CourseKeyField(max_length=255, db_index=True)
user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE)
role = models.CharField(max_length=64)
role = models.CharField(choices=Roles, max_length=64)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't affect the database, so no migration needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? Did you run makemigrations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't. I'm sorry for that.
Just added the migration file.

content = models.TextField()
4 changes: 3 additions & 1 deletion learning_assistant/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
from rest_framework import serializers

from learning_assistant.models import LearningAssistantMessage


class MessageSerializer(serializers.Serializer): # pylint: disable=abstract-method
"""
Expand All @@ -16,7 +18,7 @@ def validate_role(self, value):
"""
Validate that role is one of two acceptable values.
"""
valid_roles = ['user', 'assistant']
valid_roles = [LearningAssistantMessage.USER_ROLE, LearningAssistantMessage.ASSISTANT_ROLE]
if value not in valid_roles:
raise serializers.ValidationError('Must be valid role.')
return value
17 changes: 17 additions & 0 deletions learning_assistant/toggles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@
# .. toggle_tickets: COSMO-80
ENABLE_COURSE_CONTENT = 'enable_course_content'

# .. toggle_name: learning_assistant.enable_chat_history
# .. toggle_implementation: CourseWaffleFlag
# .. toggle_default: False
# .. toggle_description: Waffle flag to enable the chat history with the learning assistant
# .. toggle_use_cases: temporary
# .. toggle_creation_date: 2024-10-30
# .. toggle_target_removal_date: 2024-12-31
# .. toggle_tickets: COSMO-436
ENABLE_CHAT_HISTORY = 'enable_chat_history'


def _is_learning_assistant_waffle_flag_enabled(flag_name, course_key):
"""
Expand All @@ -32,3 +42,10 @@
Return whether the learning_assistant.enable_course_content WaffleFlag is on.
"""
return _is_learning_assistant_waffle_flag_enabled(ENABLE_COURSE_CONTENT, course_key)


def chat_history_enabled(course_key):
"""
Return whether the learning_assistant.enable_chat_history WaffleFlag is on.
"""
return _is_learning_assistant_waffle_flag_enabled(ENABLE_CHAT_HISTORY, course_key)

Check failure on line 51 in learning_assistant/toggles.py

View workflow job for this annotation

GitHub Actions / tests (ubuntu-20.04, 3.11, django42)

Missing coverage

Missing coverage on line 51
20 changes: 20 additions & 0 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
get_message_history,
learning_assistant_enabled,
render_prompt_template,
save_chat_message,
)
from learning_assistant.models import LearningAssistantMessage
from learning_assistant.serializers import MessageSerializer
from learning_assistant.toggles import chat_history_enabled
from learning_assistant.utils import get_chat_response, user_role_is_staff

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,6 +84,20 @@ def post(self, request, course_run_id):
unit_id = request.query_params.get('unit_id')

message_list = request.data

# Check that the last message in the list corresponds to a user
new_user_message = message_list[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand the code block and the comment. Can you explain more to help someone like to understand the logic better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the message_list array is empty?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed the endpoint was to process a new user message. Maybe I misunderstood it but if it is, then it should include at least one user message.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to point out: We discussed this offline and came to the conclusion that the approach is correct.

if new_user_message['role'] != LearningAssistantMessage.USER_ROLE:
return Response(
status=http_status.HTTP_400_BAD_REQUEST,
data={'detail': "Expects user role on last message."}
)

user_id = request.user.id

if chat_history_enabled(courserun_key):
save_chat_message(user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content'])

serializer = MessageSerializer(data=message_list, many=True)

# serializer will not be valid in the case that the message list contains any roles other than
Expand Down Expand Up @@ -108,6 +125,9 @@ def post(self, request, course_run_id):
)
status_code, message = get_chat_response(prompt_template, message_list)

if chat_history_enabled(courserun_key):
save_chat_message(user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content'])

return Response(status=status_code, data=message)


Expand Down
26 changes: 25 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
learning_assistant_available,
learning_assistant_enabled,
render_prompt_template,
save_chat_message,
set_learning_assistant_enabled,
)
from learning_assistant.data import LearningAssistantCourseEnabledData
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage

fake_transcript = 'This is the text version from the transcript'

User = get_user_model()


Expand Down Expand Up @@ -236,6 +236,30 @@ def test_render_prompt_template_invalid_unit_key(self, mock_get_content):
self.assertNotIn('The following text is useful.', prompt_text)


@ddt.ddt
class TestLearningAssistantCourseEnabledApi(TestCase):
"""
Test suite for save_chat_message.
"""
def setUp(self):
super().setUp()

self.test_user = User.objects.create(username='username', password='password')

@ddt.data(
(LearningAssistantMessage.USER_ROLE, 'What is the meaning of life, the universe and everything?'),
(LearningAssistantMessage.ASSISTANT_ROLE, '42'),
)
@ddt.unpack
def test_save_chat_message(self, chat_role, message):
save_chat_message(self.test_user.id, chat_role, message)

row = LearningAssistantMessage.objects.all().last()

self.assertEqual(row.role, chat_role)
self.assertEqual(row.content, message)


@ddt.ddt
class LearningAssistantCourseEnabledApiTests(TestCase):
"""
Expand Down
35 changes: 29 additions & 6 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import sys
from importlib import import_module
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch

import ddt
from django.conf import settings
Expand All @@ -19,7 +19,7 @@
User = get_user_model()


class TestClient(Client):
class FakeClient(Client):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class was renamed to remove the following test warning:

tests/test_views.py:22
  /Users/(ueser)/edx/src/learning-assistant/tests/test_views.py:22: PytestCollectionWarning: cannot collect test class 'TestClient' because it has a __init__ constructor (from: tests/test_views.py)
    class TestClient(Client):

"""
Allows for 'fake logins' of a user so we don't need to expose a 'login' HTTP endpoint.
"""
Expand Down Expand Up @@ -66,14 +66,14 @@ def setUp(self):
Setup for tests.
"""
super().setUp()
self.client = TestClient()
self.client = FakeClient()
self.user = User(username='tester', email='[email protected]', is_staff=True)
self.user.save()
self.client.login_user(self.user)


@ddt.ddt
class CourseChatViewTests(LoggedInTestCase):
class TestCourseChatView(LoggedInTestCase):
"""
Test for the CourseChatView
"""
Expand Down Expand Up @@ -153,15 +153,27 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render):
)
self.assertEqual(response.status_code, 400)

@ddt.data(False, True)
@patch('learning_assistant.views.render_prompt_template')
@patch('learning_assistant.views.get_chat_response')
@patch('learning_assistant.views.learning_assistant_enabled')
@patch('learning_assistant.views.get_user_role')
@patch('learning_assistant.views.CourseEnrollment.get_enrollment')
@patch('learning_assistant.views.CourseMode')
@patch('learning_assistant.views.save_chat_message')
@patch('learning_assistant.views.chat_history_enabled')
@override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template')
def test_chat_response_default(
self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render
self,
enabled_flag,
mock_chat_history_enabled,
mock_save_chat_message,
mock_mode,
mock_enrollment,
mock_role,
mock_waffle,
mock_chat_response,
mock_render,
):
mock_waffle.return_value = True
mock_role.return_value = 'student'
Expand All @@ -171,9 +183,12 @@ def test_chat_response_default(
mock_render.return_value = 'Rendered template mock'
test_unit_id = 'test-unit-id'

mock_chat_history_enabled.return_value = enabled_flag

test_data = [
{'role': 'user', 'content': 'What is 2+2?'},
{'role': 'assistant', 'content': 'It is 4'}
{'role': 'assistant', 'content': 'It is 4'},
{'role': 'user', 'content': 'And what else?'},
Comment on lines +192 to +193
Copy link
Member Author

@rijuma rijuma Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a check that validates the last message from the list to be from a user.

]

response = self.client.post(
Expand All @@ -192,6 +207,14 @@ def test_chat_response_default(
test_data,
)

if enabled_flag:
mock_save_chat_message.assert_has_calls([
call(self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']),
call(self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else')
])
else:
mock_save_chat_message.assert_not_called()


@ddt.ddt
class LearningAssistantEnabledViewTests(LoggedInTestCase):
Expand Down
Loading