Skip to content

Commit

Permalink
Merge pull request #121 from edx/rijuma/save-messages-to-db
Browse files Browse the repository at this point in the history
feat: Adding chat messages to the DB
  • Loading branch information
rijuma authored Nov 4, 2024
2 parents dbcaa39 + 4fa9bf5 commit 758b657
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 9 deletions.
22 changes: 22 additions & 0 deletions learning_assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.cache import cache
from edx_django_utils.cache import get_cache_key
from jinja2 import BaseLoader, Environment
Expand All @@ -24,6 +25,7 @@
from learning_assistant.text_utils import html_to_text

log = logging.getLogger(__name__)
User = get_user_model()


def _extract_block_contents(child, category):
Expand Down Expand Up @@ -189,6 +191,26 @@ def get_course_id(course_run_id):
return course_key


def save_chat_message(courserun_key, user_id, chat_role, message):
"""
Save the chat message to the database.
"""
user = None
try:
user = User.objects.get(id=user_id)
except User.DoesNotExist as exc:
raise Exception("User does not exists.") from exc

# Save the user message to the database.
LearningAssistantMessage.objects.create(
course_id=courserun_key,
user=user,
role=chat_role,
content=message,

)


def get_message_history(course_id, user, message_count):
"""
Given a course run id (str), user (User), and message count (int), return the associated message history.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.14 on 2024-11-04 08:52

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('learning_assistant', '0007_learningassistantmessage'),
]

operations = [
migrations.AlterField(
model_name='learningassistantmessage',
name='role',
field=models.CharField(choices=[('user', 'user'), ('assistant', 'assistant')], max_length=64),
),
]
10 changes: 9 additions & 1 deletion learning_assistant/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@ class LearningAssistantMessage(TimeStampedModel):
.. pii_retirement: third_party
"""

USER_ROLE = 'user'
ASSISTANT_ROLE = 'assistant'

Roles = (
(USER_ROLE, USER_ROLE),
(ASSISTANT_ROLE, ASSISTANT_ROLE),
)

course_id = CourseKeyField(max_length=255, db_index=True)
user = models.ForeignKey(USER_MODEL, db_index=True, on_delete=models.CASCADE)
role = models.CharField(max_length=64)
role = models.CharField(choices=Roles, max_length=64)
content = models.TextField()
4 changes: 3 additions & 1 deletion learning_assistant/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
from rest_framework import serializers

from learning_assistant.models import LearningAssistantMessage


class MessageSerializer(serializers.Serializer): # pylint: disable=abstract-method
"""
Expand All @@ -16,7 +18,7 @@ def validate_role(self, value):
"""
Validate that role is one of two acceptable values.
"""
valid_roles = ['user', 'assistant']
valid_roles = [LearningAssistantMessage.USER_ROLE, LearningAssistantMessage.ASSISTANT_ROLE]
if value not in valid_roles:
raise serializers.ValidationError('Must be valid role.')
return value
17 changes: 17 additions & 0 deletions learning_assistant/toggles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@
# .. toggle_tickets: COSMO-80
ENABLE_COURSE_CONTENT = 'enable_course_content'

# .. toggle_name: learning_assistant.enable_chat_history
# .. toggle_implementation: CourseWaffleFlag
# .. toggle_default: False
# .. toggle_description: Waffle flag to enable the chat history with the learning assistant
# .. toggle_use_cases: temporary
# .. toggle_creation_date: 2024-10-30
# .. toggle_target_removal_date: 2024-12-31
# .. toggle_tickets: COSMO-436
ENABLE_CHAT_HISTORY = 'enable_chat_history'


def _is_learning_assistant_waffle_flag_enabled(flag_name, course_key):
"""
Expand All @@ -32,3 +42,10 @@ def course_content_enabled(course_key):
Return whether the learning_assistant.enable_course_content WaffleFlag is on.
"""
return _is_learning_assistant_waffle_flag_enabled(ENABLE_COURSE_CONTENT, course_key)


def chat_history_enabled(course_key):
"""
Return whether the learning_assistant.enable_chat_history WaffleFlag is on.
"""
return _is_learning_assistant_waffle_flag_enabled(ENABLE_CHAT_HISTORY, course_key)
20 changes: 20 additions & 0 deletions learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
get_message_history,
learning_assistant_enabled,
render_prompt_template,
save_chat_message,
)
from learning_assistant.models import LearningAssistantMessage
from learning_assistant.serializers import MessageSerializer
from learning_assistant.toggles import chat_history_enabled
from learning_assistant.utils import get_chat_response, user_role_is_staff

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,6 +84,20 @@ def post(self, request, course_run_id):
unit_id = request.query_params.get('unit_id')

message_list = request.data

# Check that the last message in the list corresponds to a user
new_user_message = message_list[-1]
if new_user_message['role'] != LearningAssistantMessage.USER_ROLE:
return Response(
status=http_status.HTTP_400_BAD_REQUEST,
data={'detail': "Expects user role on last message."}
)

user_id = request.user.id

if chat_history_enabled(courserun_key):
save_chat_message(courserun_key, user_id, LearningAssistantMessage.USER_ROLE, new_user_message['content'])

serializer = MessageSerializer(data=message_list, many=True)

# serializer will not be valid in the case that the message list contains any roles other than
Expand Down Expand Up @@ -108,6 +125,9 @@ def post(self, request, course_run_id):
)
status_code, message = get_chat_response(prompt_template, message_list)

if chat_history_enabled(courserun_key):
save_chat_message(courserun_key, user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content'])

return Response(status=status_code, data=message)


Expand Down
28 changes: 27 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
learning_assistant_available,
learning_assistant_enabled,
render_prompt_template,
save_chat_message,
set_learning_assistant_enabled,
)
from learning_assistant.data import LearningAssistantCourseEnabledData
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage

fake_transcript = 'This is the text version from the transcript'

User = get_user_model()


Expand Down Expand Up @@ -236,6 +236,32 @@ def test_render_prompt_template_invalid_unit_key(self, mock_get_content):
self.assertNotIn('The following text is useful.', prompt_text)


@ddt.ddt
class TestLearningAssistantCourseEnabledApi(TestCase):
"""
Test suite for save_chat_message.
"""
def setUp(self):
super().setUp()

self.test_user = User.objects.create(username='username', password='password')
self.course_run_key = CourseKey.from_string('course-v1:edx+test+23')

@ddt.data(
(LearningAssistantMessage.USER_ROLE, 'What is the meaning of life, the universe and everything?'),
(LearningAssistantMessage.ASSISTANT_ROLE, '42'),
)
@ddt.unpack
def test_save_chat_message(self, chat_role, message):
save_chat_message(self.course_run_key, self.test_user.id, chat_role, message)

row = LearningAssistantMessage.objects.all().last()

self.assertEqual(row.course_id, self.course_run_key)
self.assertEqual(row.role, chat_role)
self.assertEqual(row.content, message)


@ddt.ddt
class LearningAssistantCourseEnabledApiTests(TestCase):
"""
Expand Down
37 changes: 31 additions & 6 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import sys
from importlib import import_module
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch

import ddt
from django.conf import settings
Expand All @@ -13,13 +13,14 @@
from django.test import TestCase, override_settings
from django.test.client import Client
from django.urls import reverse
from opaque_keys.edx.keys import CourseKey

from learning_assistant.models import LearningAssistantMessage

User = get_user_model()


class TestClient(Client):
class FakeClient(Client):
"""
Allows for 'fake logins' of a user so we don't need to expose a 'login' HTTP endpoint.
"""
Expand Down Expand Up @@ -66,14 +67,14 @@ def setUp(self):
Setup for tests.
"""
super().setUp()
self.client = TestClient()
self.client = FakeClient()
self.user = User(username='tester', email='[email protected]', is_staff=True)
self.user.save()
self.client.login_user(self.user)


@ddt.ddt
class CourseChatViewTests(LoggedInTestCase):
class TestCourseChatView(LoggedInTestCase):
"""
Test for the CourseChatView
"""
Expand All @@ -85,6 +86,7 @@ class CourseChatViewTests(LoggedInTestCase):
def setUp(self):
super().setUp()
self.course_id = 'course-v1:edx+test+23'
self.course_run_key = CourseKey.from_string(self.course_id)

self.patcher = patch(
'learning_assistant.api.get_cache_course_run_data',
Expand Down Expand Up @@ -153,15 +155,27 @@ def test_invalid_messages(self, mock_role, mock_waffle, mock_render):
)
self.assertEqual(response.status_code, 400)

@ddt.data(False, True)
@patch('learning_assistant.views.render_prompt_template')
@patch('learning_assistant.views.get_chat_response')
@patch('learning_assistant.views.learning_assistant_enabled')
@patch('learning_assistant.views.get_user_role')
@patch('learning_assistant.views.CourseEnrollment.get_enrollment')
@patch('learning_assistant.views.CourseMode')
@patch('learning_assistant.views.save_chat_message')
@patch('learning_assistant.views.chat_history_enabled')
@override_settings(LEARNING_ASSISTANT_PROMPT_TEMPLATE='This is the default template')
def test_chat_response_default(
self, mock_mode, mock_enrollment, mock_role, mock_waffle, mock_chat_response, mock_render
self,
enabled_flag,
mock_chat_history_enabled,
mock_save_chat_message,
mock_mode,
mock_enrollment,
mock_role,
mock_waffle,
mock_chat_response,
mock_render,
):
mock_waffle.return_value = True
mock_role.return_value = 'student'
Expand All @@ -171,9 +185,12 @@ def test_chat_response_default(
mock_render.return_value = 'Rendered template mock'
test_unit_id = 'test-unit-id'

mock_chat_history_enabled.return_value = enabled_flag

test_data = [
{'role': 'user', 'content': 'What is 2+2?'},
{'role': 'assistant', 'content': 'It is 4'}
{'role': 'assistant', 'content': 'It is 4'},
{'role': 'user', 'content': 'And what else?'},
]

response = self.client.post(
Expand All @@ -192,6 +209,14 @@ def test_chat_response_default(
test_data,
)

if enabled_flag:
mock_save_chat_message.assert_has_calls([
call(self.course_run_key, self.user.id, LearningAssistantMessage.USER_ROLE, test_data[-1]['content']),
call(self.course_run_key, self.user.id, LearningAssistantMessage.ASSISTANT_ROLE, 'Something else')
])
else:
mock_save_chat_message.assert_not_called()


@ddt.ddt
class LearningAssistantEnabledViewTests(LoggedInTestCase):
Expand Down

0 comments on commit 758b657

Please sign in to comment.