diff --git a/cogs/draw_image_generation.py b/cogs/draw_image_generation.py index fa338c3e..64987598 100644 --- a/cogs/draw_image_generation.py +++ b/cogs/draw_image_generation.py @@ -56,7 +56,10 @@ async def encapsulated_send( try: file, image_urls = await self.model.send_image_request( - ctx, prompt, vary=vary if not draw_from_optimizer else None, custom_api_key=custom_api_key + ctx, + prompt, + vary=vary if not draw_from_optimizer else None, + custom_api_key=custom_api_key, ) except ValueError as e: ( @@ -96,7 +99,14 @@ async def encapsulated_send( ) await result_message.edit( - view=SaveView(ctx, image_urls, self, self.converser_cog, result_message, custom_api_key=custom_api_key) + view=SaveView( + ctx, + image_urls, + self, + self.converser_cog, + result_message, + custom_api_key=custom_api_key, + ) ) self.converser_cog.users_to_interactions[user_id] = [] @@ -115,7 +125,14 @@ async def encapsulated_send( file=file, ) await message.edit( - view=SaveView(ctx, image_urls, self, self.converser_cog, message, custom_api_key=custom_api_key) + view=SaveView( + ctx, + image_urls, + self, + self.converser_cog, + message, + custom_api_key=custom_api_key, + ) ) else: # Varying case if not draw_from_optimizer: @@ -144,7 +161,12 @@ async def encapsulated_send( ) await result_message.edit( view=SaveView( - ctx, image_urls, self, self.converser_cog, result_message, custom_api_key=custom_api_key + ctx, + image_urls, + self, + self.converser_cog, + result_message, + custom_api_key=custom_api_key, ) ) @@ -179,7 +201,11 @@ async def draw(self, ctx: discord.ApplicationContext, prompt: str): return try: - asyncio.ensure_future(self.encapsulated_send(user.id, prompt, ctx, custom_api_key=user_api_key)) + asyncio.ensure_future( + self.encapsulated_send( + user.id, prompt, ctx, custom_api_key=user_api_key + ) + ) except Exception as e: print(e) @@ -258,11 +284,21 @@ def __init__( self.add_item(SaveButton(x, image_urls[x - 1])) if not only_save: if not no_retry: - self.add_item(RedoButton(self.cog, converser_cog=self.converser_cog, custom_api_key=self.custom_api_key)) + self.add_item( + RedoButton( + self.cog, + converser_cog=self.converser_cog, + custom_api_key=self.custom_api_key, + ) + ) for x in range(1, len(image_urls) + 1): self.add_item( VaryButton( - x, image_urls[x - 1], self.cog, converser_cog=self.converser_cog, custom_api_key=self.custom_api_key + x, + image_urls[x - 1], + self.cog, + converser_cog=self.converser_cog, + custom_api_key=self.custom_api_key, ) ) @@ -404,5 +440,11 @@ async def callback(self, interaction: discord.Interaction): self.converser_cog.users_to_interactions[user_id].append(message.id) asyncio.ensure_future( - self.cog.encapsulated_send(user_id, prompt, ctx, response_message, custom_api_key=self.custom_api_key) + self.cog.encapsulated_send( + user_id, + prompt, + ctx, + response_message, + custom_api_key=self.custom_api_key, + ) ) diff --git a/cogs/gpt_3_commands_and_converser.py b/cogs/gpt_3_commands_and_converser.py index 6db4804c..35d21bd9 100644 --- a/cogs/gpt_3_commands_and_converser.py +++ b/cogs/gpt_3_commands_and_converser.py @@ -32,12 +32,13 @@ USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() USER_KEY_DB = None if USER_INPUT_API_KEYS: - print("This server was configured to enforce user input API keys. Doing the required database setup now") + print( + "This server was configured to enforce user input API keys. Doing the required database setup now" + ) USER_KEY_DB = SqliteDict("user_key_db.sqlite") print("Retrieved/created the user key database") - class GPT3ComCon(discord.Cog, name="GPT3ComCon"): def __init__( self, @@ -148,9 +149,13 @@ async def get_user_api_key(user_id, ctx): modal = SetupModal(title="API Key Setup") if isinstance(ctx, discord.ApplicationContext): await ctx.send_modal(modal) - await ctx.send_followup("You must set up your API key before using this command.") + await ctx.send_followup( + "You must set up your API key before using this command." + ) else: - await ctx.reply("You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key.") + await ctx.reply( + "You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key." + ) return user_api_key async def load_file(self, file, ctx): @@ -199,7 +204,9 @@ async def on_ready(self): self.DEBUG_CHANNEL ) if USER_INPUT_API_KEYS: - print("This bot was set to use user input API keys. Doing the required SQLite setup now") + print( + "This bot was set to use user input API keys. Doing the required SQLite setup now" + ) await self.bot.sync_commands( commands=None, @@ -644,7 +651,9 @@ async def on_message(self, message): # Extract all the text after the !g and use it as the prompt. user_api_key = None if USER_INPUT_API_KEYS: - user_api_key = await GPT3ComCon.get_user_api_key(message.author.id, message) + user_api_key = await GPT3ComCon.get_user_api_key( + message.author.id, message + ) if not user_api_key: return @@ -790,7 +799,11 @@ async def encapsulated_send( # Create and upsert the embedding for the conversation id, prompt, timestamp embedding = await self.pinecone_service.upsert_conversation_embedding( - self.model, conversation_id, new_prompt, timestamp, custom_api_key=custom_api_key, + self.model, + conversation_id, + new_prompt, + timestamp, + custom_api_key=custom_api_key, ) embedding_prompt_less_author = await self.model.send_embedding_request( @@ -953,7 +966,11 @@ async def encapsulated_send( # Create and upsert the embedding for the conversation id, prompt, timestamp embedding = await self.pinecone_service.upsert_conversation_embedding( - self.model, conversation_id, response_text, timestamp, custom_api_key=custom_api_key + self.model, + conversation_id, + response_text, + timestamp, + custom_api_key=custom_api_key, ) # Cleanse @@ -967,12 +984,16 @@ async def encapsulated_send( response_message = ( await ctx.respond( response_text, - view=ConversationView(ctx, self, ctx.channel.id, custom_api_key=custom_api_key), + view=ConversationView( + ctx, self, ctx.channel.id, custom_api_key=custom_api_key + ), ) if from_context else await ctx.reply( response_text, - view=ConversationView(ctx, self, ctx.channel.id, custom_api_key=custom_api_key), + view=ConversationView( + ctx, self, ctx.channel.id, custom_api_key=custom_api_key + ), ) ) @@ -1368,12 +1389,18 @@ async def help(self, ctx: discord.ApplicationContext): await self.send_help_text(ctx) @discord.slash_command( - name="setup", description="Setup your API key for use with GPT3Discord", guild_ids=ALLOWED_GUILDS + name="setup", + description="Setup your API key for use with GPT3Discord", + guild_ids=ALLOWED_GUILDS, ) @discord.guild_only() async def setup(self, ctx: discord.ApplicationContext): if not USER_INPUT_API_KEYS: - await ctx.respond("This server doesn't support user input API keys.", ephemeral=True, delete_after=30) + await ctx.respond( + "This server doesn't support user input API keys.", + ephemeral=True, + delete_after=30, + ) modal = SetupModal(title="API Key Setup") await ctx.send_modal(modal) @@ -1437,8 +1464,10 @@ def __init__(self, ctx, converser_cog, id, custom_api_key=None): super().__init__(timeout=3600) # 1 hour interval to redo. self.converser_cog = converser_cog self.ctx = ctx - self.custom_api_key= custom_api_key - self.add_item(RedoButton(self.converser_cog, custom_api_key=self.custom_api_key)) + self.custom_api_key = custom_api_key + self.add_item( + RedoButton(self.converser_cog, custom_api_key=self.custom_api_key) + ) if id in self.converser_cog.conversation_threads: self.add_item(EndConvoButton(self.converser_cog)) @@ -1511,7 +1540,11 @@ async def callback(self, interaction: discord.Interaction): ) await self.converser_cog.encapsulated_send( - id=user_id, prompt=prompt, ctx=ctx, response_message=response_message, custom_api_key=self.custom_api_key + id=user_id, + prompt=prompt, + ctx=ctx, + response_message=response_message, + custom_api_key=self.custom_api_key, ) else: await interaction.response.send_message( @@ -1520,37 +1553,63 @@ async def callback(self, interaction: discord.Interaction): delete_after=10, ) + class SetupModal(discord.ui.Modal): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.add_item(discord.ui.InputText(label="OpenAI API Key", placeholder="sk--......", )) + self.add_item( + discord.ui.InputText( + label="OpenAI API Key", + placeholder="sk--......", + ) + ) async def callback(self, interaction: discord.Interaction): user = interaction.user api_key = self.children[0].value # Validate that api_key is indeed in this format if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key): - await interaction.response.send_message("Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.", ephemeral=True, delete_after=100) + await interaction.response.send_message( + "Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.", + ephemeral=True, + delete_after=100, + ) else: # We can save the key for the user to the database. # Make a test request using the api key to ensure that it is valid. try: await Model.send_test_request(api_key) - await interaction.response.send_message("Your API key was successfully validated.", ephemeral=True, delete_after=10) + await interaction.response.send_message( + "Your API key was successfully validated.", + ephemeral=True, + delete_after=10, + ) except Exception as e: - await interaction.response.send_message(f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding", ephemeral=True, delete_after=30) + await interaction.response.send_message( + f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding", + ephemeral=True, + delete_after=30, + ) return # Save the key to the database try: USER_KEY_DB[user.id] = api_key USER_KEY_DB.commit() - await interaction.followup.send("Your API key was successfully saved.", ephemeral=True, delete_after=10) + await interaction.followup.send( + "Your API key was successfully saved.", + ephemeral=True, + delete_after=10, + ) except Exception as e: traceback.print_exc() - await interaction.followup.send("There was an error saving your API key.", ephemeral=True, delete_after=30) + await interaction.followup.send( + "There was an error saving your API key.", + ephemeral=True, + delete_after=30, + ) return - pass \ No newline at end of file + pass diff --git a/cogs/image_prompt_optimizer.py b/cogs/image_prompt_optimizer.py index 56af8fc6..ffd1e998 100644 --- a/cogs/image_prompt_optimizer.py +++ b/cogs/image_prompt_optimizer.py @@ -15,6 +15,7 @@ if USER_INPUT_API_KEYS: USER_KEY_DB = SqliteDict("user_key_db.sqlite") + class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"): _OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:" @@ -123,7 +124,10 @@ async def optimize(self, ctx: discord.ApplicationContext, prompt: str): self.converser_cog.redo_users[user.id].add_interaction(response_message.id) await response_message.edit( view=OptimizeView( - self.converser_cog, self.image_service_cog, self.deletion_queue, custom_api_key=user_api_key, + self.converser_cog, + self.image_service_cog, + self.deletion_queue, + custom_api_key=user_api_key, ) ) @@ -142,18 +146,36 @@ async def optimize(self, ctx: discord.ApplicationContext, prompt: str): class OptimizeView(discord.ui.View): - def __init__(self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None): + def __init__( + self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None + ): super().__init__(timeout=None) self.cog = converser_cog self.image_service_cog = image_service_cog self.deletion_queue = deletion_queue self.custom_api_key = custom_api_key - self.add_item(RedoButton(self.cog, self.image_service_cog, self.deletion_queue, custom_api_key=self.custom_api_key)) - self.add_item(DrawButton(self.cog, self.image_service_cog, self.deletion_queue, custom_api_key=self.custom_api_key)) + self.add_item( + RedoButton( + self.cog, + self.image_service_cog, + self.deletion_queue, + custom_api_key=self.custom_api_key, + ) + ) + self.add_item( + DrawButton( + self.cog, + self.image_service_cog, + self.deletion_queue, + custom_api_key=self.custom_api_key, + ) + ) class DrawButton(discord.ui.Button["OptimizeView"]): - def __init__(self, converser_cog, image_service_cog, deletion_queue, custom_api_key): + def __init__( + self, converser_cog, image_service_cog, deletion_queue, custom_api_key + ): super().__init__(style=discord.ButtonStyle.green, label="Draw") self.converser_cog = converser_cog self.image_service_cog = image_service_cog @@ -206,7 +228,9 @@ async def callback(self, interaction: discord.Interaction): class RedoButton(discord.ui.Button["OptimizeView"]): - def __init__(self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None): + def __init__( + self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None + ): super().__init__(style=discord.ButtonStyle.danger, label="Retry") self.converser_cog = converser_cog self.image_service_cog = image_service_cog diff --git a/models/openai_model.py b/models/openai_model.py index c6e5e2a1..fed8c46e 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -474,7 +474,9 @@ async def send_request( else frequency_penalty_override, "best_of": self.best_of if not best_of_override else best_of_override, } - headers = {"Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}"} + headers = { + "Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}" + } async with session.post( "https://api.openai.com/v1/completions", json=payload, headers=headers ) as resp: @@ -499,7 +501,7 @@ async def send_test_request(api_key): } headers = {"Authorization": f"Bearer {api_key}"} async with session.post( - "https://api.openai.com/v1/completions", json=payload, headers=headers + "https://api.openai.com/v1/completions", json=payload, headers=headers ) as resp: response = await resp.json() try: @@ -550,9 +552,9 @@ async def send_image_request( async with session.post( "https://api.openai.com/v1/images/variations", - headers={ - "Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}", - }, + headers={ + "Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}", + }, data=data, ) as resp: response = await resp.json() diff --git a/models/pinecone_service_model.py b/models/pinecone_service_model.py index 47b82f76..3143fc7a 100644 --- a/models/pinecone_service_model.py +++ b/models/pinecone_service_model.py @@ -26,7 +26,9 @@ async def upsert_conversation_embedding( print("The split chunk is ", chunk) # Create an embedding for the split chunk - embedding = await model.send_embedding_request(chunk, custom_api_key=custom_api_key) + embedding = await model.send_embedding_request( + chunk, custom_api_key=custom_api_key + ) if not first_embedding: first_embedding = embedding self.index.upsert( @@ -38,7 +40,9 @@ async def upsert_conversation_embedding( ) return first_embedding else: - embedding = await model.send_embedding_request(text, custom_api_key=custom_api_key) + embedding = await model.send_embedding_request( + text, custom_api_key=custom_api_key + ) self.index.upsert( [ (