diff --git a/.github/workflows/flake8-pytest.yml b/.github/workflows/flake8-pytest.yml index 3ecc369..862e6bc 100644 --- a/.github/workflows/flake8-pytest.yml +++ b/.github/workflows/flake8-pytest.yml @@ -33,7 +33,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --count --max-complexity=10 --statistics - name: Test with pytest run: | pytest diff --git a/bot.py b/bot.py index 616ee21..4b44d64 100644 --- a/bot.py +++ b/bot.py @@ -12,6 +12,7 @@ from openai import OpenAI from websockets.exceptions import ConnectionClosed + # Define the function to parse command-line arguments def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser(description='GPT-based Discord bot.') @@ -36,7 +37,10 @@ def load_configuration(config_file: str) -> configparser.ConfigParser: return config -def set_activity_status(activity_type: str, activity_status: str) -> discord.Activity: +def set_activity_status( + activity_type: str, + activity_status: str +) -> discord.Activity: """ Return discord.Activity object with specified activity type and status """ @@ -79,7 +83,12 @@ def get_conversation_summary(conversation: list[dict]) -> list[dict]: return summary -async def check_rate_limit(user: discord.User) -> bool: +async def check_rate_limit( + user: discord.User, + logger: logging.Logger = None +) -> bool: + if logger is None: + logger = logging.getLogger(__name__) """ Check if a user has exceeded the rate limit for sending messages. """ @@ -99,7 +108,11 @@ async def check_rate_limit(user: discord.User) -> bool: return False -async def process_input_message(input_message: str, user: discord.User, conversation_summary: list[dict]) -> str: +async def process_input_message( + input_message: str, + user: discord.User, + conversation_summary: list[dict] +) -> str: """ Process an input message using OpenAI's GPT model. """ @@ -144,7 +157,10 @@ async def process_input_message(input_message: str, user: discord.User, conversa else: response_content = None except AttributeError: - logger.error("Failed to get response from OpenAI API: Invalid response format.") + logger.error( + "Failed to get response from OpenAI API: " + "Invalid response format." + ) return "Sorry, an error occurred while processing the message." if response_content: @@ -153,13 +169,15 @@ async def process_input_message(input_message: str, user: discord.User, conversa # logger.info(f"Raw API response: {response}") logger.info(f"Sent the response: {response_content}") - conversation.append({"role": "assistant", "content": response_content}) + conversation.append( + {"role": "assistant", "content": response_content} + ) conversation_history[user.id] = conversation return response_content else: - logger.error("Failed to get response from OpenAI API: No text in response.") - return "Sorry, I didn't get that. Could you rephrase or ask something else?" + logger.error("OpenAI API error: No response text.") + return "Sorry, I didn't get that. Can you rephrase or ask again?" except ConnectionClosed as error: logger.error(f"WebSocket connection closed: {error}") @@ -197,7 +215,9 @@ async def process_input_message(input_message: str, user: discord.User, conversa OPENAI_API_KEY = config.get('OpenAI', 'OPENAI_API_KEY') OPENAI_TIMEOUT = config.getint('OpenAI', 'OPENAI_TIMEOUT', fallback='30') - GPT_MODEL = config.get('OpenAI', 'GPT_MODEL', fallback='gpt-3.5-turbo-1106') + GPT_MODEL = config.get( + 'OpenAI', 'GPT_MODEL', fallback='gpt-3.5-turbo-1106' + ) GPT_TOKENS = config.getint('OpenAI', 'GPT_TOKENS', fallback=4096) SYSTEM_MESSAGE = config.get( 'OpenAI', 'SYSTEM_MESSAGE', fallback='You are a helpful assistant.' diff --git a/tests/test_check_rate_limit.py b/tests/test_check_rate_limit.py index b5e7961..34185c6 100644 --- a/tests/test_check_rate_limit.py +++ b/tests/test_check_rate_limit.py @@ -1,39 +1,25 @@ +from contextlib import contextmanager import time from unittest.mock import AsyncMock -import asyncio import pytest +import logging import bot +from bot import check_rate_limit -from contextlib import contextmanager - - -@pytest.mark.asyncio -async def test_check_rate_limit(): - user = AsyncMock() - user.id = 123 - -async def run_test(): - with patch_variables( - bot, 'last_command_timestamps', {user.id: time.time() - 60} - ), patch_variables(bot, 'last_command_count', {user.id: 0}), \ - patch_variables(bot, 'RATE_LIMIT_PER', RATE_LIMIT_PER), \ - patch_variables(bot, 'RATE_LIMIT', RATE_LIMIT): - result = await check_rate_limit(user) - assert result is True - assert last_command_count[user.id] == 1 - last_command_count[user.id] = 3 - result = await check_rate_limit(user) - assert result is False - assert last_command_count[user.id] == 3 +# Define a placeholder logger +logger = logging.getLogger('pytest_logger') +logger.setLevel(logging.DEBUG) +handler = logging.StreamHandler() +handler.setLevel(logging.DEBUG) +logger.addHandler(handler) - last_command_timestamps[user.id] = time.time() - RATE_LIMIT_PER - 1 - result = await check_rate_limit(user) - assert result is True - assert last_command_count[user.id] == 1 +RATE_LIMIT = 10 +RATE_LIMIT_PER = 60 - await run_test() +last_command_count = {} +last_command_timestamps = {} @contextmanager @@ -45,3 +31,43 @@ def patch_variables(module, variable_name, value): yield finally: setattr(module, variable_name, original_value) + + +async def run_test(user): + with patch_variables( + bot, + 'last_command_timestamps', + {user.id: time.time() - 60} + ), patch_variables( + bot, + 'last_command_count', + {user.id: 0} + ), patch_variables( + bot, + 'RATE_LIMIT_PER', + RATE_LIMIT_PER + ), patch_variables( + bot, + 'RATE_LIMIT', + RATE_LIMIT + ): + result = await check_rate_limit(user, logger) + assert result is True + assert bot.last_command_count.get(user.id, 0) == 1 + + bot.last_command_count[user.id] = RATE_LIMIT + result = await check_rate_limit(user, logger) + assert result is False + assert bot.last_command_count.get(user.id, 0) == RATE_LIMIT + + bot.last_command_timestamps[user.id] = time.time() - RATE_LIMIT_PER - 1 + result = await check_rate_limit(user, logger) + assert result is True + assert bot.last_command_count.get(user.id, 0) == 1 + + +@pytest.mark.asyncio +async def test_check_rate_limit(): + user = AsyncMock() + user.id = 123 + await run_test(user)