Skip to content

Commit

Permalink
Merge pull request #21 from edx/alangsto/update_message_length
Browse files Browse the repository at this point in the history
feat: use reduced message to avoid maxing out tokens
  • Loading branch information
alangsto authored Sep 11, 2023
2 parents fe8f309 + e7ffd72 commit 2ee7e74
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ Change Log

.. There should always be an "Unreleased" section for changes pending release.
1.4.0 - 2023-09-11
******************
* Send reduced message list if needed to avoid going over token limit

1.3.3 - 2023-09-07
******************
* Allow any enrolled learner to access API.
Expand Down
2 changes: 1 addition & 1 deletion learning_assistant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Plugin for a learning assistant backend, intended for use within edx-platform.
"""

__version__ = '1.3.3'
__version__ = '1.4.0'

default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name
41 changes: 39 additions & 2 deletions learning_assistant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,42 @@
log = logging.getLogger(__name__)


def get_chat_response(message_list):
def _estimated_message_tokens(message):
"""
Estimates how many tokens are in a given message.
"""
chars_per_token = 3.5
json_padding = 8

return int((len(message) - message.count(' ')) / chars_per_token) + json_padding


def get_reduced_message_list(system_list, message_list):
"""
If messages are larger than allotted token amount, return a smaller list of messages.
"""
total_system_tokens = sum(_estimated_message_tokens(system_message['content']) for system_message in system_list)

max_tokens = getattr(settings, 'CHAT_COMPLETION_MAX_TOKENS', 16385)
response_tokens = getattr(settings, 'CHAT_COMPLETION_RESPONSE_TOKENS', 1000)
remaining_tokens = max_tokens - response_tokens - total_system_tokens

new_message_list = []
total_message_tokens = 0

while total_message_tokens < remaining_tokens and len(message_list) != 0:
new_message = message_list.pop()
total_message_tokens += _estimated_message_tokens(new_message['content'])
if total_message_tokens >= remaining_tokens:
break

# insert message at beginning of list, because we are traversing the message list from most recent to oldest
new_message_list.insert(0, new_message)

return new_message_list


def get_chat_response(system_list, message_list):
"""
Pass message list to chat endpoint, as defined by the CHAT_COMPLETION_API setting.
"""
Expand All @@ -22,7 +57,9 @@ def get_chat_response(message_list):
headers = {'Content-Type': 'application/json', 'x-api-key': completion_endpoint_key}
connect_timeout = getattr(settings, 'CHAT_COMPLETION_API_CONNECT_TIMEOUT', 1)
read_timeout = getattr(settings, 'CHAT_COMPLETION_API_READ_TIMEOUT', 15)
body = {'message_list': message_list}

reduced_messages = get_reduced_message_list(system_list, message_list)
body = {'message_list': system_list + reduced_messages}

try:
response = requests.post(
Expand Down
2 changes: 1 addition & 1 deletion learning_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ def post(self, request, course_id):
'course_id': course_id
}
)
status_code, message = get_chat_response(message_setup + message_list)
status_code, message = get_chat_response(message_setup, message_list)

return Response(status=status_code, data=message)
53 changes: 45 additions & 8 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for the utils functions
"""
import copy
import json
from unittest.mock import MagicMock, patch

Expand All @@ -10,7 +11,7 @@
from django.test import TestCase, override_settings
from requests.exceptions import ConnectTimeout

from learning_assistant.utils import get_chat_response
from learning_assistant.utils import get_chat_response, get_reduced_message_list


@ddt.ddt
Expand All @@ -20,20 +21,24 @@ class GetChatResponseTests(TestCase):
"""
def setUp(self):
super().setUp()
self.system_message = [
{'role': 'system', 'content': 'Do this'},
{'role': 'system', 'content': 'Do that'},
]
self.message_list = [
{'role': 'assistant', 'content': 'Hello'},
{'role': 'user', 'content': 'Goodbye'},
]

@override_settings(CHAT_COMPLETION_API=None)
def test_no_endpoint_setting(self):
status_code, message = get_chat_response(self.message_list)
status_code, message = get_chat_response(self.system_message, self.message_list)
self.assertEqual(status_code, 404)
self.assertEqual(message, 'Completion endpoint is not defined.')

@override_settings(CHAT_COMPLETION_API_KEY=None)
def test_no_endpoint_key_setting(self):
status_code, message = get_chat_response(self.message_list)
status_code, message = get_chat_response(self.system_message, self.message_list)
self.assertEqual(status_code, 404)
self.assertEqual(message, 'Completion endpoint is not defined.')

Expand All @@ -47,7 +52,7 @@ def test_200_response(self):
body=json.dumps(message_response),
)

status_code, message = get_chat_response(self.message_list)
status_code, message = get_chat_response(self.system_message, self.message_list)
self.assertEqual(status_code, 200)
self.assertEqual(message, message_response)

Expand All @@ -61,7 +66,7 @@ def test_non_200_response(self):
body=json.dumps(message_response),
)

status_code, message = get_chat_response(self.message_list)
status_code, message = get_chat_response(self.system_message, self.message_list)
self.assertEqual(status_code, 500)
self.assertEqual(message, message_response)

Expand All @@ -72,7 +77,7 @@ def test_non_200_response(self):
@patch('learning_assistant.utils.requests')
def test_timeout(self, exception, mock_requests):
mock_requests.post = MagicMock(side_effect=exception())
status_code, _ = get_chat_response(self.message_list)
status_code, _ = get_chat_response(self.system_message, self.message_list)
self.assertEqual(status_code, 502)

@patch('learning_assistant.utils.requests')
Expand All @@ -83,12 +88,44 @@ def test_post_request_structure(self, mock_requests):
connect_timeout = settings.CHAT_COMPLETION_API_CONNECT_TIMEOUT
read_timeout = settings.CHAT_COMPLETION_API_READ_TIMEOUT
headers = {'Content-Type': 'application/json', 'x-api-key': settings.CHAT_COMPLETION_API_KEY}
body = json.dumps({'message_list': self.message_list})
body = json.dumps({'message_list': self.system_message + self.message_list})

get_chat_response(self.message_list)
get_chat_response(self.system_message, self.message_list)
mock_requests.post.assert_called_with(
completion_endpoint,
headers=headers,
data=body,
timeout=(connect_timeout, read_timeout)
)


class GetReducedMessageListTests(TestCase):
"""
Tests for the _reduced_message_list helper function
"""
def setUp(self):
super().setUp()
self.system_message = [
{'role': 'system', 'content': 'Do this'},
{'role': 'system', 'content': 'Do that'},
]
self.message_list = [
{'role': 'assistant', 'content': 'Hello'},
{'role': 'user', 'content': 'Goodbye'},
]

@override_settings(CHAT_COMPLETION_MAX_TOKENS=30)
@override_settings(CHAT_COMPLETION_RESPONSE_TOKENS=1)
def test_message_list_reduced(self):
"""
If the number of tokens in the message list is greater than allowed, assert that messages are removed
"""
# pass in copy of list, as it is modified as part of the reduction
reduced_message_list = get_reduced_message_list(self.system_message, copy.deepcopy(self.message_list))
self.assertEqual(len(reduced_message_list), 1)
self.assertEqual(reduced_message_list, self.message_list[-1:])

def test_message_list(self):
reduced_message_list = get_reduced_message_list(self.system_message, copy.deepcopy(self.message_list))
self.assertEqual(len(reduced_message_list), 2)
self.assertEqual(reduced_message_list, self.message_list)

0 comments on commit 2ee7e74

Please sign in to comment.