Skip to content

Commit

Permalink
Split Long Responses in Two and Improve Logging
Browse files Browse the repository at this point in the history
This fixes #14 by splitting any response that is long into two
pieces from a newline closest to the middle.

While here I've improved the logging capabilities and added the
streamhandler logging config.

Flake8 line length test is set to 99 overriding the default as well
  • Loading branch information
johndotpub committed Feb 2, 2024
1 parent 35d7bc1 commit 7357bf3
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 97 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/flake8-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --max-complexity=10 --statistics
flake8 . --count --max-complexity=10 --statistics --max-line-length=99
- name: Test with pytest
run: |
pytest
254 changes: 158 additions & 96 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import configparser
import logging
import os
import sys
import time
from logging.handlers import RotatingFileHandler

Expand All @@ -15,26 +16,34 @@

# Define the function to parse command-line arguments
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='GPT-based Discord bot.')
parser.add_argument('--conf', help='Configuration file path')
args = parser.parse_args()
return args
try:
parser = argparse.ArgumentParser(description='GPT-based Discord bot.')
parser.add_argument('--conf', help='Configuration file path')
args = parser.parse_args()
return args
except Exception as e:
logger.error(f"Error parsing arguments: {e}")
raise


# Define the function to load the configuration
def load_configuration(config_file: str) -> configparser.ConfigParser:
config = configparser.ConfigParser()

# Check if the configuration file exists
if os.path.exists(config_file):
config.read(config_file)
else:
# Fall back to environment variables
config.read_dict(
{section: dict(os.environ) for section in config.sections()}
)
try:
config = configparser.ConfigParser()

# Check if the configuration file exists
if os.path.exists(config_file):
config.read(config_file)
else:
# Fall back to environment variables
config.read_dict(
{section: dict(os.environ) for section in config.sections()}
)

return config
return config
except Exception as e:
logger.error(f"Error loading configuration: {e}")
raise


def set_activity_status(
Expand All @@ -44,68 +53,80 @@ def set_activity_status(
"""
Return discord.Activity object with specified activity type and status
"""
activity_types = {
'playing': discord.ActivityType.playing,
'streaming': discord.ActivityType.streaming,
'listening': discord.ActivityType.listening,
'watching': discord.ActivityType.watching,
'custom': discord.ActivityType.custom,
'competing': discord.ActivityType.competing
}
return discord.Activity(
type=activity_types.get(
activity_type, discord.ActivityType.listening
),
name=activity_status
)
try:
activity_types = {
'playing': discord.ActivityType.playing,
'streaming': discord.ActivityType.streaming,
'listening': discord.ActivityType.listening,
'watching': discord.ActivityType.watching,
'custom': discord.ActivityType.custom,
'competing': discord.ActivityType.competing
}
return discord.Activity(
type=activity_types.get(
activity_type, discord.ActivityType.listening
),
name=activity_status
)
except Exception as e:
logger.error(f"Error setting activity status: {e}")
raise


# Define the function to get the conversation summary
def get_conversation_summary(conversation: list[dict]) -> list[dict]:
"""
Conversation summary from combining user messages and assistant responses
"""
summary = []
user_messages = [
message for message in conversation if message["role"] == "user"
]
assistant_responses = [
message for message in conversation if message["role"] == "assistant"
]

# Combine user messages and assistant responses into a summary
for user_message, assistant_response in zip(
user_messages, assistant_responses
):
summary.append(user_message)
summary.append(assistant_response)
try:
summary = []
user_messages = [
message for message in conversation if message["role"] == "user"
]
assistant_responses = [
message for message in conversation if message["role"] == "assistant"
]

# Combine user messages and assistant responses into a summary
for user_message, assistant_response in zip(
user_messages, assistant_responses
):
summary.append(user_message)
summary.append(assistant_response)

return summary
return summary
except Exception as e:
logger.error(f"Error getting conversation summary: {e}")
raise


async def check_rate_limit(
user: discord.User,
logger: logging.Logger = None
) -> bool:
if logger is None:
logger = logging.getLogger(__name__)
"""
Check if a user has exceeded the rate limit for sending messages.
"""
current_time = time.time()
last_command_timestamp = last_command_timestamps.get(user.id, 0)
last_command_count_user = last_command_count.get(user.id, 0)
if current_time - last_command_timestamp > RATE_LIMIT_PER:
last_command_timestamps[user.id] = current_time
last_command_count[user.id] = 1
logger.info(f"Rate limit passed for user: {user}")
return True
if last_command_count_user < RATE_LIMIT:
last_command_count[user.id] += 1
logger.info(f"Rate limit passed for user: {user}")
return True
logger.info(f"Rate limit exceeded for user: {user}")
return False
try:
if logger is None:
logger = logging.getLogger(__name__)
"""
Check if a user has exceeded the rate limit for sending messages.
"""
current_time = time.time()
last_command_timestamp = last_command_timestamps.get(user.id, 0)
last_command_count_user = last_command_count.get(user.id, 0)
if current_time - last_command_timestamp > RATE_LIMIT_PER:
last_command_timestamps[user.id] = current_time
last_command_count[user.id] = 1
logger.info(f"Rate limit passed for user: {user}")
return True
if last_command_count_user < RATE_LIMIT:
last_command_count[user.id] += 1
logger.info(f"Rate limit passed for user: {user}")
return True
logger.info(f"Rate limit exceeded for user: {user}")
return False
except Exception as e:
logger.error(f"Error checking rate limit: {e}")
raise


async def process_input_message(
Expand Down Expand Up @@ -227,22 +248,39 @@ async def process_input_message(
RATE_LIMIT_PER = config.getint('Limits', 'RATE_LIMIT_PER', fallback=60)

LOG_FILE = config.get('Logging', 'LOG_FILE', fallback='bot.log')
LOG_LEVEL = config.get('Logging', 'LOG_LEVEL', fallback='INFO')

# Set up logging
logger = logging.getLogger('discord')
logger.setLevel(logging.INFO)
logger.setLevel(getattr(logging, LOG_LEVEL.upper()))

# File handler
file_handler = RotatingFileHandler(
LOG_FILE, maxBytes=5 * 1024 * 1024, backupCount=5
)
file_handler.setLevel(logging.WARNING)
file_handler.setLevel(getattr(logging, LOG_LEVEL.upper()))
file_formatter = logging.Formatter(
'%(asctime)s [%(levelname)s] %(name)s: %(message)s'
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)

# Add a StreamHandler to the logger
stream_handler = logging.StreamHandler()
stream_handler.setLevel(getattr(logging, LOG_LEVEL.upper()))
stream_handler.setFormatter(file_formatter)
logger.addHandler(stream_handler)

# Set a global exception handler
def handle_unhandled_exception(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return

logger.error("Unhandled exception", exc_info=(exc_type, exc_value, exc_traceback))

sys.excepthook = handle_unhandled_exception

# Set the intents for the bot
intents = discord.Intents.default()
intents.typing = False
Expand Down Expand Up @@ -283,43 +321,19 @@ async def on_message(message):
"""
Event handler for when a message is received.
"""
if message.author == bot.user:
return

if isinstance(message.channel, discord.DMChannel):
# Process DM messages without the @botname requirement
logger.info(
f'Received DM: {message.content} | Author: {message.author}'
)

if not await check_rate_limit(message.author):
await message.channel.send(
"Command on cooldown. Please wait before using it again."
)
try:
if message.author == bot.user:
return

conversation_summary = get_conversation_summary(
conversation_history.get(message.author.id, [])
)
response = await process_input_message(
message.content, message.author, conversation_summary
)
await message.channel.send(response)
elif (
isinstance(message.channel, discord.TextChannel)
and message.channel.name in ALLOWED_CHANNELS
):
if bot.user in message.mentions:
if isinstance(message.channel, discord.DMChannel):
# Process DM messages without the @botname requirement
logger.info(
'Received message: ' + message.content
+ ' | Channel: ' + str(message.channel)
+ ' | Author: ' + str(message.author)
f'Received DM: {message.content} | Author: {message.author}'
)

if not await check_rate_limit(message.author):
await message.channel.send(
"Command on cooldown. "
"Please wait before using it again."
"Command on cooldown. Please wait before using it again."
)
return

Expand All @@ -329,7 +343,55 @@ async def on_message(message):
response = await process_input_message(
message.content, message.author, conversation_summary
)
await message.channel.send(response)
await send_split_message(message.channel, response)
elif (
isinstance(message.channel, discord.TextChannel)
and message.channel.name in ALLOWED_CHANNELS
):
if bot.user in message.mentions:
logger.info(
'Received message: ' + message.content
+ ' | Channel: ' + str(message.channel)
+ ' | Author: ' + str(message.author)
)

if not await check_rate_limit(message.author):
await message.channel.send(
"Command on cooldown. "
"Please wait before using it again."
)
return

conversation_summary = get_conversation_summary(
conversation_history.get(message.author.id, [])
)
response = await process_input_message(
message.content, message.author, conversation_summary
)
await send_split_message(message.channel, response)
except Exception as e:
logger.error(f"An error occurred in on_message: {e}")

async def send_split_message(channel, message):
"""
Send a message to a channel. If the message is longer than 2000 characters,
it is split into multiple messages at the nearest newline character around
the middle of the message.
"""
if len(message) <= 2000:
await channel.send(message)
else:
# Find the nearest newline character around the middle of the message
middle_index = len(message) // 2
split_index = message.rfind('\n', 0, middle_index)
if split_index == -1: # No newline character found
split_index = middle_index # Split at the middle of the message
# Split the message into two parts
message_part1 = message[:split_index]
message_part2 = message[split_index:]
# Send the two parts as separate messages
await channel.send(message_part1)
await channel.send(message_part2)

# Run the bot
bot.run(DISCORD_TOKEN)

0 comments on commit 7357bf3

Please sign in to comment.