Skip to content

Commit

Permalink
Format Python code with psf/black push
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions authored and github-actions committed Jan 10, 2023
1 parent c6ccfd9 commit b47d52f
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 43 deletions.
58 changes: 50 additions & 8 deletions cogs/draw_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ async def encapsulated_send(

try:
file, image_urls = await self.model.send_image_request(
ctx, prompt, vary=vary if not draw_from_optimizer else None, custom_api_key=custom_api_key
ctx,
prompt,
vary=vary if not draw_from_optimizer else None,
custom_api_key=custom_api_key,
)
except ValueError as e:
(
Expand Down Expand Up @@ -96,7 +99,14 @@ async def encapsulated_send(
)

await result_message.edit(
view=SaveView(ctx, image_urls, self, self.converser_cog, result_message, custom_api_key=custom_api_key)
view=SaveView(
ctx,
image_urls,
self,
self.converser_cog,
result_message,
custom_api_key=custom_api_key,
)
)

self.converser_cog.users_to_interactions[user_id] = []
Expand All @@ -115,7 +125,14 @@ async def encapsulated_send(
file=file,
)
await message.edit(
view=SaveView(ctx, image_urls, self, self.converser_cog, message, custom_api_key=custom_api_key)
view=SaveView(
ctx,
image_urls,
self,
self.converser_cog,
message,
custom_api_key=custom_api_key,
)
)
else: # Varying case
if not draw_from_optimizer:
Expand Down Expand Up @@ -144,7 +161,12 @@ async def encapsulated_send(
)
await result_message.edit(
view=SaveView(
ctx, image_urls, self, self.converser_cog, result_message, custom_api_key=custom_api_key
ctx,
image_urls,
self,
self.converser_cog,
result_message,
custom_api_key=custom_api_key,
)
)

Expand Down Expand Up @@ -179,7 +201,11 @@ async def draw(self, ctx: discord.ApplicationContext, prompt: str):
return

try:
asyncio.ensure_future(self.encapsulated_send(user.id, prompt, ctx, custom_api_key=user_api_key))
asyncio.ensure_future(
self.encapsulated_send(
user.id, prompt, ctx, custom_api_key=user_api_key
)
)

except Exception as e:
print(e)
Expand Down Expand Up @@ -258,11 +284,21 @@ def __init__(
self.add_item(SaveButton(x, image_urls[x - 1]))
if not only_save:
if not no_retry:
self.add_item(RedoButton(self.cog, converser_cog=self.converser_cog, custom_api_key=self.custom_api_key))
self.add_item(
RedoButton(
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
)
)
for x in range(1, len(image_urls) + 1):
self.add_item(
VaryButton(
x, image_urls[x - 1], self.cog, converser_cog=self.converser_cog, custom_api_key=self.custom_api_key
x,
image_urls[x - 1],
self.cog,
converser_cog=self.converser_cog,
custom_api_key=self.custom_api_key,
)
)

Expand Down Expand Up @@ -404,5 +440,11 @@ async def callback(self, interaction: discord.Interaction):
self.converser_cog.users_to_interactions[user_id].append(message.id)

asyncio.ensure_future(
self.cog.encapsulated_send(user_id, prompt, ctx, response_message, custom_api_key=self.custom_api_key)
self.cog.encapsulated_send(
user_id,
prompt,
ctx,
response_message,
custom_api_key=self.custom_api_key,
)
)
103 changes: 81 additions & 22 deletions cogs/gpt_3_commands_and_converser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys()
USER_KEY_DB = None
if USER_INPUT_API_KEYS:
print("This server was configured to enforce user input API keys. Doing the required database setup now")
print(
"This server was configured to enforce user input API keys. Doing the required database setup now"
)
USER_KEY_DB = SqliteDict("user_key_db.sqlite")
print("Retrieved/created the user key database")



class GPT3ComCon(discord.Cog, name="GPT3ComCon"):
def __init__(
self,
Expand Down Expand Up @@ -148,9 +149,13 @@ async def get_user_api_key(user_id, ctx):
modal = SetupModal(title="API Key Setup")
if isinstance(ctx, discord.ApplicationContext):
await ctx.send_modal(modal)
await ctx.send_followup("You must set up your API key before using this command.")
await ctx.send_followup(
"You must set up your API key before using this command."
)
else:
await ctx.reply("You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key.")
await ctx.reply(
"You must set up your API key before typing in a GPT3 powered channel, type `/setup` to enter your API key."
)
return user_api_key

async def load_file(self, file, ctx):
Expand Down Expand Up @@ -199,7 +204,9 @@ async def on_ready(self):
self.DEBUG_CHANNEL
)
if USER_INPUT_API_KEYS:
print("This bot was set to use user input API keys. Doing the required SQLite setup now")
print(
"This bot was set to use user input API keys. Doing the required SQLite setup now"
)

await self.bot.sync_commands(
commands=None,
Expand Down Expand Up @@ -644,7 +651,9 @@ async def on_message(self, message):
# Extract all the text after the !g and use it as the prompt.
user_api_key = None
if USER_INPUT_API_KEYS:
user_api_key = await GPT3ComCon.get_user_api_key(message.author.id, message)
user_api_key = await GPT3ComCon.get_user_api_key(
message.author.id, message
)
if not user_api_key:
return

Expand Down Expand Up @@ -790,7 +799,11 @@ async def encapsulated_send(

# Create and upsert the embedding for the conversation id, prompt, timestamp
embedding = await self.pinecone_service.upsert_conversation_embedding(
self.model, conversation_id, new_prompt, timestamp, custom_api_key=custom_api_key,
self.model,
conversation_id,
new_prompt,
timestamp,
custom_api_key=custom_api_key,
)

embedding_prompt_less_author = await self.model.send_embedding_request(
Expand Down Expand Up @@ -953,7 +966,11 @@ async def encapsulated_send(

# Create and upsert the embedding for the conversation id, prompt, timestamp
embedding = await self.pinecone_service.upsert_conversation_embedding(
self.model, conversation_id, response_text, timestamp, custom_api_key=custom_api_key
self.model,
conversation_id,
response_text,
timestamp,
custom_api_key=custom_api_key,
)

# Cleanse
Expand All @@ -967,12 +984,16 @@ async def encapsulated_send(
response_message = (
await ctx.respond(
response_text,
view=ConversationView(ctx, self, ctx.channel.id, custom_api_key=custom_api_key),
view=ConversationView(
ctx, self, ctx.channel.id, custom_api_key=custom_api_key
),
)
if from_context
else await ctx.reply(
response_text,
view=ConversationView(ctx, self, ctx.channel.id, custom_api_key=custom_api_key),
view=ConversationView(
ctx, self, ctx.channel.id, custom_api_key=custom_api_key
),
)
)

Expand Down Expand Up @@ -1368,12 +1389,18 @@ async def help(self, ctx: discord.ApplicationContext):
await self.send_help_text(ctx)

@discord.slash_command(
name="setup", description="Setup your API key for use with GPT3Discord", guild_ids=ALLOWED_GUILDS
name="setup",
description="Setup your API key for use with GPT3Discord",
guild_ids=ALLOWED_GUILDS,
)
@discord.guild_only()
async def setup(self, ctx: discord.ApplicationContext):
if not USER_INPUT_API_KEYS:
await ctx.respond("This server doesn't support user input API keys.", ephemeral=True, delete_after=30)
await ctx.respond(
"This server doesn't support user input API keys.",
ephemeral=True,
delete_after=30,
)

modal = SetupModal(title="API Key Setup")
await ctx.send_modal(modal)
Expand Down Expand Up @@ -1437,8 +1464,10 @@ def __init__(self, ctx, converser_cog, id, custom_api_key=None):
super().__init__(timeout=3600) # 1 hour interval to redo.
self.converser_cog = converser_cog
self.ctx = ctx
self.custom_api_key= custom_api_key
self.add_item(RedoButton(self.converser_cog, custom_api_key=self.custom_api_key))
self.custom_api_key = custom_api_key
self.add_item(
RedoButton(self.converser_cog, custom_api_key=self.custom_api_key)
)

if id in self.converser_cog.conversation_threads:
self.add_item(EndConvoButton(self.converser_cog))
Expand Down Expand Up @@ -1511,7 +1540,11 @@ async def callback(self, interaction: discord.Interaction):
)

await self.converser_cog.encapsulated_send(
id=user_id, prompt=prompt, ctx=ctx, response_message=response_message, custom_api_key=self.custom_api_key
id=user_id,
prompt=prompt,
ctx=ctx,
response_message=response_message,
custom_api_key=self.custom_api_key,
)
else:
await interaction.response.send_message(
Expand All @@ -1520,37 +1553,63 @@ async def callback(self, interaction: discord.Interaction):
delete_after=10,
)


class SetupModal(discord.ui.Modal):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.add_item(discord.ui.InputText(label="OpenAI API Key", placeholder="sk--......", ))
self.add_item(
discord.ui.InputText(
label="OpenAI API Key",
placeholder="sk--......",
)
)

async def callback(self, interaction: discord.Interaction):
user = interaction.user
api_key = self.children[0].value
# Validate that api_key is indeed in this format
if not re.match(r"sk-[a-zA-Z0-9]{32}", api_key):
await interaction.response.send_message("Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.", ephemeral=True, delete_after=100)
await interaction.response.send_message(
"Your API key looks invalid, please check that it is correct before proceeding. Please run the /setup command to set your key.",
ephemeral=True,
delete_after=100,
)
else:
# We can save the key for the user to the database.

# Make a test request using the api key to ensure that it is valid.
try:
await Model.send_test_request(api_key)
await interaction.response.send_message("Your API key was successfully validated.", ephemeral=True, delete_after=10)
await interaction.response.send_message(
"Your API key was successfully validated.",
ephemeral=True,
delete_after=10,
)
except Exception as e:
await interaction.response.send_message(f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding", ephemeral=True, delete_after=30)
await interaction.response.send_message(
f"Your API key looks invalid, the API returned: {e}. Please check that your API key is correct before proceeding",
ephemeral=True,
delete_after=30,
)
return

# Save the key to the database
try:
USER_KEY_DB[user.id] = api_key
USER_KEY_DB.commit()
await interaction.followup.send("Your API key was successfully saved.", ephemeral=True, delete_after=10)
await interaction.followup.send(
"Your API key was successfully saved.",
ephemeral=True,
delete_after=10,
)
except Exception as e:
traceback.print_exc()
await interaction.followup.send("There was an error saving your API key.", ephemeral=True, delete_after=30)
await interaction.followup.send(
"There was an error saving your API key.",
ephemeral=True,
delete_after=30,
)
return

pass
pass
36 changes: 30 additions & 6 deletions cogs/image_prompt_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
if USER_INPUT_API_KEYS:
USER_KEY_DB = SqliteDict("user_key_db.sqlite")


class ImgPromptOptimizer(discord.Cog, name="ImgPromptOptimizer"):
_OPTIMIZER_PRETEXT = "Optimize the following text for DALL-E image generation to have the most detailed and realistic image possible. Prompt:"

Expand Down Expand Up @@ -123,7 +124,10 @@ async def optimize(self, ctx: discord.ApplicationContext, prompt: str):
self.converser_cog.redo_users[user.id].add_interaction(response_message.id)
await response_message.edit(
view=OptimizeView(
self.converser_cog, self.image_service_cog, self.deletion_queue, custom_api_key=user_api_key,
self.converser_cog,
self.image_service_cog,
self.deletion_queue,
custom_api_key=user_api_key,
)
)

Expand All @@ -142,18 +146,36 @@ async def optimize(self, ctx: discord.ApplicationContext, prompt: str):


class OptimizeView(discord.ui.View):
def __init__(self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None):
def __init__(
self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None
):
super().__init__(timeout=None)
self.cog = converser_cog
self.image_service_cog = image_service_cog
self.deletion_queue = deletion_queue
self.custom_api_key = custom_api_key
self.add_item(RedoButton(self.cog, self.image_service_cog, self.deletion_queue, custom_api_key=self.custom_api_key))
self.add_item(DrawButton(self.cog, self.image_service_cog, self.deletion_queue, custom_api_key=self.custom_api_key))
self.add_item(
RedoButton(
self.cog,
self.image_service_cog,
self.deletion_queue,
custom_api_key=self.custom_api_key,
)
)
self.add_item(
DrawButton(
self.cog,
self.image_service_cog,
self.deletion_queue,
custom_api_key=self.custom_api_key,
)
)


class DrawButton(discord.ui.Button["OptimizeView"]):
def __init__(self, converser_cog, image_service_cog, deletion_queue, custom_api_key):
def __init__(
self, converser_cog, image_service_cog, deletion_queue, custom_api_key
):
super().__init__(style=discord.ButtonStyle.green, label="Draw")
self.converser_cog = converser_cog
self.image_service_cog = image_service_cog
Expand Down Expand Up @@ -206,7 +228,9 @@ async def callback(self, interaction: discord.Interaction):


class RedoButton(discord.ui.Button["OptimizeView"]):
def __init__(self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None):
def __init__(
self, converser_cog, image_service_cog, deletion_queue, custom_api_key=None
):
super().__init__(style=discord.ButtonStyle.danger, label="Retry")
self.converser_cog = converser_cog
self.image_service_cog = image_service_cog
Expand Down
Loading

0 comments on commit b47d52f

Please sign in to comment.