-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: GET endpoint to retrieve message history #119
Changes from 11 commits
77b1f78
bd4102c
eca0b80
7ac52f4
8df9f11
43a106a
5816c33
d1157de
2d9fe12
a7982d8
2b27837
fb79e58
a506d86
4041df4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
from learning_assistant.constants import ACCEPTED_CATEGORY_TYPES, CATEGORY_TYPE_MAP | ||
from learning_assistant.data import LearningAssistantCourseEnabledData | ||
from learning_assistant.models import LearningAssistantCourseEnabled | ||
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage | ||
from learning_assistant.platform_imports import ( | ||
block_get_children, | ||
block_leaf_filter, | ||
|
@@ -187,3 +187,21 @@ def get_course_id(course_run_id): | |
course_data = get_cache_course_run_data(course_run_id, ['course']) | ||
course_key = course_data['course'] | ||
return course_key | ||
|
||
|
||
def get_message_history(course_id, user, message_count): | ||
""" | ||
Given a course run id (str), user (User), and message count (int), return the associated message history. | ||
|
||
Returns a number of messages equal to the message_count value. | ||
""" | ||
# If the received message count exceeds the number of messages present, we return all the messages queried. | ||
actual_message_count = LearningAssistantMessage.objects.count() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually don't know if this is strictly necessary. Have you tested the code without this, and just indexing using the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh that actually works. I'll get rid of this code. |
||
if message_count > actual_message_count: | ||
amount_to_get = actual_message_count | ||
else: | ||
amount_to_get = message_count | ||
|
||
message_history = LearningAssistantMessage.objects.filter( | ||
course_id=course_id, user=user).order_by('-created')[:amount_to_get] | ||
return message_history |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
from django.urls import re_path | ||
|
||
from learning_assistant.constants import COURSE_ID_PATTERN | ||
from learning_assistant.views import CourseChatView, LearningAssistantEnabledView | ||
from learning_assistant.views import CourseChatView, LearningAssistantEnabledView, LearningAssistantMessageHistoryView | ||
|
||
app_name = 'learning_assistant' | ||
|
||
|
@@ -19,4 +19,9 @@ | |
LearningAssistantEnabledView.as_view(), | ||
name='enabled', | ||
), | ||
re_path( | ||
fr'learning_assistant/v1/course_id/{COURSE_ID_PATTERN}/history/<int:message_count>', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the message count shouldn't be specified in the path, in my opinion. Happy to be convinced otherwise, but it feels more appropriate, as the message count param is used to filter on the resource being specified. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just realizing query parameters don't need to be in the URL... I'll get rid of that now |
||
LearningAssistantMessageHistoryView.as_view(), | ||
name='message-history', | ||
), | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,12 @@ | |
except ImportError: | ||
pass | ||
|
||
from learning_assistant.api import get_course_id, learning_assistant_enabled, render_prompt_template | ||
from learning_assistant.api import ( | ||
get_course_id, | ||
get_message_history, | ||
learning_assistant_enabled, | ||
render_prompt_template, | ||
) | ||
from learning_assistant.serializers import MessageSerializer | ||
from learning_assistant.utils import get_chat_response, user_role_is_staff | ||
|
||
|
@@ -149,3 +154,69 @@ | |
} | ||
|
||
return Response(status=http_status.HTTP_200_OK, data=data) | ||
|
||
|
||
class LearningAssistantMessageHistoryView(APIView): | ||
""" | ||
View to retrieve the message history for user in a course. | ||
|
||
This endpoint returns the message history stored in the LearningAssistantMessage table in a course | ||
represented by the course_key, which is provided in the URL. | ||
|
||
Accepts: [GET] | ||
|
||
Path: /learning_assistant/v1/course_id/{course_key}/history/{message_count} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: Could you update the path to reflect that message_count is no longer a path parameter? |
||
|
||
Parameters: | ||
* course_key: the ID of the course | ||
|
||
Responses: | ||
* 200: OK | ||
* 400: Malformed Request - Course ID is not a valid course ID. | ||
""" | ||
|
||
authentication_classes = (SessionAuthentication, JwtAuthentication,) | ||
permission_classes = (IsAuthenticated,) | ||
|
||
def get(self, request, course_run_id, message_count=50): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still think this should be a query parameter, not a path parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I keep mixing these up- I think I've got it right now, gonna use `request.GET.get('message_count', 50) instead. Side note, what do you think the default value should be? Is 50 good? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that should be the right syntax! And 50 seems good. |
||
""" | ||
Given a course run ID, retrieve the message history for the corresponding user. | ||
|
||
The response will be in the following format. | ||
|
||
[{'role': 'assistant', 'content': 'something'}] | ||
""" | ||
try: | ||
courserun_key = CourseKey.from_string(course_run_id) | ||
except InvalidKeyError: | ||
return Response( | ||
status=http_status.HTTP_400_BAD_REQUEST, | ||
data={'detail': 'Course ID is not a valid course ID.'} | ||
) | ||
|
||
if not learning_assistant_enabled(courserun_key): | ||
return Response( | ||
status=http_status.HTTP_403_FORBIDDEN, | ||
data={'detail': 'Learning assistant not enabled for course.'} | ||
) | ||
|
||
# If user does not have an enrollment record, or is not staff, they should not have access | ||
# NOTE: This will likely be removed once work is done to allow audit learners to access xpert | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: Should this note be left here? |
||
user_role = get_user_role(request.user, courserun_key) | ||
enrollment_object = CourseEnrollment.get_enrollment(request.user, courserun_key) | ||
enrollment_mode = enrollment_object.mode if enrollment_object else None | ||
if ( | ||
(enrollment_mode not in CourseMode.VERIFIED_MODES) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure this logic condition is correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite sure, I copy/pasted this from another view so this could be wrong. @alangsto Do you have an opinion on this? |
||
and not user_role_is_staff(user_role) | ||
): | ||
return Response( | ||
status=http_status.HTTP_403_FORBIDDEN, | ||
data={'detail': 'Must be staff or have valid enrollment.'} | ||
) | ||
|
||
course_id = get_course_id(course_run_id) | ||
user = request.user | ||
|
||
message_history = get_message_history(course_id, user, message_count) | ||
data = MessageSerializer(message_history, many=True).data | ||
return Response(status=http_status.HTTP_200_OK, data=data) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
|
||
import ddt | ||
from django.conf import settings | ||
from django.contrib.auth import get_user_model | ||
from django.core.cache import cache | ||
from django.test import TestCase, override_settings | ||
from opaque_keys import InvalidKeyError | ||
|
@@ -16,16 +17,19 @@ | |
_get_children_contents, | ||
_leaf_filter, | ||
get_block_content, | ||
get_message_history, | ||
learning_assistant_available, | ||
learning_assistant_enabled, | ||
render_prompt_template, | ||
set_learning_assistant_enabled, | ||
) | ||
from learning_assistant.data import LearningAssistantCourseEnabledData | ||
from learning_assistant.models import LearningAssistantCourseEnabled | ||
from learning_assistant.models import LearningAssistantCourseEnabled, LearningAssistantMessage | ||
|
||
fake_transcript = 'This is the text version from the transcript' | ||
|
||
User = get_user_model() | ||
|
||
|
||
class FakeChild: | ||
"""Fake child block for testing""" | ||
|
@@ -49,6 +53,7 @@ def get_html(self): | |
|
||
class FakeBlock: | ||
"Fake block for testing, returns given children" | ||
|
||
def __init__(self, children): | ||
self.children = children | ||
self.scope_ids = lambda: None | ||
|
@@ -236,6 +241,7 @@ class LearningAssistantCourseEnabledApiTests(TestCase): | |
""" | ||
Test suite for learning_assistant_available, learning_assistant_enabled, and set_learning_assistant_enabled. | ||
""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.course_key = CourseKey.from_string('course-v1:edx+fake+1') | ||
|
@@ -305,3 +311,140 @@ def test_learning_assistant_available(self, learning_assistant_available_setting | |
|
||
expected_value = learning_assistant_available_setting_value | ||
self.assertEqual(return_value, expected_value) | ||
|
||
|
||
@ddt.ddt | ||
class GetMessageHistoryTests(TestCase): | ||
""" | ||
Test suite for get_message_history. | ||
""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.course_id = 'course-v1:edx+fake+1' | ||
self.course_key = CourseKey.from_string(self.course_id) | ||
self.user = User(username='tester', email='[email protected]') | ||
self.user.save() | ||
|
||
self.role = 'verified' | ||
|
||
def test_get_message_history(self): | ||
message_count = 5 | ||
for i in range(1, message_count + 1): | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=self.user, | ||
role=self.role, | ||
content=f'Content of message {i}', | ||
) | ||
|
||
return_value = get_message_history(self.course_id, self.user, message_count) | ||
|
||
expected_value = LearningAssistantMessage.objects.filter( | ||
course_id=self.course_id, user=self.user).order_by('-created')[:message_count] | ||
|
||
# Ensure same number of entries | ||
self.assertEqual(len(return_value), len(expected_value)) | ||
|
||
# Ensure values are as expected for all LearningAssistantMessage instances | ||
for i, return_value in enumerate(return_value): | ||
self.assertEqual(return_value.course_id, expected_value[i].course_id) | ||
self.assertEqual(return_value.user, expected_value[i].user) | ||
self.assertEqual(return_value.role, expected_value[i].role) | ||
self.assertEqual(return_value.content, expected_value[i].content) | ||
|
||
@ddt.data( | ||
0, 1, 5, 10, 50 | ||
) | ||
def test_get_message_history_message_count(self, actual_message_count): | ||
for i in range(1, actual_message_count + 1): | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=self.user, | ||
role=self.role, | ||
content=f'Content of message {i}', | ||
) | ||
|
||
message_count_parameter = 5 | ||
return_value = get_message_history(self.course_id, self.user, message_count_parameter) | ||
|
||
expected_value = LearningAssistantMessage.objects.filter( | ||
course_id=self.course_id, user=self.user).order_by('-created')[:message_count_parameter] | ||
|
||
# Ensure same number of entries | ||
self.assertEqual(len(return_value), len(expected_value)) | ||
|
||
def test_get_message_history_user_difference(self): | ||
# Default Message | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=self.user, | ||
role=self.role, | ||
content='Expected content of message', | ||
) | ||
|
||
# New message w/ new user | ||
new_user = User(username='not_tester', email='[email protected]') | ||
new_user.save() | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=new_user, | ||
role=self.role, | ||
content='Expected content of message', | ||
) | ||
|
||
message_count = 2 | ||
return_value = get_message_history(self.course_id, self.user, message_count) | ||
|
||
expected_value = LearningAssistantMessage.objects.filter( | ||
course_id=self.course_id, user=self.user).order_by('-created')[:message_count] | ||
|
||
# Ensure we filtered one of the two present messages | ||
self.assertNotEqual(len(return_value), LearningAssistantMessage.objects.count()) | ||
|
||
# Ensure same number of entries | ||
self.assertEqual(len(return_value), len(expected_value)) | ||
|
||
# Ensure values are as expected for all LearningAssistantMessage instances | ||
for i, return_value in enumerate(return_value): | ||
self.assertEqual(return_value.course_id, expected_value[i].course_id) | ||
self.assertEqual(return_value.user, expected_value[i].user) | ||
self.assertEqual(return_value.role, expected_value[i].role) | ||
self.assertEqual(return_value.content, expected_value[i].content) | ||
|
||
def test_get_message_course_id_differences(self): | ||
# Default Message | ||
LearningAssistantMessage.objects.create( | ||
course_id=self.course_id, | ||
user=self.user, | ||
role=self.role, | ||
content='Expected content of message', | ||
) | ||
|
||
# New message | ||
wrong_course_id = 'course-v1:wrong+id+1' | ||
LearningAssistantMessage.objects.create( | ||
course_id=wrong_course_id, | ||
user=self.user, | ||
role=self.role, | ||
content='Expected content of message', | ||
) | ||
|
||
message_count = 2 | ||
return_value = get_message_history(self.course_id, self.user, message_count) | ||
|
||
expected_value = LearningAssistantMessage.objects.filter( | ||
course_id=self.course_id, user=self.user).order_by('-created')[:message_count] | ||
|
||
# Ensure we filtered one of the two present messages | ||
self.assertNotEqual(len(return_value), LearningAssistantMessage.objects.count()) | ||
|
||
# Ensure same number of entries | ||
self.assertEqual(len(return_value), len(expected_value)) | ||
|
||
# Ensure values are as expected for all LearningAssistantMessage instances | ||
for i, return_value in enumerate(return_value): | ||
self.assertEqual(return_value.course_id, expected_value[i].course_id) | ||
self.assertEqual(return_value.user, expected_value[i].user) | ||
self.assertEqual(return_value.role, expected_value[i].role) | ||
self.assertEqual(return_value.content, expected_value[i].content) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: If the message count exceeds the number of messages, we could return messages that are less than this value (in other words, the number of messages cannot exceed this value but could be less than this value).