Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: GET endpoint to retrieve message history #119

Merged
merged 14 commits into from
Oct 31, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ Unreleased
**********
* Add LearningAssistantMessage model

4.4.0 - 2024-10-30
******************
* Add new GET endpoint to retrieve a user's message history in a given course.

4.3.3 - 2024-10-15
******************
* Use `LEARNING_ASSISTANT_PROMPT_TEMPLATE` for prompt
Expand Down
2 changes: 1 addition & 1 deletion learning_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Plugin for a learning assistant backend, intended for use within edx-platform.
"""

__version__ = '4.3.3'
__version__ = '4.4.0'

default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name
13 changes: 12 additions & 1 deletion learning_assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP
from learning_assistant.data import LearningAssistantCourseEnabledData
from learning_assistant.models import LearningAssistantCourseEnabled
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage
from learning_assistant.platform_imports import (
block_get_children,
block_leaf_filter,
Expand Down Expand Up @@ -187,3 +187,14 @@ def get_course_id(course_run_id):
course_data = get_cache_course_run_data(course_run_id, ['course'])
course_key = course_data['course']
return course_key


def get_message_history(course_id, user, message_count):
"""
Given a course run id (str), user (User), and message count (int), return the associated message history.

Returns a number of messages equal to the message_count value.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: If the message count exceeds the number of messages, we could return messages that are less than this value (in other words, the number of messages cannot exceed this value but could be less than this value).

"""
message_history = LearningAssistantMessage.objects.filter(
course_id=course_id, user=user).order_by('-created')[:message_count]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we need the handle the case here if the message count exceeds the number of messages? (i.e. to prevent and index error)

return message_history
2 changes: 1 addition & 1 deletion learning_assistant/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def cleanup_text(text):
return stripped


class _HTMLToTextHelper(HTMLParser): # lint-amnesty, pylint: disable=abstract-method
class _HTMLToTextHelper(HTMLParser): # lint-amnesty
"""
Helper function for html_to_text below.
"""
Expand Down
7 changes: 6 additions & 1 deletion learning_assistant/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.urls import re_path

from learning_assistant.constants import COURSE_ID_PATTERN
from learning_assistant.views import CourseChatView, LearningAssistantEnabledView
from learning_assistant.views import CourseChatView, LearningAssistantEnabledView, LearningAssistantMessageHistoryView

app_name = 'learning_assistant'

Expand All @@ -19,4 +19,9 @@
LearningAssistantEnabledView.as_view(),
name='enabled',
),
re_path(
fr'learning_assistant/v1/course_id/{COURSE_ID_PATTERN}/history',
LearningAssistantMessageHistoryView.as_view(),
name='message-history',
),
]
73 changes: 72 additions & 1 deletion learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
except ImportError:
pass

from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template
from learning_assistant.api import (
get_course_id,
get_message_history,
learning_assistant_enabled,
render_prompt_template,
)
from learning_assistant.serializers import MessageSerializer
from learning_assistant.utils import get_chat_response, user_role_is_staff

Expand Down Expand Up @@ -149,3 +154,69 @@
}

return Response(status=http_status.HTTP_200_OK, data=data)


class LearningAssistantMessageHistoryView(APIView):
"""
View to retrieve the message history for user in a course.

This endpoint returns the message history stored in the LearningAssistantMessage table in a course
represented by the course_key, which is provided in the URL.

Accepts: [GET]

Path: /learning_assistant/v1/course_id/{course_key}/history

Parameters:
* course_key: the ID of the course

Responses:
* 200: OK
* 400: Malformed Request - Course ID is not a valid course ID.
"""

authentication_classes = (SessionAuthentication, JwtAuthentication,)
permission_classes = (IsAuthenticated,)

def get(self, request, course_run_id):
"""
Given a course run ID, retrieve the message history for the corresponding user.

The response will be in the following format.

[{'role': 'assistant', 'content': 'something'}]
"""
try:
courserun_key = CourseKey.from_string(course_run_id)
except InvalidKeyError:
return Response(

Check failure on line 192 in learning_assistant/views.py

View workflow job for this annotation

GitHub Actions / tests (ubuntu-20.04, 3.11, django42)

Missing coverage

Missing coverage on lines 191-192
status=http_status.HTTP_400_BAD_REQUEST,
data={'detail': 'Course ID is not a valid course ID.'}
)

if not learning_assistant_enabled(courserun_key):
return Response(
status=http_status.HTTP_403_FORBIDDEN,
data={'detail': 'Learning assistant not enabled for course.'}
)

# If user does not have an enrollment record, or is not staff, they should not have access
user_role = get_user_role(request.user, courserun_key)
enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key)
enrollment_mode = enrollment_object.mode if enrollment_object else None
if (
(enrollment_mode not in CourseMode.VERIFIED_MODES)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this logic condition is correct?
We should return if user_role_is_staff, no?
Are you sure we don't want to show message history to audit mode learners?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure, I copy/pasted this from another view so this could be wrong. @alangsto Do you have an opinion on this?

and not user_role_is_staff(user_role)
):
return Response(
status=http_status.HTTP_403_FORBIDDEN,
data={'detail': 'Must be staff or have valid enrollment.'}
)

course_id = get_course_id(course_run_id)
user = request.user

message_count = int(request.GET.get('message_count', 50))
message_history = get_message_history(course_id, user, message_count)
data = MessageSerializer(message_history, many=True).data
return Response(status=http_status.HTTP_200_OK, data=data)
145 changes: 144 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ddt
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.cache import cache
from django.test import TestCase, override_settings
from opaque_keys import InvalidKeyError
Expand All @@ -16,16 +17,19 @@
_get_children_contents,
_leaf_filter,
get_block_content,
get_message_history,
learning_assistant_available,
learning_assistant_enabled,
render_prompt_template,
set_learning_assistant_enabled,
)
from learning_assistant.data import LearningAssistantCourseEnabledData
from learning_assistant.models import LearningAssistantCourseEnabled
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage

fake_transcript = 'This is the text version from the transcript'

User = get_user_model()


class FakeChild:
"""Fake child block for testing"""
Expand All @@ -49,6 +53,7 @@ def get_html(self):

class FakeBlock:
"Fake block for testing, returns given children"

def __init__(self, children):
self.children = children
self.scope_ids = lambda: None
Expand Down Expand Up @@ -236,6 +241,7 @@ class LearningAssistantCourseEnabledApiTests(TestCase):
"""
Test suite for learning_assistant_available, learning_assistant_enabled, and set_learning_assistant_enabled.
"""

def setUp(self):
super().setUp()
self.course_key = CourseKey.from_string('course-v1:edx+fake+1')
Expand Down Expand Up @@ -305,3 +311,140 @@ def test_learning_assistant_available(self, learning_assistant_available_setting

expected_value = learning_assistant_available_setting_value
self.assertEqual(return_value, expected_value)


@ddt.ddt
class GetMessageHistoryTests(TestCase):
"""
Test suite for get_message_history.
"""

def setUp(self):
super().setUp()
self.course_id = 'course-v1:edx+fake+1'
self.course_key = CourseKey.from_string(self.course_id)
self.user = User(username='tester', email='[email protected]')
self.user.save()

self.role = 'verified'

def test_get_message_history(self):
message_count = 5
for i in range(1, message_count + 1):
LearningAssistantMessage.objects.create(
course_id=self.course_id,
user=self.user,
role=self.role,
content=f'Content of message {i}',
)

return_value = get_message_history(self.course_id, self.user, message_count)

expected_value = LearningAssistantMessage.objects.filter(
course_id=self.course_id, user=self.user).order_by('-created')[:message_count]

# Ensure same number of entries
self.assertEqual(len(return_value), len(expected_value))

# Ensure values are as expected for all LearningAssistantMessage instances
for i, return_value in enumerate(return_value):
self.assertEqual(return_value.course_id, expected_value[i].course_id)
self.assertEqual(return_value.user, expected_value[i].user)
self.assertEqual(return_value.role, expected_value[i].role)
self.assertEqual(return_value.content, expected_value[i].content)

@ddt.data(
0, 1, 5, 10, 50
)
def test_get_message_history_message_count(self, actual_message_count):
for i in range(1, actual_message_count + 1):
LearningAssistantMessage.objects.create(
course_id=self.course_id,
user=self.user,
role=self.role,
content=f'Content of message {i}',
)

message_count_parameter = 5
return_value = get_message_history(self.course_id, self.user, message_count_parameter)

expected_value = LearningAssistantMessage.objects.filter(
course_id=self.course_id, user=self.user).order_by('-created')[:message_count_parameter]

# Ensure same number of entries
self.assertEqual(len(return_value), len(expected_value))

def test_get_message_history_user_difference(self):
# Default Message
LearningAssistantMessage.objects.create(
course_id=self.course_id,
user=self.user,
role=self.role,
content='Expected content of message',
)

# New message w/ new user
new_user = User(username='not_tester', email='[email protected]')
new_user.save()
LearningAssistantMessage.objects.create(
course_id=self.course_id,
user=new_user,
role=self.role,
content='Expected content of message',
)

message_count = 2
return_value = get_message_history(self.course_id, self.user, message_count)

expected_value = LearningAssistantMessage.objects.filter(
course_id=self.course_id, user=self.user).order_by('-created')[:message_count]

# Ensure we filtered one of the two present messages
self.assertNotEqual(len(return_value), LearningAssistantMessage.objects.count())

# Ensure same number of entries
self.assertEqual(len(return_value), len(expected_value))

# Ensure values are as expected for all LearningAssistantMessage instances
for i, return_value in enumerate(return_value):
self.assertEqual(return_value.course_id, expected_value[i].course_id)
self.assertEqual(return_value.user, expected_value[i].user)
self.assertEqual(return_value.role, expected_value[i].role)
self.assertEqual(return_value.content, expected_value[i].content)

def test_get_message_course_id_differences(self):
# Default Message
LearningAssistantMessage.objects.create(
course_id=self.course_id,
user=self.user,
role=self.role,
content='Expected content of message',
)

# New message
wrong_course_id = 'course-v1:wrong+id+1'
LearningAssistantMessage.objects.create(
course_id=wrong_course_id,
user=self.user,
role=self.role,
content='Expected content of message',
)

message_count = 2
return_value = get_message_history(self.course_id, self.user, message_count)

expected_value = LearningAssistantMessage.objects.filter(
course_id=self.course_id, user=self.user).order_by('-created')[:message_count]

# Ensure we filtered one of the two present messages
self.assertNotEqual(len(return_value), LearningAssistantMessage.objects.count())

# Ensure same number of entries
self.assertEqual(len(return_value), len(expected_value))

# Ensure values are as expected for all LearningAssistantMessage instances
for i, return_value in enumerate(return_value):
self.assertEqual(return_value.course_id, expected_value[i].course_id)
self.assertEqual(return_value.user, expected_value[i].user)
self.assertEqual(return_value.role, expected_value[i].role)
self.assertEqual(return_value.content, expected_value[i].content)
1 change: 1 addition & 0 deletions tests/test_plugins_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class PluginApiTests(TestCase):
"""
Test suite for the plugins_api module.
"""

def setUp(self):
super().setUp()
self.course_key = CourseKey.from_string('course-v1:edx+fake+1')
Expand Down
Loading
Loading