Skip to content

Commit

Permalink
Fix flake8 failures and warnings then enforce PEP8
Browse files Browse the repository at this point in the history
This fixes any flake8 warnings across all tests and bot.py

NOTE: While here I also improved test_check_rate_limit.py
  • Loading branch information
johndotpub committed Jan 8, 2024
1 parent a7ed9bb commit be8eb5e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/flake8-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --max-complexity=10 --statistics
- name: Test with pytest
run: |
pytest
36 changes: 28 additions & 8 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from openai import OpenAI
from websockets.exceptions import ConnectionClosed


# Define the function to parse command-line arguments
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='GPT-based Discord bot.')
Expand All @@ -36,7 +37,10 @@ def load_configuration(config_file: str) -> configparser.ConfigParser:
return config


def set_activity_status(activity_type: str, activity_status: str) -> discord.Activity:
def set_activity_status(
activity_type: str,
activity_status: str
) -> discord.Activity:
"""
Return discord.Activity object with specified activity type and status
"""
Expand Down Expand Up @@ -79,7 +83,12 @@ def get_conversation_summary(conversation: list[dict]) -> list[dict]:
return summary


async def check_rate_limit(user: discord.User) -> bool:
async def check_rate_limit(
user: discord.User,
logger: logging.Logger = None
) -> bool:
if logger is None:
logger = logging.getLogger(__name__)
"""
Check if a user has exceeded the rate limit for sending messages.
"""
Expand All @@ -99,7 +108,11 @@ async def check_rate_limit(user: discord.User) -> bool:
return False


async def process_input_message(input_message: str, user: discord.User, conversation_summary: list[dict]) -> str:
async def process_input_message(
input_message: str,
user: discord.User,
conversation_summary: list[dict]
) -> str:
"""
Process an input message using OpenAI's GPT model.
"""
Expand Down Expand Up @@ -144,7 +157,10 @@ async def process_input_message(input_message: str, user: discord.User, conversa
else:
response_content = None
except AttributeError:
logger.error("Failed to get response from OpenAI API: Invalid response format.")
logger.error(
"Failed to get response from OpenAI API: "
"Invalid response format."
)
return "Sorry, an error occurred while processing the message."

if response_content:
Expand All @@ -153,13 +169,15 @@ async def process_input_message(input_message: str, user: discord.User, conversa
# logger.info(f"Raw API response: {response}")
logger.info(f"Sent the response: {response_content}")

conversation.append({"role": "assistant", "content": response_content})
conversation.append(
{"role": "assistant", "content": response_content}
)
conversation_history[user.id] = conversation

return response_content
else:
logger.error("Failed to get response from OpenAI API: No text in response.")
return "Sorry, I didn't get that. Could you rephrase or ask something else?"
logger.error("OpenAI API error: No response text.")
return "Sorry, I didn't get that. Can you rephrase or ask again?"

except ConnectionClosed as error:
logger.error(f"WebSocket connection closed: {error}")
Expand Down Expand Up @@ -197,7 +215,9 @@ async def process_input_message(input_message: str, user: discord.User, conversa

OPENAI_API_KEY = config.get('OpenAI', 'OPENAI_API_KEY')
OPENAI_TIMEOUT = config.getint('OpenAI', 'OPENAI_TIMEOUT', fallback='30')
GPT_MODEL = config.get('OpenAI', 'GPT_MODEL', fallback='gpt-3.5-turbo-1106')
GPT_MODEL = config.get(
'OpenAI', 'GPT_MODEL', fallback='gpt-3.5-turbo-1106'
)
GPT_TOKENS = config.getint('OpenAI', 'GPT_TOKENS', fallback=4096)
SYSTEM_MESSAGE = config.get(
'OpenAI', 'SYSTEM_MESSAGE', fallback='You are a helpful assistant.'
Expand Down
80 changes: 53 additions & 27 deletions tests/test_check_rate_limit.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,25 @@
from contextlib import contextmanager
import time
from unittest.mock import AsyncMock
import asyncio
import pytest
import logging

import bot
from bot import check_rate_limit

from contextlib import contextmanager


@pytest.mark.asyncio
async def test_check_rate_limit():
user = AsyncMock()
user.id = 123

async def run_test():
with patch_variables(
bot, 'last_command_timestamps', {user.id: time.time() - 60}
), patch_variables(bot, 'last_command_count', {user.id: 0}), \
patch_variables(bot, 'RATE_LIMIT_PER', RATE_LIMIT_PER), \
patch_variables(bot, 'RATE_LIMIT', RATE_LIMIT):
result = await check_rate_limit(user)
assert result is True
assert last_command_count[user.id] == 1

last_command_count[user.id] = 3
result = await check_rate_limit(user)
assert result is False
assert last_command_count[user.id] == 3
# Define a placeholder logger
logger = logging.getLogger('pytest_logger')
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)

last_command_timestamps[user.id] = time.time() - RATE_LIMIT_PER - 1
result = await check_rate_limit(user)
assert result is True
assert last_command_count[user.id] == 1
RATE_LIMIT = 10
RATE_LIMIT_PER = 60

await run_test()
last_command_count = {}
last_command_timestamps = {}


@contextmanager
Expand All @@ -45,3 +31,43 @@ def patch_variables(module, variable_name, value):
yield
finally:
setattr(module, variable_name, original_value)


async def run_test(user):
with patch_variables(
bot,
'last_command_timestamps',
{user.id: time.time() - 60}
), patch_variables(
bot,
'last_command_count',
{user.id: 0}
), patch_variables(
bot,
'RATE_LIMIT_PER',
RATE_LIMIT_PER
), patch_variables(
bot,
'RATE_LIMIT',
RATE_LIMIT
):
result = await check_rate_limit(user, logger)
assert result is True
assert bot.last_command_count.get(user.id, 0) == 1

bot.last_command_count[user.id] = RATE_LIMIT
result = await check_rate_limit(user, logger)
assert result is False
assert bot.last_command_count.get(user.id, 0) == RATE_LIMIT

bot.last_command_timestamps[user.id] = time.time() - RATE_LIMIT_PER - 1
result = await check_rate_limit(user, logger)
assert result is True
assert bot.last_command_count.get(user.id, 0) == 1


@pytest.mark.asyncio
async def test_check_rate_limit():
user = AsyncMock()
user.id = 123
await run_test(user)

0 comments on commit be8eb5e

Please sign in to comment.