From 9bca4caf66478945a8571e46f708a0f0992105eb Mon Sep 17 00:00:00 2001 From: rlywtf Date: Fri, 2 Feb 2024 18:16:49 -0800 Subject: [PATCH] Fix session handling We need to handle session disconnects, resuming, and changes in shard states. This fixes some of the errors we were seeing with: - Shard ID None WebSocket closed - Shard ID None heartbeat blocked for more than This also makes a new `call_openai_api` function definition and turns the response into an async call on a thread. The rate limit function was cleaned up while I was here. The Flake8 complexity was reduced to it's default of 10 and the main function was commented to avoid Flake8 complexity issues. --- .github/workflows/flake8-pytest.yml | 2 +- bot.py | 69 +++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/.github/workflows/flake8-pytest.yml b/.github/workflows/flake8-pytest.yml index 5cd6494..2ffd110 100644 --- a/.github/workflows/flake8-pytest.yml +++ b/.github/workflows/flake8-pytest.yml @@ -33,7 +33,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --max-complexity=20 --statistics --max-line-length=99 + flake8 . --count --max-complexity=10 --statistics --max-line-length=99 - name: Test with pytest run: | pytest diff --git a/bot.py b/bot.py index dcf1eef..2058b8c 100644 --- a/bot.py +++ b/bot.py @@ -105,26 +105,31 @@ async def check_rate_limit( user: discord.User, logger: logging.Logger = None ) -> bool: + """ + Check if a user has exceeded the rate limit for sending messages. + """ + if logger is None: + logger = logging.getLogger(__name__) + try: - if logger is None: - logger = logging.getLogger(__name__) - """ - Check if a user has exceeded the rate limit for sending messages. - """ current_time = time.time() last_command_timestamp = last_command_timestamps.get(user.id, 0) last_command_count_user = last_command_count.get(user.id, 0) + if current_time - last_command_timestamp > RATE_LIMIT_PER: last_command_timestamps[user.id] = current_time last_command_count[user.id] = 1 logger.info(f"Rate limit passed for user: {user}") return True + if last_command_count_user < RATE_LIMIT: last_command_count[user.id] += 1 logger.info(f"Rate limit passed for user: {user}") return True + logger.info(f"Rate limit exceeded for user: {user}") return False + except Exception as e: logger.error(f"Error checking rate limit: {e}") raise @@ -161,16 +166,19 @@ async def process_input_message( # Log the current conversation history # logger.info(f"Current conversation history: {conversation}") - response = client.chat.completions.create( - model=GPT_MODEL, - messages=[ - {"role": "system", "content": SYSTEM_MESSAGE}, - *conversation_summary, - {"role": "user", "content": input_message} - ], - max_tokens=max_tokens, - temperature=0.7 - ) + def call_openai_api(): + return client.chat.completions.create( + model=GPT_MODEL, + messages=[ + {"role": "system", "content": SYSTEM_MESSAGE}, + *conversation_summary, + {"role": "user", "content": input_message} + ], + max_tokens=max_tokens, + temperature=0.7 + ) + + response = await asyncio.to_thread(call_openai_api) try: # Extracting the response content from the new API response format @@ -212,8 +220,8 @@ async def process_input_message( return "An error occurred while processing the message." -# Execute the argparse code only when the file is run directly -if __name__ == "__main__": +# Executes the argparse code only when the file is run directly +if __name__ == "__main__": # noqa: C901 (ignore complexity in main function) # Parse command-line arguments args = parse_arguments() @@ -311,6 +319,27 @@ async def on_ready(): status=discord.Status(BOT_PRESENCE) ) + @bot.event + async def on_disconnect(): + """ + Event handler for when the bot disconnects from the Discord server. + """ + logger.info('Bot has disconnected') + + @bot.event + async def on_resumed(): + """ + Event handler for when the bot resumes its session. + """ + logger.info('Bot has resumed session') + + @bot.event + async def on_shard_ready(shard_id): + """ + Event handler for when a shard is ready. + """ + logger.info(f'Shard {shard_id} is ready') + @bot.event async def on_message(message): """ @@ -341,8 +370,9 @@ async def process_dm_message(message): if not await check_rate_limit(message.author): await message.channel.send( - "Command on cooldown. Please wait before using it again." + f"{message.author.mention} Exceeded the Rate Limit! Please slow down!" ) + logger.warning(f"Rate Limit Exceed by DM from {message.author}") return conversation_summary = get_conversation_summary( @@ -367,8 +397,9 @@ async def process_channel_message(message): if not await check_rate_limit(message.author): await message.channel.send( - "Command on cooldown. Please wait before using it again." + f"{message.author.mention} Exceeded the Rate Limit! Please slow down!" ) + logger.warning(f"Rate Limit Exceeded in {message.channel} by {message.author}") return conversation_summary = get_conversation_summary(