diff --git a/bot.py b/bot.py index 4b44d64..382dadf 100644 --- a/bot.py +++ b/bot.py @@ -15,27 +15,34 @@ # Define the function to parse command-line arguments def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description='GPT-based Discord bot.') - parser.add_argument('--conf', help='Configuration file path') - args = parser.parse_args() - return args + try: + parser = argparse.ArgumentParser(description='GPT-based Discord bot.') + parser.add_argument('--conf', help='Configuration file path') + args = parser.parse_args() + return args + except Exception as e: + logger.error(f"Error parsing arguments: {e}") + raise # Define the function to load the configuration def load_configuration(config_file: str) -> configparser.ConfigParser: - config = configparser.ConfigParser() - - # Check if the configuration file exists - if os.path.exists(config_file): - config.read(config_file) - else: - # Fall back to environment variables - config.read_dict( - {section: dict(os.environ) for section in config.sections()} - ) + try: + config = configparser.ConfigParser() - return config + # Check if the configuration file exists + if os.path.exists(config_file): + config.read(config_file) + else: + # Fall back to environment variables + config.read_dict( + {section: dict(os.environ) for section in config.sections()} + ) + return config + except Exception as e: + logger.error(f"Error loading configuration: {e}") + raise def set_activity_status( activity_type: str, @@ -44,20 +51,24 @@ def set_activity_status( """ Return discord.Activity object with specified activity type and status """ - activity_types = { - 'playing': discord.ActivityType.playing, - 'streaming': discord.ActivityType.streaming, - 'listening': discord.ActivityType.listening, - 'watching': discord.ActivityType.watching, - 'custom': discord.ActivityType.custom, - 'competing': discord.ActivityType.competing - } - return discord.Activity( - type=activity_types.get( - activity_type, discord.ActivityType.listening - ), - name=activity_status - ) + try: + activity_types = { + 'playing': discord.ActivityType.playing, + 'streaming': discord.ActivityType.streaming, + 'listening': discord.ActivityType.listening, + 'watching': discord.ActivityType.watching, + 'custom': discord.ActivityType.custom, + 'competing': discord.ActivityType.competing + } + return discord.Activity( + type=activity_types.get( + activity_type, discord.ActivityType.listening + ), + name=activity_status + ) + except Exception as e: + logger.error(f"Error setting activity status: {e}") + raise # Define the function to get the conversation summary @@ -65,47 +76,54 @@ def get_conversation_summary(conversation: list[dict]) -> list[dict]: """ Conversation summary from combining user messages and assistant responses """ - summary = [] - user_messages = [ - message for message in conversation if message["role"] == "user" - ] - assistant_responses = [ - message for message in conversation if message["role"] == "assistant" - ] - - # Combine user messages and assistant responses into a summary - for user_message, assistant_response in zip( - user_messages, assistant_responses - ): - summary.append(user_message) - summary.append(assistant_response) - - return summary + try: + summary = [] + user_messages = [ + message for message in conversation if message["role"] == "user" + ] + assistant_responses = [ + message for message in conversation if message["role"] == "assistant" + ] + + # Combine user messages and assistant responses into a summary + for user_message, assistant_response in zip( + user_messages, assistant_responses + ): + summary.append(user_message) + summary.append(assistant_response) + return summary + except Exception as e: + logger.error(f"Error getting conversation summary: {e}") + raise async def check_rate_limit( user: discord.User, logger: logging.Logger = None ) -> bool: - if logger is None: - logger = logging.getLogger(__name__) - """ - Check if a user has exceeded the rate limit for sending messages. - """ - current_time = time.time() - last_command_timestamp = last_command_timestamps.get(user.id, 0) - last_command_count_user = last_command_count.get(user.id, 0) - if current_time - last_command_timestamp > RATE_LIMIT_PER: - last_command_timestamps[user.id] = current_time - last_command_count[user.id] = 1 - logger.info(f"Rate limit passed for user: {user}") - return True - if last_command_count_user < RATE_LIMIT: - last_command_count[user.id] += 1 - logger.info(f"Rate limit passed for user: {user}") - return True - logger.info(f"Rate limit exceeded for user: {user}") - return False + try: + if logger is None: + logger = logging.getLogger(__name__) + """ + Check if a user has exceeded the rate limit for sending messages. + """ + current_time = time.time() + last_command_timestamp = last_command_timestamps.get(user.id, 0) + last_command_count_user = last_command_count.get(user.id, 0) + if current_time - last_command_timestamp > RATE_LIMIT_PER: + last_command_timestamps[user.id] = current_time + last_command_count[user.id] = 1 + logger.info(f"Rate limit passed for user: {user}") + return True + if last_command_count_user < RATE_LIMIT: + last_command_count[user.id] += 1 + logger.info(f"Rate limit passed for user: {user}") + return True + logger.info(f"Rate limit exceeded for user: {user}") + return False + except Exception as e: + logger.error(f"Error checking rate limit: {e}") + raise async def process_input_message( @@ -227,22 +245,39 @@ async def process_input_message( RATE_LIMIT_PER = config.getint('Limits', 'RATE_LIMIT_PER', fallback=60) LOG_FILE = config.get('Logging', 'LOG_FILE', fallback='bot.log') + LOG_LEVEL = config.get('Logging', 'LOG_LEVEL', fallback='INFO') # Set up logging logger = logging.getLogger('discord') - logger.setLevel(logging.INFO) + logger.setLevel(getattr(logging, LOG_LEVEL.upper())) # File handler file_handler = RotatingFileHandler( LOG_FILE, maxBytes=5 * 1024 * 1024, backupCount=5 ) - file_handler.setLevel(logging.WARNING) + file_handler.setLevel(getattr(logging, LOG_LEVEL.upper())) file_formatter = logging.Formatter( '%(asctime)s [%(levelname)s] %(name)s: %(message)s' ) file_handler.setFormatter(file_formatter) logger.addHandler(file_handler) + # Add a StreamHandler to the logger + stream_handler = logging.StreamHandler() + stream_handler.setLevel(getattr(logging, LOG_LEVEL.upper())) + stream_handler.setFormatter(file_formatter) + logger.addHandler(stream_handler) + + # Set a global exception handler + def handle_unhandled_exception(exc_type, exc_value, exc_traceback): + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return + + logger.error("Unhandled exception", exc_info=(exc_type, exc_value, exc_traceback)) + + sys.excepthook = handle_unhandled_exception + # Set the intents for the bot intents = discord.Intents.default() intents.typing = False @@ -283,43 +318,19 @@ async def on_message(message): """ Event handler for when a message is received. """ - if message.author == bot.user: - return - - if isinstance(message.channel, discord.DMChannel): - # Process DM messages without the @botname requirement - logger.info( - f'Received DM: {message.content} | Author: {message.author}' - ) - - if not await check_rate_limit(message.author): - await message.channel.send( - "Command on cooldown. Please wait before using it again." - ) + try: + if message.author == bot.user: return - conversation_summary = get_conversation_summary( - conversation_history.get(message.author.id, []) - ) - response = await process_input_message( - message.content, message.author, conversation_summary - ) - await message.channel.send(response) - elif ( - isinstance(message.channel, discord.TextChannel) - and message.channel.name in ALLOWED_CHANNELS - ): - if bot.user in message.mentions: + if isinstance(message.channel, discord.DMChannel): + # Process DM messages without the @botname requirement logger.info( - 'Received message: ' + message.content - + ' | Channel: ' + str(message.channel) - + ' | Author: ' + str(message.author) + f'Received DM: {message.content} | Author: {message.author}' ) if not await check_rate_limit(message.author): await message.channel.send( - "Command on cooldown. " - "Please wait before using it again." + "Command on cooldown. Please wait before using it again." ) return @@ -329,7 +340,55 @@ async def on_message(message): response = await process_input_message( message.content, message.author, conversation_summary ) - await message.channel.send(response) + await send_split_message(message.channel, response) + elif ( + isinstance(message.channel, discord.TextChannel) + and message.channel.name in ALLOWED_CHANNELS + ): + if bot.user in message.mentions: + logger.info( + 'Received message: ' + message.content + + ' | Channel: ' + str(message.channel) + + ' | Author: ' + str(message.author) + ) + + if not await check_rate_limit(message.author): + await message.channel.send( + "Command on cooldown. " + "Please wait before using it again." + ) + return + + conversation_summary = get_conversation_summary( + conversation_history.get(message.author.id, []) + ) + response = await process_input_message( + message.content, message.author, conversation_summary + ) + await send_split_message(message.channel, response) + except Exception as e: + logger.error(f"An error occurred in on_message: {e}") + + async def send_split_message(channel, message): + """ + Send a message to a channel. If the message is longer than 2000 characters, + it is split into multiple messages at the nearest newline character around + the middle of the message. + """ + if len(message) <= 2000: + await channel.send(message) + else: + # Find the nearest newline character around the middle of the message + middle_index = len(message) // 2 + split_index = message.rfind('\n', 0, middle_index) + if split_index == -1: # No newline character found + split_index = middle_index # Split at the middle of the message + # Split the message into two parts + message_part1 = message[:split_index] + message_part2 = message[split_index:] + # Send the two parts as separate messages + await channel.send(message_part1) + await channel.send(message_part2) # Run the bot bot.run(DISCORD_TOKEN)