Skip to content

Commit

Permalink
fix: hook onto add_command to propagate errors correctly (#1678)
Browse files Browse the repository at this point in the history
* fix: hook onto add_command to propagate errors correctly

* fix: re-add on_callback_added to add_hybrid_command
  • Loading branch information
AstreaTSS authored May 6, 2024
1 parent f815149 commit a8f6fcf
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
5 changes: 5 additions & 0 deletions interactions/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ def __init__(
self.async_startup_tasks: list[tuple[Callable[..., Coroutine], Iterable[Any], dict[str, Any]]] = []
"""A list of coroutines to run during startup"""

self._add_command_hook: list[Callable[[Callable], Any]] = []

# callbacks
if global_pre_run_callback:
if asyncio.iscoroutinefunction(global_pre_run_callback):
Expand Down Expand Up @@ -1416,6 +1418,9 @@ def add_command(self, func: Callable) -> None:
else:
self.logger.debug(f"Added callback: {func.callback.__name__}")

for hook in self._add_command_hook:
hook(func)

self.dispatch(CallbackAdded(callback=func, extension=func.extension if hasattr(func, "extension") else None))

def _gather_callbacks(self) -> None:
Expand Down
17 changes: 9 additions & 8 deletions interactions/ext/hybrid_commands/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,22 @@ def __init__(
self.client = cast(prefixed.PrefixedInjectedClient, client)
self.ext_command_list: dict[str, list[str]] = {}

self.client.add_listener(self.add_hybrid_command.copy_with_binding(self))
self.client.add_listener(self.handle_ext_unload.copy_with_binding(self))

self.client._add_command_hook.append(self._add_hybrid_command)

self.client.hybrid = self

@listen("on_callback_added")
async def add_hybrid_command(self, event: CallbackAdded):
if (
not isinstance(event.callback, HybridSlashCommand)
or not event.callback.callback
or event.callback._dummy_base
):
async def add_hybrid_command(self, event: CallbackAdded) -> None:
# just here for backwards compatability since it was accidentially public, don't rely on it
self._add_hybrid_command(event.callback)

def _add_hybrid_command(self, callback: Callable):
if not isinstance(callback, HybridSlashCommand) or not callback.callback or callback._dummy_base:
return

cmd = event.callback
cmd = callback
prefixed_transform = slash_to_prefixed(cmd)

if self.use_slash_command_msg:
Expand Down
17 changes: 6 additions & 11 deletions interactions/ext/prefixed_commands/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from interactions.api.events.internal import (
CommandError,
CommandCompletion,
CallbackAdded,
ExtensionUnload,
)
from interactions.client.client import Client
Expand Down Expand Up @@ -90,13 +89,13 @@ def __init__(
self.client.prefixed = self

self._dispatch_prefixed_commands = self._dispatch_prefixed_commands.copy_with_binding(self)
self._register_command = self._register_command.copy_with_binding(self)
self._handle_ext_unload = self._handle_ext_unload.copy_with_binding(self)

self.client.add_listener(self._dispatch_prefixed_commands)
self.client.add_listener(self._register_command)
self.client.add_listener(self._handle_ext_unload)

self.client._add_command_hook.append(self._register_command)

async def generate_prefixes(self, client: Client, msg: Message) -> str | list[str]:
"""
Generates a list of prefixes a prefixed command can have based on the client and message.
Expand Down Expand Up @@ -229,17 +228,13 @@ def remove_command(self, name: str, delete_parent_if_empty: bool = False) -> Opt

return command

@listen("callback_added")
async def _register_command(self, event: CallbackAdded) -> None:
def _register_command(self, callback: Callable) -> None:
"""Registers a prefixed command, if there is one given."""
if not isinstance(event.callback, PrefixedCommand):
if not isinstance(callback, PrefixedCommand):
return

cmd = event.callback
cmd.extension = event.extension

if not cmd.is_subcommand:
self.add_command(cmd)
if not callback.is_subcommand:
self.add_command(callback)

@listen("extension_unload")
async def _handle_ext_unload(self, event: ExtensionUnload) -> None:
Expand Down

0 comments on commit a8f6fcf

Please sign in to comment.