Skip to content

Commit

Permalink
Merge pull request #1719 from interactions-py/unstable
Browse files Browse the repository at this point in the history
5.13.2
  • Loading branch information
silasary authored Aug 27, 2024
2 parents acd44d0 + d621ef7 commit 369a291
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 38 deletions.
16 changes: 8 additions & 8 deletions interactions/api/events/processors/message_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ async def _on_raw_message_poll_vote_add(self, event: "RawGatewayEvent") -> None:
"""
self.dispatch(
events.MessagePollVoteAdd(
event.data.get("user_id"),
event.data.get("channel_id"),
event.data.get("message_id"),
event.data.get("answer_id"),
event.data.get("guild_id", None),
event.data["channel_id"],
event.data["message_id"],
event.data["user_id"],
event.data["option"],
)
)

Expand All @@ -114,10 +114,10 @@ async def _on_raw_message_poll_vote_remove(self, event: "RawGatewayEvent") -> No
"""
self.dispatch(
events.MessagePollVoteRemove(
event.data.get("user_id"),
event.data.get("channel_id"),
event.data.get("message_id"),
event.data.get("answer_id"),
event.data.get("guild_id", None),
event.data["channel_id"],
event.data["message_id"],
event.data["user_id"],
event.data["option"],
)
)
2 changes: 1 addition & 1 deletion interactions/api/voice/voice_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def send_packet(self, data: bytes, encoder, needs_encode=True) -> None:
self.timestamp += encoder.samples_per_frame

async def send_heartbeat(self) -> None:
await self.send_json({"op": OP.HEARTBEAT, "d": random.uniform(0.0, 1.0)})
await self.send_json({"op": OP.HEARTBEAT, "d": random.getrandbits(64)})
self.logger.debug("❤ Voice Connection is sending Heartbeat")

async def _identify(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions interactions/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,9 @@ def add_command(self, func: Callable) -> None:
elif not isinstance(func, BaseCommand):
raise TypeError("Invalid command type")

for hook in self._add_command_hook:
hook(func)

if not func.callback:
# for group = SlashCommand(...) usage
return
Expand All @@ -1499,9 +1502,6 @@ def add_command(self, func: Callable) -> None:
else:
self.logger.debug(f"Added callback: {func.callback.__name__}")

for hook in self._add_command_hook:
hook(func)

self.dispatch(CallbackAdded(callback=func, extension=func.extension if hasattr(func, "extension") else None))

def _gather_callbacks(self) -> None:
Expand Down
25 changes: 23 additions & 2 deletions interactions/ext/hybrid_commands/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,27 @@ def _add_hybrid_command(self, callback: Callable):
return

cmd = callback

if not cmd.callback or cmd._dummy_base:
if cmd.group_name:
if not (group := self.client.prefixed.get_command(f"{cmd.name} {cmd.group_name}")):
group = base_subcommand_generator(
str(cmd.group_name),
list(_values_wrapper(cmd.group_name.to_locale_dict())) + cmd.aliases,
str(cmd.group_name),
group=True,
)
self.client.prefixed.commands[str(cmd.name)].add_command(group)
elif not (base := self.client.prefixed.commands.get(str(cmd.name))):
base = base_subcommand_generator(
str(cmd.name),
list(_values_wrapper(cmd.name.to_locale_dict())) + cmd.aliases,
str(cmd.name),
group=False,
)
self.client.prefixed.add_command(base)
return

prefixed_transform = slash_to_prefixed(cmd)

if self.use_slash_command_msg:
Expand All @@ -91,7 +112,7 @@ def _add_hybrid_command(self, callback: Callable):
if not (base := self.client.prefixed.commands.get(str(cmd.name))):
base = base_subcommand_generator(
str(cmd.name),
list(_values_wrapper(cmd.name.to_locale_dict())) + cmd.aliases,
list(_values_wrapper(cmd.name.to_locale_dict())),
str(cmd.name),
group=False,
)
Expand All @@ -102,7 +123,7 @@ def _add_hybrid_command(self, callback: Callable):
if not (group := base.subcommands.get(str(cmd.group_name))):
group = base_subcommand_generator(
str(cmd.group_name),
list(_values_wrapper(cmd.group_name.to_locale_dict())) + cmd.aliases,
list(_values_wrapper(cmd.group_name.to_locale_dict())),
str(cmd.group_name),
group=True,
)
Expand Down
1 change: 1 addition & 0 deletions interactions/models/discord/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ class EmbedType(Enum):
LINK = "link"
AUTOMOD_MESSAGE = "auto_moderation_message"
AUTOMOD_NOTIFICATION = "auto_moderation_notification"
POLL_RESULT = "poll_result"


class MessageActivityType(CursedIntEnum):
Expand Down
2 changes: 1 addition & 1 deletion interactions/models/discord/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]
@property
def user(self) -> "models.User":
"""Get the user associated with this interaction."""
return self.client.get_user(self.user_id)
return self.client.get_user(self._user_id)


@attrs.define(eq=False, order=False, hash=False, kw_only=False)
Expand Down
2 changes: 1 addition & 1 deletion interactions/models/discord/poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class PollResults(DictSerializationMixin):

@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class Poll(DictSerializationMixin):
question: PollMedia = attrs.field(repr=False)
question: PollMedia = attrs.field(repr=False, converter=PollMedia.from_dict)
"""The question of the poll. Only text media is supported."""
answers: list[PollAnswer] = attrs.field(repr=False, factory=list, converter=PollAnswer.from_list)
"""Each of the answers available in the poll, up to 10."""
Expand Down
25 changes: 12 additions & 13 deletions interactions/models/internal/application_commands.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from collections import defaultdict
import inspect
import re
import typing
Expand Down Expand Up @@ -288,6 +289,8 @@ def _dm_permission_validator(self, attribute: str, value: bool) -> None:
def to_dict(self) -> dict:
data = super().to_dict()

data["name_localizations"] = self.name.to_locale_dict()

if self.default_member_permissions is not None:
data["default_member_permissions"] = str(int(self.default_member_permissions))
else:
Expand Down Expand Up @@ -1466,9 +1469,9 @@ def application_commands_to_dict( # noqa: C901
`Client.interactions` should be the variable passed to this
"""
cmd_bases = {} # {cmd_base: [commands]}
cmd_bases: defaultdict[str, list[InteractionCommand]] = defaultdict(list) # {cmd_base: [commands]}
"""A store of commands organised by their base command"""
output = {}
output: defaultdict["Snowflake_Type", list[dict]] = defaultdict(list)
"""The output dictionary"""

def squash_subcommand(subcommands: List) -> Dict:
Expand Down Expand Up @@ -1514,9 +1517,6 @@ def squash_subcommand(subcommands: List) -> Dict:
for _scope, cmds in commands.items():
for cmd in cmds.values():
cmd_name = str(cmd.name)
if cmd_name not in cmd_bases:
cmd_bases[cmd_name] = [cmd]
continue
if cmd not in cmd_bases[cmd_name]:
cmd_bases[cmd_name].append(cmd)

Expand Down Expand Up @@ -1556,15 +1556,14 @@ def squash_subcommand(subcommands: List) -> Dict:
cmd.nsfw = nsfw
# end validation of attributes
cmd_data = squash_subcommand(cmd_list)

for s in scopes:
output[s].append(cmd_data)
else:
scopes = cmd_list[0].scopes
cmd_data = cmd_list[0].to_dict()
for s in scopes:
if s not in output:
output[s] = [cmd_data]
continue
output[s].append(cmd_data)
return output
for cmd in cmd_list:
for s in cmd.scopes:
output[s].append(cmd.to_dict())
return dict(output)


def _compare_commands(local_cmd: dict, remote_cmd: dict) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions interactions/models/internal/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def from_dict(cls, client: "ClientT", payload: dict) -> Self:
instance.guild_locale = payload.get("guild_locale", instance.locale)
instance._context_type = payload.get("type", 0)
instance.resolved = Resolved.from_dict(client, payload["data"].get("resolved", {}), payload.get("guild_id"))
instance.entitlements = Entitlement.from_list(payload["entitlements"], client)
instance.entitlements = Entitlement.from_list(payload.get("entitlements", []), client)
instance.context = ContextType(payload["context"]) if payload.get("context") else None
instance.authorizing_integration_owners = {
IntegrationType(int(integration_type)): Snowflake(owner_id)
Expand Down Expand Up @@ -345,8 +345,8 @@ def author_permissions(self) -> Permissions:
return Permissions(0)

@property
def command(self) -> InteractionCommand:
return self.client._interaction_lookup[self._command_name]
def command(self) -> typing.Optional[InteractionCommand]:
return self.client._interaction_lookup.get(self._command_name)

@property
def expires_at(self) -> Timestamp:
Expand Down
2 changes: 1 addition & 1 deletion interactions/models/internal/tasks/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,4 @@ def __init__(self, cron: str, tz: "_TzInfo" = timezone.utc) -> None:
self.tz = tz

def next_fire(self) -> datetime | None:
return croniter(self.cron, datetime.now(tz=self.tz)).next(datetime)
return croniter(self.cron, self.last_call_time.astimezone(self.tz)).next(datetime)
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "interactions.py"
version = "5.13.1"
version = "5.13.2"
description = "Easy, simple, scalable and modular: a Python API wrapper for interactions."
authors = ["LordOfPolls <[email protected]>"]

Expand Down Expand Up @@ -93,14 +93,12 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING:"]

[tool.coverage.run]
omit = ["tests/*"]
source = ["interactions"]

[build-system]
requires = ["setuptools", "tomli"]
build-backend = "setuptools.build_meta"

[tools.coverage.run]
source = ["interactions"]

[tool.pytest.ini_options]
addopts = "-l -ra --durations=2 --junitxml=TestResults.xml"
doctest_optionflags = "NORMALIZE_WHITESPACE"
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
aiohttp
attrs>=22.1
audioop-lts; python_version>='3.13'
croniter
discord-typings>=0.9.0
emoji
Expand Down
93 changes: 92 additions & 1 deletion tests/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from asyncio import AbstractEventLoop
from contextlib import suppress
from datetime import datetime
from datetime import datetime, timedelta

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -33,6 +33,8 @@
ParagraphText,
Message,
GuildVoice,
Poll,
PollMedia,
)
from interactions.models.discord.asset import Asset
from interactions.models.discord.components import ActionRow, Button, StringSelectMenu
Expand Down Expand Up @@ -432,6 +434,95 @@ async def test_components(bot: Client, channel: GuildText) -> None:
await thread.delete()


@pytest.mark.asyncio
async def test_polls(bot: Client, channel: GuildText) -> None:
msg = await channel.send("Polls Tests")
thread = await msg.create_thread("Test Thread")

try:
poll_1 = Poll.create("Test Poll", duration=1, answers=["Answer 1", "Answer 2"])
test_data_1 = {
"question": {"text": "Test Poll"},
"layout_type": 1,
"duration": 1,
"allow_multiselect": False,
"answers": [{"poll_media": {"text": "Answer 1"}}, {"poll_media": {"text": "Answer 2"}}],
}
poll_1_dict = poll_1.to_dict()
for key in poll_1_dict.keys():
assert poll_1_dict[key] == test_data_1[key]

msg_1 = await thread.send(poll=poll_1)

assert msg_1.poll is not None
assert msg_1.poll.question.to_dict() == PollMedia(text="Test Poll").to_dict()
assert msg_1.poll.expiry <= msg_1.created_at + timedelta(hours=1, minutes=1)
poll_1_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_1.poll.answers]
assert poll_1_answer_medias == [
PollMedia.create(text="Answer 1").to_dict(),
PollMedia.create(text="Answer 2").to_dict(),
]

poll_2 = Poll.create("Test Poll 2", duration=1, allow_multiselect=True)
poll_2.add_answer("Answer 1")
poll_2.add_answer("Answer 2")
test_data_2 = {
"question": {"text": "Test Poll 2"},
"layout_type": 1,
"duration": 1,
"allow_multiselect": True,
"answers": [{"poll_media": {"text": "Answer 1"}}, {"poll_media": {"text": "Answer 2"}}],
}
poll_2_dict = poll_2.to_dict()
for key in poll_2_dict.keys():
assert poll_2_dict[key] == test_data_2[key]
msg_2 = await thread.send(poll=poll_2)

assert msg_2.poll is not None
assert msg_2.poll.question.to_dict() == PollMedia(text="Test Poll 2").to_dict()
assert msg_2.poll.expiry <= msg_2.created_at + timedelta(hours=1, minutes=1)
assert msg_2.poll.allow_multiselect
poll_2_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_2.poll.answers]
assert poll_2_answer_medias == [
PollMedia.create(text="Answer 1").to_dict(),
PollMedia.create(text="Answer 2").to_dict(),
]

poll_3 = Poll.create(
"Test Poll 3",
duration=1,
answers=[PollMedia.create(text="One", emoji="1️⃣"), PollMedia.create(text="Two", emoji="2️⃣")],
)
test_data_3 = {
"question": {"text": "Test Poll 3"},
"layout_type": 1,
"duration": 1,
"allow_multiselect": False,
"answers": [
{"poll_media": {"text": "One", "emoji": {"name": "1️⃣", "animated": False}}},
{"poll_media": {"text": "Two", "emoji": {"name": "2️⃣", "animated": False}}},
],
}
poll_3_dict = poll_3.to_dict()
for key in poll_3_dict.keys():
assert poll_3_dict[key] == test_data_3[key]

msg_3 = await thread.send(poll=poll_3)

assert msg_3.poll is not None
assert msg_3.poll.question.to_dict() == PollMedia(text="Test Poll 3").to_dict()
assert msg_3.poll.expiry <= msg_3.created_at + timedelta(hours=1, minutes=1)
poll_3_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_3.poll.answers]
assert poll_3_answer_medias == [
PollMedia.create(text="One", emoji="1️⃣").to_dict(),
PollMedia.create(text="Two", emoji="2️⃣").to_dict(),
]

finally:
with suppress(interactions.errors.NotFound):
await thread.delete()


@pytest.mark.asyncio
async def test_webhooks(bot: Client, guild: Guild, channel: GuildText) -> None:
test_thread = await channel.create_thread("Test Thread")
Expand Down

0 comments on commit 369a291

Please sign in to comment.