Skip to content

Commit

Permalink
Merge pull request #25 from rlywtf/flake8_fixes
Browse files Browse the repository at this point in the history
Fix flake8 failures and warnings then enforce PEP8
  • Loading branch information
johndotpub authored Jan 8, 2024
2 parents a7ed9bb + be8eb5e commit 7e4cefe
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/flake8-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --max-complexity=10 --statistics
- name: Test with pytest
run: |
pytest
36 changes: 28 additions & 8 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from openai import OpenAI
from websockets.exceptions import ConnectionClosed


# Define the function to parse command-line arguments
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='GPT-based Discord bot.')
Expand All @@ -36,7 +37,10 @@ def load_configuration(config_file: str) -> configparser.ConfigParser:
return config


def set_activity_status(activity_type: str, activity_status: str) -> discord.Activity:
def set_activity_status(
activity_type: str,
activity_status: str
) -> discord.Activity:
"""
Return discord.Activity object with specified activity type and status
"""
Expand Down Expand Up @@ -79,7 +83,12 @@ def get_conversation_summary(conversation: list[dict]) -> list[dict]:
return summary


async def check_rate_limit(user: discord.User) -> bool:
async def check_rate_limit(
user: discord.User,
logger: logging.Logger = None
) -> bool:
if logger is None:
logger = logging.getLogger(__name__)
"""
Check if a user has exceeded the rate limit for sending messages.
"""
Expand All @@ -99,7 +108,11 @@ async def check_rate_limit(user: discord.User) -> bool:
return False


async def process_input_message(input_message: str, user: discord.User, conversation_summary: list[dict]) -> str:
async def process_input_message(
input_message: str,
user: discord.User,
conversation_summary: list[dict]
) -> str:
"""
Process an input message using OpenAI's GPT model.
"""
Expand Down Expand Up @@ -144,7 +157,10 @@ async def process_input_message(input_message: str, user: discord.User, conversa
else:
response_content = None
except AttributeError:
logger.error("Failed to get response from OpenAI API: Invalid response format.")
logger.error(
"Failed to get response from OpenAI API: "
"Invalid response format."
)
return "Sorry, an error occurred while processing the message."

if response_content:
Expand All @@ -153,13 +169,15 @@ async def process_input_message(input_message: str, user: discord.User, conversa
# logger.info(f"Raw API response: {response}")
logger.info(f"Sent the response: {response_content}")

conversation.append({"role": "assistant", "content": response_content})
conversation.append(
{"role": "assistant", "content": response_content}
)
conversation_history[user.id] = conversation

return response_content
else:
logger.error("Failed to get response from OpenAI API: No text in response.")
return "Sorry, I didn't get that. Could you rephrase or ask something else?"
logger.error("OpenAI API error: No response text.")
return "Sorry, I didn't get that. Can you rephrase or ask again?"

except ConnectionClosed as error:
logger.error(f"WebSocket connection closed: {error}")
Expand Down Expand Up @@ -197,7 +215,9 @@ async def process_input_message(input_message: str, user: discord.User, conversa

OPENAI_API_KEY = config.get('OpenAI', 'OPENAI_API_KEY')
OPENAI_TIMEOUT = config.getint('OpenAI', 'OPENAI_TIMEOUT', fallback='30')
GPT_MODEL = config.get('OpenAI', 'GPT_MODEL', fallback='gpt-3.5-turbo-1106')
GPT_MODEL = config.get(
'OpenAI', 'GPT_MODEL', fallback='gpt-3.5-turbo-1106'
)
GPT_TOKENS = config.getint('OpenAI', 'GPT_TOKENS', fallback=4096)
SYSTEM_MESSAGE = config.get(
'OpenAI', 'SYSTEM_MESSAGE', fallback='You are a helpful assistant.'
Expand Down
80 changes: 53 additions & 27 deletions tests/test_check_rate_limit.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,25 @@
from contextlib import contextmanager
import time
from unittest.mock import AsyncMock
import asyncio
import pytest
import logging

import bot
from bot import check_rate_limit

from contextlib import contextmanager


@pytest.mark.asyncio
async def test_check_rate_limit():
user = AsyncMock()
user.id = 123

async def run_test():
with patch_variables(
bot, 'last_command_timestamps', {user.id: time.time() - 60}
), patch_variables(bot, 'last_command_count', {user.id: 0}), \
patch_variables(bot, 'RATE_LIMIT_PER', RATE_LIMIT_PER), \
patch_variables(bot, 'RATE_LIMIT', RATE_LIMIT):
result = await check_rate_limit(user)
assert result is True
assert last_command_count[user.id] == 1

last_command_count[user.id] = 3
result = await check_rate_limit(user)
assert result is False
assert last_command_count[user.id] == 3
# Define a placeholder logger
logger = logging.getLogger('pytest_logger')
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)

last_command_timestamps[user.id] = time.time() - RATE_LIMIT_PER - 1
result = await check_rate_limit(user)
assert result is True
assert last_command_count[user.id] == 1
RATE_LIMIT = 10
RATE_LIMIT_PER = 60

await run_test()
last_command_count = {}
last_command_timestamps = {}


@contextmanager
Expand All @@ -45,3 +31,43 @@ def patch_variables(module, variable_name, value):
yield
finally:
setattr(module, variable_name, original_value)


async def run_test(user):
with patch_variables(
bot,
'last_command_timestamps',
{user.id: time.time() - 60}
), patch_variables(
bot,
'last_command_count',
{user.id: 0}
), patch_variables(
bot,
'RATE_LIMIT_PER',
RATE_LIMIT_PER
), patch_variables(
bot,
'RATE_LIMIT',
RATE_LIMIT
):
result = await check_rate_limit(user, logger)
assert result is True
assert bot.last_command_count.get(user.id, 0) == 1

bot.last_command_count[user.id] = RATE_LIMIT
result = await check_rate_limit(user, logger)
assert result is False
assert bot.last_command_count.get(user.id, 0) == RATE_LIMIT

bot.last_command_timestamps[user.id] = time.time() - RATE_LIMIT_PER - 1
result = await check_rate_limit(user, logger)
assert result is True
assert bot.last_command_count.get(user.id, 0) == 1


@pytest.mark.asyncio
async def test_check_rate_limit():
user = AsyncMock()
user.id = 123
await run_test(user)

0 comments on commit 7e4cefe

Please sign in to comment.