-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: GET endpoint to retrieve message history #119
Changes from all commits
77b1f78
bd4102c
eca0b80
7ac52f4
8df9f11
43a106a
5816c33
d1157de
2d9fe12
a7982d8
2b27837
fb79e58
a506d86
4041df4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP | ||
from learning_assistant.data import LearningAssistantCourseEnabledData | ||
from learning_assistant.models import LearningAssistantCourseEnabled | ||
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage | ||
from learning_assistant.platform_imports import ( | ||
block_get_children, | ||
block_leaf_filter, | ||
|
@@ -187,3 +187,14 @@ def get_course_id(course_run_id): | |
course_data = get_cache_course_run_data(course_run_id, ['course']) | ||
course_key = course_data['course'] | ||
return course_key | ||
|
||
|
||
def get_message_history(course_id, user, message_count): | ||
""" | ||
Given a course run id (str), user (User), and message count (int), return the associated message history. | ||
|
||
Returns a number of messages equal to the message_count value. | ||
""" | ||
message_history = LearningAssistantMessage.objects.filter( | ||
course_id=course_id, user=user).order_by('-created')[:message_count] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would we need the handle the case here if the message count exceeds the number of messages? (i.e. to prevent and index error) |
||
return message_history |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,12 @@ | |
except ImportError: | ||
pass | ||
|
||
from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template | ||
from learning_assistant.api import ( | ||
get_course_id, | ||
get_message_history, | ||
learning_assistant_enabled, | ||
render_prompt_template, | ||
) | ||
from learning_assistant.serializers import MessageSerializer | ||
from learning_assistant.utils import get_chat_response, user_role_is_staff | ||
|
||
|
@@ -149,3 +154,69 @@ | |
} | ||
|
||
return Response(status=http_status.HTTP_200_OK, data=data) | ||
|
||
|
||
class LearningAssistantMessageHistoryView(APIView): | ||
""" | ||
View to retrieve the message history for user in a course. | ||
|
||
This endpoint returns the message history stored in the LearningAssistantMessage table in a course | ||
represented by the course_key, which is provided in the URL. | ||
|
||
Accepts: [GET] | ||
|
||
Path: /learning_assistant/v1/course_id/{course_key}/history | ||
|
||
Parameters: | ||
* course_key: the ID of the course | ||
|
||
Responses: | ||
* 200: OK | ||
* 400: Malformed Request - Course ID is not a valid course ID. | ||
""" | ||
|
||
authentication_classes = (SessionAuthentication, JwtAuthentication,) | ||
permission_classes = (IsAuthenticated,) | ||
|
||
def get(self, request, course_run_id): | ||
""" | ||
Given a course run ID, retrieve the message history for the corresponding user. | ||
|
||
The response will be in the following format. | ||
|
||
[{'role': 'assistant', 'content': 'something'}] | ||
""" | ||
try: | ||
courserun_key = CourseKey.from_string(course_run_id) | ||
except InvalidKeyError: | ||
return Response( | ||
status=http_status.HTTP_400_BAD_REQUEST, | ||
data={'detail': 'Course ID is not a valid course ID.'} | ||
) | ||
|
||
if not learning_assistant_enabled(courserun_key): | ||
return Response( | ||
status=http_status.HTTP_403_FORBIDDEN, | ||
data={'detail': 'Learning assistant not enabled for course.'} | ||
) | ||
|
||
# If user does not have an enrollment record, or is not staff, they should not have access | ||
user_role = get_user_role(request.user, courserun_key) | ||
enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key) | ||
enrollment_mode = enrollment_object.mode if enrollment_object else None | ||
if ( | ||
(enrollment_mode not in CourseMode.VERIFIED_MODES) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure this logic condition is correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite sure, I copy/pasted this from another view so this could be wrong. @alangsto Do you have an opinion on this? |
||
and not user_role_is_staff(user_role) | ||
): | ||
return Response( | ||
status=http_status.HTTP_403_FORBIDDEN, | ||
data={'detail': 'Must be staff or have valid enrollment.'} | ||
) | ||
|
||
course_id = get_course_id(course_run_id) | ||
user = request.user | ||
|
||
message_count = int(request.GET.get('message_count', 50)) | ||
message_history = get_message_history(course_id, user, message_count) | ||
data = MessageSerializer(message_history, many=True).data | ||
return Response(status=http_status.HTTP_200_OK, data=data) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
import ddt | ||
from django.conf import settings | ||
from django.contrib.auth import get_user_model | ||
from django.core.cache import cache | ||
from django.test import TestCase, override_settings | ||
from opaque_keys import InvalidKeyError | ||
|
@@ -16,16 +17,19 @@ | |
_get_children_contents, | ||
_leaf_filter, | ||
get_block_content, | ||
get_message_history, | ||
learning_assistant_available, | ||
learning_assistant_enabled, | ||
render_prompt_template, | ||
set_learning_assistant_enabled, | ||
) | ||
from learning_assistant.data import LearningAssistantCourseEnabledData | ||
from learning_assistant.models import LearningAssistantCourseEnabled | ||
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage | ||
|
||
fake_transcript = 'This is the text version from the transcript' | ||
|
||
User = get_user_model() | ||
|
||
|
||
class FakeChild: | ||
"""Fake child block for testing""" | ||
|
@@ -49,6 +53,7 @@ def get_html(self): | |
|
||
class FakeBlock: | ||
"Fake block for testing, returns given children" | ||
|
||
def __init__(self, children): | ||
self.children = children | ||
self.scope_ids = lambda: None | ||
|
@@ -236,6 +241,7 @@ class LearningAssistantCourseEnabledApiTests(TestCase): | |
""" | ||
Test suite for learning_assistant_available, learning_assistant_enabled, and set_learning_assistant_enabled. | ||
""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.course_key = CourseKey.from_string('course-v1:edx+fake+1') | ||
|
@@ -305,3 +311,140 @@ def test_learning_assistant_available(self, learning_assistant_available_setting | |
|
||
expected_value = learning_assistant_available_setting_value | ||
self.assertEqual(return_value, expected_value) | ||
|
||
|
||
@ddt.ddt | ||
class GetMessageHistoryTests(TestCase): | ||
""" | ||
Test suite for get_message_history. | ||
""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.course_id = 'course-v1:edx+fake+1' | ||
self.course_key = CourseKey.from_string(self.course_id) | ||
self.user = User(username='tester', email='[email protected]') | ||
self.user.save() | ||
|
||
self.role = 'verified' | ||
|
||
def test_get_message_history(self): | ||
message_count = 5 | ||
for i in range(1, message_count + 1): | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=self.user, | ||
role=self.role, | ||
content=f'Content of message {i}', | ||
) | ||
|
||
return_value = get_message_history(self.course_id, self.user, message_count) | ||
|
||
expected_value = LearningAssistantMessage.objects.filter( | ||
course_id=self.course_id, user=self.user).order_by('-created')[:message_count] | ||
|
||
# Ensure same number of entries | ||
self.assertEqual(len(return_value), len(expected_value)) | ||
|
||
# Ensure values are as expected for all LearningAssistantMessage instances | ||
for i, return_value in enumerate(return_value): | ||
self.assertEqual(return_value.course_id, expected_value[i].course_id) | ||
self.assertEqual(return_value.user, expected_value[i].user) | ||
self.assertEqual(return_value.role, expected_value[i].role) | ||
self.assertEqual(return_value.content, expected_value[i].content) | ||
|
||
@ddt.data( | ||
0, 1, 5, 10, 50 | ||
) | ||
def test_get_message_history_message_count(self, actual_message_count): | ||
for i in range(1, actual_message_count + 1): | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=self.user, | ||
role=self.role, | ||
content=f'Content of message {i}', | ||
) | ||
|
||
message_count_parameter = 5 | ||
return_value = get_message_history(self.course_id, self.user, message_count_parameter) | ||
|
||
expected_value = LearningAssistantMessage.objects.filter( | ||
course_id=self.course_id, user=self.user).order_by('-created')[:message_count_parameter] | ||
|
||
# Ensure same number of entries | ||
self.assertEqual(len(return_value), len(expected_value)) | ||
|
||
def test_get_message_history_user_difference(self): | ||
# Default Message | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=self.user, | ||
role=self.role, | ||
content='Expected content of message', | ||
) | ||
|
||
# New message w/ new user | ||
new_user = User(username='not_tester', email='[email protected]') | ||
new_user.save() | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=new_user, | ||
role=self.role, | ||
content='Expected content of message', | ||
) | ||
|
||
message_count = 2 | ||
return_value = get_message_history(self.course_id, self.user, message_count) | ||
|
||
expected_value = LearningAssistantMessage.objects.filter( | ||
course_id=self.course_id, user=self.user).order_by('-created')[:message_count] | ||
|
||
# Ensure we filtered one of the two present messages | ||
self.assertNotEqual(len(return_value), LearningAssistantMessage.objects.count()) | ||
|
||
# Ensure same number of entries | ||
self.assertEqual(len(return_value), len(expected_value)) | ||
|
||
# Ensure values are as expected for all LearningAssistantMessage instances | ||
for i, return_value in enumerate(return_value): | ||
self.assertEqual(return_value.course_id, expected_value[i].course_id) | ||
self.assertEqual(return_value.user, expected_value[i].user) | ||
self.assertEqual(return_value.role, expected_value[i].role) | ||
self.assertEqual(return_value.content, expected_value[i].content) | ||
|
||
def test_get_message_course_id_differences(self): | ||
# Default Message | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=self.user, | ||
role=self.role, | ||
content='Expected content of message', | ||
) | ||
|
||
# New message | ||
wrong_course_id = 'course-v1:wrong+id+1' | ||
LearningAssistantMessage.objects.create( | ||
course_id=wrong_course_id, | ||
user=self.user, | ||
role=self.role, | ||
content='Expected content of message', | ||
) | ||
|
||
message_count = 2 | ||
return_value = get_message_history(self.course_id, self.user, message_count) | ||
|
||
expected_value = LearningAssistantMessage.objects.filter( | ||
course_id=self.course_id, user=self.user).order_by('-created')[:message_count] | ||
|
||
# Ensure we filtered one of the two present messages | ||
self.assertNotEqual(len(return_value), LearningAssistantMessage.objects.count()) | ||
|
||
# Ensure same number of entries | ||
self.assertEqual(len(return_value), len(expected_value)) | ||
|
||
# Ensure values are as expected for all LearningAssistantMessage instances | ||
for i, return_value in enumerate(return_value): | ||
self.assertEqual(return_value.course_id, expected_value[i].course_id) | ||
self.assertEqual(return_value.user, expected_value[i].user) | ||
self.assertEqual(return_value.role, expected_value[i].role) | ||
self.assertEqual(return_value.content, expected_value[i].content) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: If the message count exceeds the number of messages, we could return messages that are less than this value (in other words, the number of messages cannot exceed this value but could be less than this value).