diff --git a/interactions/api/events/processors/message_events.py b/interactions/api/events/processors/message_events.py index 74a10cbde..7c9b2e1cb 100644 --- a/interactions/api/events/processors/message_events.py +++ b/interactions/api/events/processors/message_events.py @@ -95,11 +95,11 @@ async def _on_raw_message_poll_vote_add(self, event: "RawGatewayEvent") -> None: """ self.dispatch( events.MessagePollVoteAdd( + event.data.get("user_id"), + event.data.get("channel_id"), + event.data.get("message_id"), + event.data.get("answer_id"), event.data.get("guild_id", None), - event.data["channel_id"], - event.data["message_id"], - event.data["user_id"], - event.data["option"], ) ) @@ -114,10 +114,10 @@ async def _on_raw_message_poll_vote_remove(self, event: "RawGatewayEvent") -> No """ self.dispatch( events.MessagePollVoteRemove( + event.data.get("user_id"), + event.data.get("channel_id"), + event.data.get("message_id"), + event.data.get("answer_id"), event.data.get("guild_id", None), - event.data["channel_id"], - event.data["message_id"], - event.data["user_id"], - event.data["option"], ) ) diff --git a/interactions/api/voice/voice_gateway.py b/interactions/api/voice/voice_gateway.py index 670e139aa..c20c0560c 100644 --- a/interactions/api/voice/voice_gateway.py +++ b/interactions/api/voice/voice_gateway.py @@ -350,7 +350,7 @@ def send_packet(self, data: bytes, encoder, needs_encode=True) -> None: self.timestamp += encoder.samples_per_frame async def send_heartbeat(self) -> None: - await self.send_json({"op": OP.HEARTBEAT, "d": random.uniform(0.0, 1.0)}) + await self.send_json({"op": OP.HEARTBEAT, "d": random.getrandbits(64)}) self.logger.debug("❤ Voice Connection is sending Heartbeat") async def _identify(self) -> None: diff --git a/interactions/client/client.py b/interactions/client/client.py index fe9a5ff6f..3c4fed523 100644 --- a/interactions/client/client.py +++ b/interactions/client/client.py @@ -1489,6 +1489,9 @@ def add_command(self, func: Callable) -> None: elif not isinstance(func, BaseCommand): raise TypeError("Invalid command type") + for hook in self._add_command_hook: + hook(func) + if not func.callback: # for group = SlashCommand(...) usage return @@ -1499,9 +1502,6 @@ def add_command(self, func: Callable) -> None: else: self.logger.debug(f"Added callback: {func.callback.__name__}") - for hook in self._add_command_hook: - hook(func) - self.dispatch(CallbackAdded(callback=func, extension=func.extension if hasattr(func, "extension") else None)) def _gather_callbacks(self) -> None: diff --git a/interactions/ext/hybrid_commands/manager.py b/interactions/ext/hybrid_commands/manager.py index 6de809e67..9cf5f464d 100644 --- a/interactions/ext/hybrid_commands/manager.py +++ b/interactions/ext/hybrid_commands/manager.py @@ -81,6 +81,27 @@ def _add_hybrid_command(self, callback: Callable): return cmd = callback + + if not cmd.callback or cmd._dummy_base: + if cmd.group_name: + if not (group := self.client.prefixed.get_command(f"{cmd.name} {cmd.group_name}")): + group = base_subcommand_generator( + str(cmd.group_name), + list(_values_wrapper(cmd.group_name.to_locale_dict())) + cmd.aliases, + str(cmd.group_name), + group=True, + ) + self.client.prefixed.commands[str(cmd.name)].add_command(group) + elif not (base := self.client.prefixed.commands.get(str(cmd.name))): + base = base_subcommand_generator( + str(cmd.name), + list(_values_wrapper(cmd.name.to_locale_dict())) + cmd.aliases, + str(cmd.name), + group=False, + ) + self.client.prefixed.add_command(base) + return + prefixed_transform = slash_to_prefixed(cmd) if self.use_slash_command_msg: @@ -91,7 +112,7 @@ def _add_hybrid_command(self, callback: Callable): if not (base := self.client.prefixed.commands.get(str(cmd.name))): base = base_subcommand_generator( str(cmd.name), - list(_values_wrapper(cmd.name.to_locale_dict())) + cmd.aliases, + list(_values_wrapper(cmd.name.to_locale_dict())), str(cmd.name), group=False, ) @@ -102,7 +123,7 @@ def _add_hybrid_command(self, callback: Callable): if not (group := base.subcommands.get(str(cmd.group_name))): group = base_subcommand_generator( str(cmd.group_name), - list(_values_wrapper(cmd.group_name.to_locale_dict())) + cmd.aliases, + list(_values_wrapper(cmd.group_name.to_locale_dict())), str(cmd.group_name), group=True, ) diff --git a/interactions/models/discord/enums.py b/interactions/models/discord/enums.py index 550711749..75dab61b1 100644 --- a/interactions/models/discord/enums.py +++ b/interactions/models/discord/enums.py @@ -448,6 +448,7 @@ class EmbedType(Enum): LINK = "link" AUTOMOD_MESSAGE = "auto_moderation_message" AUTOMOD_NOTIFICATION = "auto_moderation_notification" + POLL_RESULT = "poll_result" class MessageActivityType(CursedIntEnum): diff --git a/interactions/models/discord/message.py b/interactions/models/discord/message.py index 56a051a6b..abe982653 100644 --- a/interactions/models/discord/message.py +++ b/interactions/models/discord/message.py @@ -275,7 +275,7 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any] @property def user(self) -> "models.User": """Get the user associated with this interaction.""" - return self.client.get_user(self.user_id) + return self.client.get_user(self._user_id) @attrs.define(eq=False, order=False, hash=False, kw_only=False) diff --git a/interactions/models/discord/poll.py b/interactions/models/discord/poll.py index 7498e0c14..9051791ef 100644 --- a/interactions/models/discord/poll.py +++ b/interactions/models/discord/poll.py @@ -86,7 +86,7 @@ class PollResults(DictSerializationMixin): @attrs.define(eq=False, order=False, hash=False, kw_only=True) class Poll(DictSerializationMixin): - question: PollMedia = attrs.field(repr=False) + question: PollMedia = attrs.field(repr=False, converter=PollMedia.from_dict) """The question of the poll. Only text media is supported.""" answers: list[PollAnswer] = attrs.field(repr=False, factory=list, converter=PollAnswer.from_list) """Each of the answers available in the poll, up to 10.""" diff --git a/interactions/models/internal/application_commands.py b/interactions/models/internal/application_commands.py index c0f2b1b8b..a9727b8b5 100644 --- a/interactions/models/internal/application_commands.py +++ b/interactions/models/internal/application_commands.py @@ -1,4 +1,5 @@ import asyncio +from collections import defaultdict import inspect import re import typing @@ -288,6 +289,8 @@ def _dm_permission_validator(self, attribute: str, value: bool) -> None: def to_dict(self) -> dict: data = super().to_dict() + data["name_localizations"] = self.name.to_locale_dict() + if self.default_member_permissions is not None: data["default_member_permissions"] = str(int(self.default_member_permissions)) else: @@ -1466,9 +1469,9 @@ def application_commands_to_dict( # noqa: C901 `Client.interactions` should be the variable passed to this """ - cmd_bases = {} # {cmd_base: [commands]} + cmd_bases: defaultdict[str, list[InteractionCommand]] = defaultdict(list) # {cmd_base: [commands]} """A store of commands organised by their base command""" - output = {} + output: defaultdict["Snowflake_Type", list[dict]] = defaultdict(list) """The output dictionary""" def squash_subcommand(subcommands: List) -> Dict: @@ -1514,9 +1517,6 @@ def squash_subcommand(subcommands: List) -> Dict: for _scope, cmds in commands.items(): for cmd in cmds.values(): cmd_name = str(cmd.name) - if cmd_name not in cmd_bases: - cmd_bases[cmd_name] = [cmd] - continue if cmd not in cmd_bases[cmd_name]: cmd_bases[cmd_name].append(cmd) @@ -1556,15 +1556,14 @@ def squash_subcommand(subcommands: List) -> Dict: cmd.nsfw = nsfw # end validation of attributes cmd_data = squash_subcommand(cmd_list) + + for s in scopes: + output[s].append(cmd_data) else: - scopes = cmd_list[0].scopes - cmd_data = cmd_list[0].to_dict() - for s in scopes: - if s not in output: - output[s] = [cmd_data] - continue - output[s].append(cmd_data) - return output + for cmd in cmd_list: + for s in cmd.scopes: + output[s].append(cmd.to_dict()) + return dict(output) def _compare_commands(local_cmd: dict, remote_cmd: dict) -> bool: diff --git a/interactions/models/internal/context.py b/interactions/models/internal/context.py index 775e80f06..7eac000be 100644 --- a/interactions/models/internal/context.py +++ b/interactions/models/internal/context.py @@ -296,7 +296,7 @@ def from_dict(cls, client: "ClientT", payload: dict) -> Self: instance.guild_locale = payload.get("guild_locale", instance.locale) instance._context_type = payload.get("type", 0) instance.resolved = Resolved.from_dict(client, payload["data"].get("resolved", {}), payload.get("guild_id")) - instance.entitlements = Entitlement.from_list(payload["entitlements"], client) + instance.entitlements = Entitlement.from_list(payload.get("entitlements", []), client) instance.context = ContextType(payload["context"]) if payload.get("context") else None instance.authorizing_integration_owners = { IntegrationType(int(integration_type)): Snowflake(owner_id) @@ -345,8 +345,8 @@ def author_permissions(self) -> Permissions: return Permissions(0) @property - def command(self) -> InteractionCommand: - return self.client._interaction_lookup[self._command_name] + def command(self) -> typing.Optional[InteractionCommand]: + return self.client._interaction_lookup.get(self._command_name) @property def expires_at(self) -> Timestamp: diff --git a/interactions/models/internal/tasks/triggers.py b/interactions/models/internal/tasks/triggers.py index 5ac993236..d15a319ea 100644 --- a/interactions/models/internal/tasks/triggers.py +++ b/interactions/models/internal/tasks/triggers.py @@ -163,4 +163,4 @@ def __init__(self, cron: str, tz: "_TzInfo" = timezone.utc) -> None: self.tz = tz def next_fire(self) -> datetime | None: - return croniter(self.cron, datetime.now(tz=self.tz)).next(datetime) + return croniter(self.cron, self.last_call_time.astimezone(self.tz)).next(datetime) diff --git a/pyproject.toml b/pyproject.toml index b93cbf7ea..ddd04ff06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "interactions.py" -version = "5.13.1" +version = "5.13.2" description = "Easy, simple, scalable and modular: a Python API wrapper for interactions." authors = ["LordOfPolls "] @@ -93,14 +93,12 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING:"] [tool.coverage.run] omit = ["tests/*"] +source = ["interactions"] [build-system] requires = ["setuptools", "tomli"] build-backend = "setuptools.build_meta" -[tools.coverage.run] -source = ["interactions"] - [tool.pytest.ini_options] addopts = "-l -ra --durations=2 --junitxml=TestResults.xml" doctest_optionflags = "NORMALIZE_WHITESPACE" diff --git a/requirements.txt b/requirements.txt index 293eabbec..ccfed7aa0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ aiohttp attrs>=22.1 +audioop-lts; python_version>='3.13' croniter discord-typings>=0.9.0 emoji diff --git a/tests/test_bot.py b/tests/test_bot.py index 9e2f8c7c6..33267eaf5 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -3,7 +3,7 @@ import os from asyncio import AbstractEventLoop from contextlib import suppress -from datetime import datetime +from datetime import datetime, timedelta import pytest import pytest_asyncio @@ -33,6 +33,8 @@ ParagraphText, Message, GuildVoice, + Poll, + PollMedia, ) from interactions.models.discord.asset import Asset from interactions.models.discord.components import ActionRow, Button, StringSelectMenu @@ -432,6 +434,95 @@ async def test_components(bot: Client, channel: GuildText) -> None: await thread.delete() +@pytest.mark.asyncio +async def test_polls(bot: Client, channel: GuildText) -> None: + msg = await channel.send("Polls Tests") + thread = await msg.create_thread("Test Thread") + + try: + poll_1 = Poll.create("Test Poll", duration=1, answers=["Answer 1", "Answer 2"]) + test_data_1 = { + "question": {"text": "Test Poll"}, + "layout_type": 1, + "duration": 1, + "allow_multiselect": False, + "answers": [{"poll_media": {"text": "Answer 1"}}, {"poll_media": {"text": "Answer 2"}}], + } + poll_1_dict = poll_1.to_dict() + for key in poll_1_dict.keys(): + assert poll_1_dict[key] == test_data_1[key] + + msg_1 = await thread.send(poll=poll_1) + + assert msg_1.poll is not None + assert msg_1.poll.question.to_dict() == PollMedia(text="Test Poll").to_dict() + assert msg_1.poll.expiry <= msg_1.created_at + timedelta(hours=1, minutes=1) + poll_1_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_1.poll.answers] + assert poll_1_answer_medias == [ + PollMedia.create(text="Answer 1").to_dict(), + PollMedia.create(text="Answer 2").to_dict(), + ] + + poll_2 = Poll.create("Test Poll 2", duration=1, allow_multiselect=True) + poll_2.add_answer("Answer 1") + poll_2.add_answer("Answer 2") + test_data_2 = { + "question": {"text": "Test Poll 2"}, + "layout_type": 1, + "duration": 1, + "allow_multiselect": True, + "answers": [{"poll_media": {"text": "Answer 1"}}, {"poll_media": {"text": "Answer 2"}}], + } + poll_2_dict = poll_2.to_dict() + for key in poll_2_dict.keys(): + assert poll_2_dict[key] == test_data_2[key] + msg_2 = await thread.send(poll=poll_2) + + assert msg_2.poll is not None + assert msg_2.poll.question.to_dict() == PollMedia(text="Test Poll 2").to_dict() + assert msg_2.poll.expiry <= msg_2.created_at + timedelta(hours=1, minutes=1) + assert msg_2.poll.allow_multiselect + poll_2_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_2.poll.answers] + assert poll_2_answer_medias == [ + PollMedia.create(text="Answer 1").to_dict(), + PollMedia.create(text="Answer 2").to_dict(), + ] + + poll_3 = Poll.create( + "Test Poll 3", + duration=1, + answers=[PollMedia.create(text="One", emoji="1️⃣"), PollMedia.create(text="Two", emoji="2️⃣")], + ) + test_data_3 = { + "question": {"text": "Test Poll 3"}, + "layout_type": 1, + "duration": 1, + "allow_multiselect": False, + "answers": [ + {"poll_media": {"text": "One", "emoji": {"name": "1️⃣", "animated": False}}}, + {"poll_media": {"text": "Two", "emoji": {"name": "2️⃣", "animated": False}}}, + ], + } + poll_3_dict = poll_3.to_dict() + for key in poll_3_dict.keys(): + assert poll_3_dict[key] == test_data_3[key] + + msg_3 = await thread.send(poll=poll_3) + + assert msg_3.poll is not None + assert msg_3.poll.question.to_dict() == PollMedia(text="Test Poll 3").to_dict() + assert msg_3.poll.expiry <= msg_3.created_at + timedelta(hours=1, minutes=1) + poll_3_answer_medias = [poll_answer.poll_media.to_dict() for poll_answer in msg_3.poll.answers] + assert poll_3_answer_medias == [ + PollMedia.create(text="One", emoji="1️⃣").to_dict(), + PollMedia.create(text="Two", emoji="2️⃣").to_dict(), + ] + + finally: + with suppress(interactions.errors.NotFound): + await thread.delete() + + @pytest.mark.asyncio async def test_webhooks(bot: Client, guild: Guild, channel: GuildText) -> None: test_thread = await channel.create_thread("Test Thread")