Skip to content

Commit

Permalink
Fix and improve polls (#147)
Browse files Browse the repository at this point in the history
Refactor a little the whole view and menu system, allowing awaitable
`__init__()`.

Also added 3 new colors, to allow up to 10 choices.
Fixed a bug when we tried to add more choices than possible.
Added some icons to improve the UI of poll creation.
Fixed an i18n bug when an empty tring is passed to the i18n function.
  • Loading branch information
AiroPi authored Mar 29, 2024
1 parent 3559f2b commit be4fd92
Show file tree
Hide file tree
Showing 14 changed files with 462 additions and 356 deletions.
6 changes: 3 additions & 3 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ pyproject-hooks==1.0.0
# via
# build
# pip-tools
pyright==1.1.354
pyright==1.1.355
# via -r requirements.dev.in
ruff==0.3.3
ruff==0.3.4
# via -r requirements.dev.in
tox==4.14.1
tox==4.14.2
# via -r requirements.dev.in
virtualenv==20.25.1
# via tox
Expand Down
8 changes: 3 additions & 5 deletions src/cogs/clear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def working_display(self) -> MessageDisplay:
return display

async def start(self):
view = await CancelClearView(self.bot, self.inter.user.id).build()
view = await CancelClearView(self.bot, self.inter.user.id)

await self.inter.response.send_message(
**self.working_display(),
Expand Down Expand Up @@ -265,14 +265,12 @@ async def filtered_history(self) -> AsyncGenerator[discord.Message, None]:


class CancelClearView(Menu):
def __init__(self, bot: MyBot, user_id: int):
super().__init__(bot, timeout=3 * 60)
async def __init__(self, bot: MyBot, user_id: int):
await super().__init__(bot, timeout=3 * 60)
self.pressed: bool = False
self.user_id: int = user_id

async def build(self):
self.cancel.label = _("Cancel")
return self

async def interaction_check(self, interaction: discord.Interaction) -> bool:
await interaction.response.defer()
Expand Down
14 changes: 7 additions & 7 deletions src/cogs/poll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, bot: MyBot):
self.current_votes: dict[int, dict[int, tuple[Interaction, ui.View]]] = {} # poll_id: {user_id: interaction}

async def cog_load(self) -> None:
self.bot.add_view(PollPublicMenu(self))
self.bot.add_view(await PollPublicMenu(self))

async def callback(self, inter: Interaction, poll_type: db.PollType) -> None:
channel_id = cast(int, inter.channel_id) # not usable in private messages
Expand Down Expand Up @@ -102,8 +102,8 @@ async def edit_poll(self, inter: Interaction, message: discord.Message) -> None:
if poll.author_id != inter.user.id:
raise NonSpecificError(_("You are not the author of this poll. You can't edit it.", _l=256))
await inter.response.send_message(
**(await PollDisplay.build(poll, self.bot)),
view=await EditPoll(self, poll, message).build(),
**(await PollDisplay(poll, self.bot)),
view=await EditPoll(self, poll, message, inter),
ephemeral=True,
)

Expand Down Expand Up @@ -133,8 +133,8 @@ async def on_submit(self, inter: discord.Interaction):
self.poll.title = self.question.value
self.poll.description = self.description.value
await inter.response.send_message(
**(await PollDisplay.build(self.poll, self.bot)),
view=await EditPoll(self.cog, self.poll, inter.message).build(),
**(await PollDisplay(self.poll, self.bot)),
view=await EditPoll(self.cog, self.poll, inter.message, inter),
ephemeral=True,
)

Expand Down Expand Up @@ -173,8 +173,8 @@ async def on_submit(self, inter: discord.Interaction):
self.poll.choices.append(db.PollChoice(poll_id=self.poll.id, label=self.choice3.value))

await inter.response.send_message(
**(await PollDisplay.build(self.poll, self.bot)),
view=await EditPoll(self.cog, self.poll, inter.message).build(),
**(await PollDisplay(self.poll, self.bot)),
view=await EditPoll(self.cog, self.poll, inter.message, inter),
ephemeral=True,
)

Expand Down
5 changes: 4 additions & 1 deletion src/cogs/poll/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from core import Emojis

COLORS_ORDER = ["blue", "red", "yellow", "purple", "brown", "green", "orange"]
COLORS_ORDER = ["blue", "red", "yellow", "purple", "brown", "green", "orange", "pink", "lime", "blue_green"]
COLOR_TO_HEX = {
"blue": 0x54ACEE,
"red": 0xDD2D44,
Expand All @@ -9,6 +9,9 @@
"brown": 0xC1694F,
"green": 0x78B159,
"orange": 0xF4900E,
"pink": 0xFFB7CE,
"lime": 0xBEFD73,
"blue_green": 0x9ADEDB,
}
LEGEND_EMOJIS = [getattr(Emojis, f"{color}_round") for color in COLORS_ORDER]
GRAPH_EMOJIS = [getattr(Emojis, f"{color}_mid") for color in COLORS_ORDER]
Expand Down
72 changes: 35 additions & 37 deletions src/cogs/poll/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from core import Emojis, db
from core.i18n import _
from core.response import MessageDisplay
from core.utils import AsyncInitMixin

from .constants import (
BOOLEAN_INDEXES,
Expand All @@ -28,56 +29,53 @@
from mybot import MyBot


class PollDisplay:
def __init__(self, poll: Poll, votes: dict[str, int] | None):
self.poll: Poll = poll
self.votes = votes
class PollDisplay(AsyncInitMixin, MessageDisplay):
async def __init__(self, poll: Poll, bot: MyBot, old_embed: Embed | None = None):
self.poll = poll
self.votes: dict[str, int] | None = await self.get_votes(bot)

@classmethod
async def build(cls, poll: Poll, bot: MyBot, old_embed: Embed | None = None) -> MessageDisplay:
content = poll.description
embed = discord.Embed(title=poll.title)

votes: dict[str, int] | None
if poll.public_results is True:
async with bot.async_session.begin() as session:
stmt = (
db.select(db.PollAnswer.value, func.count())
.select_from(db.PollAnswer)
.where(db.PollAnswer.poll_id == poll.id)
.group_by(db.PollAnswer.value)
)

votes = { # noqa: C416, dict comprehension used for typing purposes
key: value
for key, value in (await session.execute(stmt)).all() # choice_id: vote_count
}
if poll.type == db.PollType.CHOICE:
# when we delete a choice from a poll, the votes are still in the db before commit
# so we need to filter them
votes = {
key: value for key, value in votes.items() if key in (str(choice.id) for choice in poll.choices)
}
else:
votes = None

poll_display = cls(poll, votes)

description_split: list[str] = [poll_display.build_end_date(), poll_display.build_legend()]

description_split: list[str] = [self.build_end_date(), self.build_legend()]
embed.description = "\n".join(description_split)

if poll.public_results:
embed.add_field(name="\u200b", value=poll_display.build_graph())
embed.color = poll_display.build_color()
embed.add_field(name="\u200b", value=self.build_graph())
embed.color = self.build_color()

if old_embed:
embed.set_footer(text=old_embed.footer.text)
else:
author = await bot.getch_user(poll.author_id)
embed.set_footer(text=_("Poll created by {}", author.name if author else "unknown"))

return MessageDisplay(content=content, embed=embed)
MessageDisplay.__init__(self, content=content, embed=embed)

async def get_votes(self, bot: MyBot) -> dict[str, int] | None:
if self.poll.public_results is False:
return None

async with bot.async_session.begin() as session:
stmt = (
db.select(db.PollAnswer.value, func.count())
.select_from(db.PollAnswer)
.where(db.PollAnswer.poll_id == self.poll.id)
.group_by(db.PollAnswer.value)
)

votes = { # noqa: C416, dict comprehension used for typing purposes
key: value
for key, value in (await session.execute(stmt)).all() # choice_id: vote_count
}
if self.poll.type == db.PollType.CHOICE:
# when we delete a choice from a poll, the votes are still in the db before commit
# so we need to filter them
votes = {
key: value
for key, value in votes.items()
if key in (str(choice.id) for choice in self.poll.choices)
}
return votes

@property
def total_votes(self) -> int:
Expand Down
Loading

0 comments on commit be4fd92

Please sign in to comment.