From e7ffd7245ea9f603f8d2b0c2adcf7e96b3dc2b90 Mon Sep 17 00:00:00 2001 From: Alie Langston Date: Mon, 11 Sep 2023 10:41:19 -0400 Subject: [PATCH] feat: use reduced message to avoid maxing out tokens --- CHANGELOG.rst | 4 +++ learning_assistant/__init__.py | 2 +- learning_assistant/utils.py | 41 ++++++++++++++++++++++++-- learning_assistant/views.py | 2 +- tests/test_utils.py | 53 +++++++++++++++++++++++++++++----- 5 files changed, 90 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index beb5678..6841ee5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,6 +11,10 @@ Change Log .. There should always be an "Unreleased" section for changes pending release. +1.4.0 - 2023-09-11 +****************** +* Send reduced message list if needed to avoid going over token limit + 1.3.3 - 2023-09-07 ****************** * Allow any enrolled learner to access API. diff --git a/learning_assistant/__init__.py b/learning_assistant/__init__.py index 5f9a684..60ac235 100644 --- a/learning_assistant/__init__.py +++ b/learning_assistant/__init__.py @@ -2,6 +2,6 @@ Plugin for a learning assistant backend, intended for use within edx-platform. """ -__version__ = '1.3.3' +__version__ = '1.4.0' default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name diff --git a/learning_assistant/utils.py b/learning_assistant/utils.py index 5920cd0..110f449 100644 --- a/learning_assistant/utils.py +++ b/learning_assistant/utils.py @@ -12,7 +12,42 @@ log = logging.getLogger(__name__) -def get_chat_response(message_list): +def _estimated_message_tokens(message): + """ + Estimates how many tokens are in a given message. + """ + chars_per_token = 3.5 + json_padding = 8 + + return int((len(message) - message.count(' ')) / chars_per_token) + json_padding + + +def get_reduced_message_list(system_list, message_list): + """ + If messages are larger than allotted token amount, return a smaller list of messages. + """ + total_system_tokens = sum(_estimated_message_tokens(system_message['content']) for system_message in system_list) + + max_tokens = getattr(settings, 'CHAT_COMPLETION_MAX_TOKENS', 16385) + response_tokens = getattr(settings, 'CHAT_COMPLETION_RESPONSE_TOKENS', 1000) + remaining_tokens = max_tokens - response_tokens - total_system_tokens + + new_message_list = [] + total_message_tokens = 0 + + while total_message_tokens < remaining_tokens and len(message_list) != 0: + new_message = message_list.pop() + total_message_tokens += _estimated_message_tokens(new_message['content']) + if total_message_tokens >= remaining_tokens: + break + + # insert message at beginning of list, because we are traversing the message list from most recent to oldest + new_message_list.insert(0, new_message) + + return new_message_list + + +def get_chat_response(system_list, message_list): """ Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting. """ @@ -22,7 +57,9 @@ def get_chat_response(message_list): headers = {'Content-Type': 'application/json', 'x-api-key': completion_endpoint_key} connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1) read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15) - body = {'message_list': message_list} + + reduced_messages = get_reduced_message_list(system_list, message_list) + body = {'message_list': system_list + reduced_messages} try: response = requests.post( diff --git a/learning_assistant/views.py b/learning_assistant/views.py index 5a7d5fa..4749008 100644 --- a/learning_assistant/views.py +++ b/learning_assistant/views.py @@ -94,6 +94,6 @@ def post(self, request, course_id): 'course_id': course_id } ) - status_code, message = get_chat_response(message_setup + message_list) + status_code, message = get_chat_response(message_setup, message_list) return Response(status=status_code, data=message) diff --git a/tests/test_utils.py b/tests/test_utils.py index fad8966..27afa8e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ """ Tests for the utils functions """ +import copy import json from unittest.mock import MagicMock, patch @@ -10,7 +11,7 @@ from django.test import TestCase, override_settings from requests.exceptions import ConnectTimeout -from learning_assistant.utils import get_chat_response +from learning_assistant.utils import get_chat_response, get_reduced_message_list @ddt.ddt @@ -20,6 +21,10 @@ class GetChatResponseTests(TestCase): """ def setUp(self): super().setUp() + self.system_message = [ + {'role': 'system', 'content': 'Do this'}, + {'role': 'system', 'content': 'Do that'}, + ] self.message_list = [ {'role': 'assistant', 'content': 'Hello'}, {'role': 'user', 'content': 'Goodbye'}, @@ -27,13 +32,13 @@ def setUp(self): @override_settings(CHAT_COMPLETION_API=None) def test_no_endpoint_setting(self): - status_code, message = get_chat_response(self.message_list) + status_code, message = get_chat_response(self.system_message, self.message_list) self.assertEqual(status_code, 404) self.assertEqual(message, 'Completion endpoint is not defined.') @override_settings(CHAT_COMPLETION_API_KEY=None) def test_no_endpoint_key_setting(self): - status_code, message = get_chat_response(self.message_list) + status_code, message = get_chat_response(self.system_message, self.message_list) self.assertEqual(status_code, 404) self.assertEqual(message, 'Completion endpoint is not defined.') @@ -47,7 +52,7 @@ def test_200_response(self): body=json.dumps(message_response), ) - status_code, message = get_chat_response(self.message_list) + status_code, message = get_chat_response(self.system_message, self.message_list) self.assertEqual(status_code, 200) self.assertEqual(message, message_response) @@ -61,7 +66,7 @@ def test_non_200_response(self): body=json.dumps(message_response), ) - status_code, message = get_chat_response(self.message_list) + status_code, message = get_chat_response(self.system_message, self.message_list) self.assertEqual(status_code, 500) self.assertEqual(message, message_response) @@ -72,7 +77,7 @@ def test_non_200_response(self): @patch('learning_assistant.utils.requests') def test_timeout(self, exception, mock_requests): mock_requests.post = MagicMock(side_effect=exception()) - status_code, _ = get_chat_response(self.message_list) + status_code, _ = get_chat_response(self.system_message, self.message_list) self.assertEqual(status_code, 502) @patch('learning_assistant.utils.requests') @@ -83,12 +88,44 @@ def test_post_request_structure(self, mock_requests): connect_timeout = settings.CHAT_COMPLETION_API_CONNECT_TIMEOUT read_timeout = settings.CHAT_COMPLETION_API_READ_TIMEOUT headers = {'Content-Type': 'application/json', 'x-api-key': settings.CHAT_COMPLETION_API_KEY} - body = json.dumps({'message_list': self.message_list}) + body = json.dumps({'message_list': self.system_message + self.message_list}) - get_chat_response(self.message_list) + get_chat_response(self.system_message, self.message_list) mock_requests.post.assert_called_with( completion_endpoint, headers=headers, data=body, timeout=(connect_timeout, read_timeout) ) + + +class GetReducedMessageListTests(TestCase): + """ + Tests for the _reduced_message_list helper function + """ + def setUp(self): + super().setUp() + self.system_message = [ + {'role': 'system', 'content': 'Do this'}, + {'role': 'system', 'content': 'Do that'}, + ] + self.message_list = [ + {'role': 'assistant', 'content': 'Hello'}, + {'role': 'user', 'content': 'Goodbye'}, + ] + + @override_settings(CHAT_COMPLETION_MAX_TOKENS=30) + @override_settings(CHAT_COMPLETION_RESPONSE_TOKENS=1) + def test_message_list_reduced(self): + """ + If the number of tokens in the message list is greater than allowed, assert that messages are removed + """ + # pass in copy of list, as it is modified as part of the reduction + reduced_message_list = get_reduced_message_list(self.system_message, copy.deepcopy(self.message_list)) + self.assertEqual(len(reduced_message_list), 1) + self.assertEqual(reduced_message_list, self.message_list[-1:]) + + def test_message_list(self): + reduced_message_list = get_reduced_message_list(self.system_message, copy.deepcopy(self.message_list)) + self.assertEqual(len(reduced_message_list), 2) + self.assertEqual(reduced_message_list, self.message_list)