diff --git a/cogs/commands.py b/cogs/commands.py index f1adc542..94fdc259 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -983,7 +983,14 @@ async def draw_old(self, ctx: discord.ApplicationContext, prompt: str): default="natural", autocomplete=Settings_autocompleter.get_dalle3_image_styles, ) - async def draw(self, ctx: discord.ApplicationContext, prompt: str, quality: str, image_size: str, style: str): + async def draw( + self, + ctx: discord.ApplicationContext, + prompt: str, + quality: str, + image_size: str, + style: str, + ): await self.image_draw_cog.draw_command(ctx, prompt, quality, image_size, style) @add_to_group("dalle") diff --git a/cogs/image_service_cog.py b/cogs/image_service_cog.py index c1bad18b..2418cf67 100644 --- a/cogs/image_service_cog.py +++ b/cogs/image_service_cog.py @@ -37,7 +37,13 @@ def __init__( self.redo_users = {} async def draw_command( - self, ctx: discord.ApplicationContext, prompt: str, quality: str, image_size: str, style: str, from_action=False + self, + ctx: discord.ApplicationContext, + prompt: str, + quality: str, + image_size: str, + style: str, + from_action=False, ): """With an ApplicationContext and prompt, send a dalle image to the invoked channel. Ephemeral if from an action""" user_api_key = None @@ -63,7 +69,15 @@ async def draw_command( try: asyncio.ensure_future( ImageService.encapsulated_send( - self, user.id, prompt, ctx, custom_api_key=user_api_key, dalle_3=True, quality=quality, image_size=image_size, style=style + self, + user.id, + prompt, + ctx, + custom_api_key=user_api_key, + dalle_3=True, + quality=quality, + image_size=image_size, + style=style, ) ) @@ -116,7 +130,14 @@ async def draw_old_command( async def draw_action(self, ctx, message): """decoupler to handle context actions for the draw command""" - await self.draw_command(ctx, message.content, quality="hd",image_size="1024x1024", style="natural", from_action=True) + await self.draw_command( + ctx, + message.content, + quality="hd", + image_size="1024x1024", + style="natural", + from_action=True, + ) async def local_size_command(self, ctx: discord.ApplicationContext): """Get the folder size of the image folder""" diff --git a/models/openai_model.py b/models/openai_model.py index 50dbadd6..988d7571 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -1265,8 +1265,8 @@ async def send_test_request(api_key): raise ValueError(str(response["error"]["message"])) return response - async def save_image_urls_and_return(self, image_urls, ctx): + async def save_image_urls_and_return(self, image_urls, ctx): # For each image url, open it as an image object using PIL images = await asyncio.get_running_loop().run_in_executor( None, @@ -1373,7 +1373,9 @@ async def save_image_urls_and_return(self, image_urls, ctx): return discord.File(temp_file.name), image_urls - async def make_image_request_individual(self, session, url, json_payload, headers) -> dict: + async def make_image_request_individual( + self, session, url, json_payload, headers + ) -> dict: async with session.post(url, json=json_payload, headers=headers) as resp: return await resp.json() @@ -1385,8 +1387,8 @@ async def make_image_request_individual(self, session, url, json_payload, header max_tries=4, on_backoff=backoff_handler_http, ) - async def send_image_request( - self, ctx, prompt, quality, image_size, style, custom_api_key=None + async def send_image_request( + self, ctx, prompt, quality, image_size, style, custom_api_key=None ) -> tuple[File, List[Any]]: words = len(prompt.split(" ")) if words < 1 or words > 75: @@ -1398,7 +1400,13 @@ async def send_image_request( image_urls = [] tasks = [] - payload = {"prompt": prompt, "quality": quality, "style": style, "model": "dall-e-3", "size": image_size} + payload = { + "prompt": prompt, + "quality": quality, + "style": style, + "model": "dall-e-3", + "size": image_size, + } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key if not custom_api_key else custom_api_key}", @@ -1408,10 +1416,17 @@ async def send_image_request( headers["OpenAI-Organization"] = self.openai_organization # Setup the client session outside of the loop - async with aiohttp.ClientSession(raise_for_status=True, timeout=aiohttp.ClientTimeout(total=300)) as session: + async with aiohttp.ClientSession( + raise_for_status=True, timeout=aiohttp.ClientTimeout(total=300) + ) as session: # Create a coroutine for each image request and store it in the tasks list for _ in range(self.num_images): - task = self.make_image_request_individual(session, "https://api.openai.com/v1/images/generations", payload, headers) + task = self.make_image_request_individual( + session, + "https://api.openai.com/v1/images/generations", + payload, + headers, + ) tasks.append(task) # Run all tasks in parallel and wait for them to complete diff --git a/models/user_model.py b/models/user_model.py index c60d7d3f..b1ebb144 100644 --- a/models/user_model.py +++ b/models/user_model.py @@ -5,7 +5,19 @@ class RedoUser: - def __init__(self, prompt, instruction, message, ctx, response, paginator, dalle_3=False, quality=None, image_size=None, style=None): + def __init__( + self, + prompt, + instruction, + message, + ctx, + response, + paginator, + dalle_3=False, + quality=None, + image_size=None, + style=None, + ): self.prompt = prompt self.instruction = instruction self.message = message diff --git a/services/image_service.py b/services/image_service.py index 93657895..1f99098a 100644 --- a/services/image_service.py +++ b/services/image_service.py @@ -118,7 +118,7 @@ async def encapsulated_send( image_service_cog.converser_cog, result_message, custom_api_key=custom_api_key, - dalle_3=dalle_3 + dalle_3=dalle_3, ) ) @@ -158,7 +158,7 @@ async def encapsulated_send( image_service_cog.converser_cog, message, custom_api_key=custom_api_key, - dalle_3=dalle_3 + dalle_3=dalle_3, ) ) else: # Varying case @@ -225,7 +225,7 @@ def __init__( no_retry=False, only_save=None, custom_api_key=None, - dalle_3=False + dalle_3=False, ): super().__init__( timeout=3600 if not only_save else None @@ -237,7 +237,7 @@ def __init__( self.converser_cog = converser_cog self.message = message self.custom_api_key = custom_api_key - self.dalle_3=dalle_3 + self.dalle_3 = dalle_3 for x in range(1, len(image_urls) + 1): self.add_item(SaveButton(x, image_urls[x - 1])) if not only_save: