diff --git a/src/cogs/admin.py b/src/cogs/admin.py index f0aa401..9ab9445 100644 --- a/src/cogs/admin.py +++ b/src/cogs/admin.py @@ -6,6 +6,7 @@ from discord import app_commands from core import ExtendedCog, config +from core.checkers import is_me if TYPE_CHECKING: from discord import Interaction @@ -15,9 +16,10 @@ logger = logging.getLogger(__name__) -class Admin(ExtendedCog): # TODO(airo.pi_): add checkers +class Admin(ExtendedCog): @app_commands.command() @app_commands.guilds(config.support_guild_id) + @is_me async def reload_extension(self, inter: Interaction, extension: str): await self.bot.reload_extension(extension) await inter.response.send_message(f"Extension [{extension}] reloaded successfully") @@ -32,6 +34,7 @@ async def extension_autocompleter(self, inter: Interaction, current: str) -> lis @app_commands.command() @app_commands.guilds(config.support_guild_id) + @is_me async def sync_tree(self, inter: Interaction): await inter.response.defer() await self.bot.sync_tree() diff --git a/src/cogs/clear/__init__.py b/src/cogs/clear/__init__.py index 97244e5..e9f52c4 100644 --- a/src/cogs/clear/__init__.py +++ b/src/cogs/clear/__init__.py @@ -48,7 +48,7 @@ def __init__(self, bot: MyBot): self.clear_max_concurrency = checkers.MaxConcurrency(1, key=channel_bucket, wait=False) - @checkers.app.bot_required_permissions( + @checkers.bot_required_permissions( manage_messages=True, read_message_history=True, read_messages=True, connect=True ) @app_commands.command( diff --git a/src/cogs/eval.py b/src/cogs/eval.py index c77abc3..04f5fcd 100644 --- a/src/cogs/eval.py +++ b/src/cogs/eval.py @@ -17,8 +17,7 @@ from core import ExtendedCog from core._config import config -from core.checkers.app import is_me -from core.checkers.base import is_me_bool +from core.checkers import is_me, is_me_test from core.utils import size_text if TYPE_CHECKING: @@ -32,7 +31,7 @@ class Eval(ExtendedCog): @commands.command(name="+eval") - @commands.check(lambda ctx: is_me_bool(ctx.author.id)) + @commands.check(lambda ctx: is_me_test(ctx.author.id)) async def add_eval(self, ctx: commands.Context[MyBot]) -> None: try: self.bot.tree.add_command(self._eval, guild=ctx.guild) @@ -44,7 +43,7 @@ async def add_eval(self, ctx: commands.Context[MyBot]) -> None: await ctx.send("Command added.") @commands.command(name="-eval") - @commands.check(lambda ctx: is_me_bool(ctx.author.id)) + @commands.check(lambda ctx: is_me_test(ctx.author.id)) async def remove_eval(self, ctx: commands.Context[MyBot]) -> None: if self.bot.tree.remove_command("eval", guild=ctx.guild) is None: await ctx.send("Command not registered. Cleaning eventual leftovers...") diff --git a/src/cogs/poll/__init__.py b/src/cogs/poll/__init__.py index b62de0c..02ec5b6 100644 --- a/src/cogs/poll/__init__.py +++ b/src/cogs/poll/__init__.py @@ -13,7 +13,7 @@ from sqlalchemy.orm import selectinload from core import ExtendedGroupCog, db -from core.checkers.app import bot_required_permissions +from core.checkers import bot_required_permissions from core.errors import NonSpecificError from core.i18n import _ diff --git a/src/cogs/restore.py b/src/cogs/restore.py index a5e2f13..34579dc 100644 --- a/src/cogs/restore.py +++ b/src/cogs/restore.py @@ -4,8 +4,10 @@ import re from typing import TYPE_CHECKING +from discord import Interaction, app_commands + from core import ExtendedCog, MiscCommandContext, misc_command -from core.checkers.misc import bot_required_permissions, is_activated, is_user_authorized, misc_check +from core.checkers import bot_required_permissions, check, is_activated_predicate, is_user_authorized_predicate if TYPE_CHECKING: from discord import Message @@ -27,11 +29,15 @@ def contains_message_link(self, message: Message) -> bool: trigger_condition=contains_message_link, ) @bot_required_permissions(manage_webhooks=True) - @misc_check(is_activated) - @misc_check(is_user_authorized) + @check(is_activated_predicate) + @check(is_user_authorized_predicate) async def on_message(self, ctx: MiscCommandContext[MyBot], message: Message) -> None: raise NotImplementedError("Restore is not implemented.") + @app_commands.command() + async def test(self, inter: Interaction): + pass + async def setup(bot: MyBot): await bot.add_cog(Restore(bot)) diff --git a/src/cogs/translate/__init__.py b/src/cogs/translate/__init__.py index 59f1b96..cff12c8 100644 --- a/src/cogs/translate/__init__.py +++ b/src/cogs/translate/__init__.py @@ -13,7 +13,7 @@ from discord.app_commands import locale_str as __ from core import ExtendedCog, MiscCommandContext, ResponseType, TemporaryCache, db, misc_command, response_constructor -from core.checkers.misc import bot_required_permissions, is_activated, is_user_authorized, misc_check +from core.checkers import bot_required_permissions, check, is_activated_predicate, is_user_authorized_predicate from core.constants import EmbedsCharLimits from core.errors import BadArgument, NonSpecificError from core.i18n import _ @@ -280,8 +280,8 @@ async def translate_misc_condition(self, payload: RawReactionActionEvent) -> boo trigger_condition=translate_misc_condition, ) @bot_required_permissions(send_messages=True, embed_links=True) - @misc_check(is_activated) - @misc_check(is_user_authorized) + @check(is_activated_predicate) + @check(is_user_authorized_predicate) async def translate_misc_command(self, ctx: MiscCommandContext[MyBot], payload: RawReactionActionEvent): channel = await self.bot.getch_channel(payload.channel_id) if channel is None: diff --git a/src/core/checkers/__init__.py b/src/core/checkers/__init__.py index 49f488c..722bbb5 100644 --- a/src/core/checkers/__init__.py +++ b/src/core/checkers/__init__.py @@ -1,2 +1,157 @@ -from . import app as app +from __future__ import annotations + +import inspect +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, cast + +import discord +from discord import Interaction +from discord.app_commands import Command, ContextMenu, check as app_check + +from .._config import config +from .._types import CoroT +from ..errors import BotMissingPermissions, NotAllowedUser +from ..extended_commands import MiscCommandContext, check as misc_check +from ..utils import CommandType from .max_concurrency import MaxConcurrency as MaxConcurrency + +if TYPE_CHECKING: + from mybot import MyBot + + type Context = Interaction | MiscCommandContext[Any] + + +logger = logging.getLogger(__name__) + + +def _determine_type(obj: Any) -> CommandType: + """This function will determine the type of the command. + + It makes some assumptions about the type of the command based on the annotations of the function. + """ + if isinstance(obj, Command | ContextMenu): + return CommandType.APP + if hasattr(obj, "__listener_as_command__"): + return CommandType.MISC + else: + annotations = inspect.get_annotations(obj) + target = next(iter(annotations.values())) # get the first annotation + if target is MiscCommandContext: + return CommandType.MISC + if target is Interaction: + return CommandType.APP + if isinstance(target, str): + # I don't know how to handle this case properly because MyBot is not imported in this file + if target.startswith("Interaction"): + return CommandType.APP + if target.startswith("MiscCommandContext"): + return CommandType.MISC + raise TypeError("Could not determine the type of the command.") + + +def _add_extra[T](type_: CommandType, func: T, name: str, value: Any) -> T: + copy_func = func # typing behavior + if type_ is CommandType.APP: + if isinstance(func, Command | ContextMenu): + func.extras[name] = value + else: + logger.critical( + "Because we need to add extras, this decorator must be above the command decorator. " + "(Command should already be defined)" + ) + elif type_ is CommandType.MISC: + if hasattr(func, "__listener_as_command__"): + command: Command[Any, ..., Any] = getattr(func, "__listener_as_command__") + command.extras[name] = value + else: + if not hasattr(func, "__misc_commands_extras__"): + setattr(func, "__misc_commands_extras__", {}) + getattr(func, "__misc_commands_extras__")[name] = value + return copy_func + + +def check[C: Interaction | MiscCommandContext[Any], F]( + predicate: Callable[[C], bool | CoroT[bool]], +) -> Callable[[F], F]: + def decorator(func: F) -> F: + match _determine_type(func): + case CommandType.APP: + p = cast(Callable[[Interaction], bool | CoroT[bool]], predicate) + return app_check(p)(func) + case CommandType.MISC: + p = cast(Callable[[MiscCommandContext[Any]], bool | CoroT[bool]], predicate) + return misc_check(p)(func) + + return decorator + + +def _bot_required_permissions_test(perms: dict[str, bool]) -> Callable[..., bool]: + def predicate(ctx: Context): + match ctx: + case discord.Interaction(): + permissions = ctx.app_permissions + case MiscCommandContext(): + permissions = ctx.bot_permissions + + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + + if not missing: + return True + + raise BotMissingPermissions(missing) + + return predicate + + +def bot_required_permissions[T](**perms: bool) -> Callable[[T], T]: + invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) + if invalid: + raise TypeError(f"Invalid permission(s): {", ".join(invalid)}") + + def decorator(func: T) -> T: + type_ = _determine_type(func) + _add_extra( + type_, + func, + "bot_required_permissions", + [perm for perm, value in perms.items() if value is True], + ) + match type_: + case CommandType.APP: + return app_check(_bot_required_permissions_test(perms))(func) + case CommandType.MISC: + return misc_check(_bot_required_permissions_test(perms))(func) + + return decorator + + +async def is_user_authorized_predicate(context: MiscCommandContext[MyBot]) -> bool: + del context # unused + # TODO(airo.pi_): check using the database if the user is authorized + return True + + +is_user_authorized = check(is_user_authorized_predicate) # misc commands only + + +async def is_activated_predicate(context: MiscCommandContext[MyBot]) -> bool: + del context # unused + # TODO(airo.pi_): check using the database if the misc command is activated + return True + + +is_activated = check(is_activated_predicate) # misc commands only + + +def allowed_users_test(*user_ids: int) -> Callable[..., bool]: + def inner(user_id: int) -> bool: + if user_id not in user_ids: + raise NotAllowedUser(user_id) + return True + + return inner + + +is_me_test = allowed_users_test(*config.owners_ids) # test function used for eval commands +is_me = check(lambda ctx: is_me_test(ctx.user.id)) diff --git a/src/core/checkers/app.py b/src/core/checkers/app.py deleted file mode 100644 index 55b0d04..0000000 --- a/src/core/checkers/app.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable - -from discord.app_commands import check - -from ..extended_commands import misc_check as misc_check -from ..utils import CommandType -from .base import T, bot_required_permissions_base, is_me_bool - - -def bot_required_permissions(**perms: bool) -> Callable[[T], T]: - return bot_required_permissions_base(CommandType.APP, **perms) - - -is_me = check(lambda inter: is_me_bool(inter.user.id)) diff --git a/src/core/checkers/base.py b/src/core/checkers/base.py deleted file mode 100644 index 50afcf5..0000000 --- a/src/core/checkers/base.py +++ /dev/null @@ -1,100 +0,0 @@ -from __future__ import annotations - -import logging -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, TypeVar - -import discord -from discord.app_commands import Command, ContextMenu, check as app_check - -from .._config import config -from ..errors import BotMissingPermissions, NotAllowedUser -from ..extended_commands import MiscCommandContext, misc_check as misc_check -from ..utils import CommandType - -T = TypeVar("T") - - -logger = logging.getLogger(__name__) - - -if TYPE_CHECKING: - from discord import Interaction - - from mybot import MyBot - - -def add_extra(type_: CommandType, func: T, name: str, value: Any) -> T: - copy_func = func # typing behavior - if type_ is CommandType.APP: - if isinstance(func, Command | ContextMenu): - func.extras[name] = value - else: - logger.critical( - "Because we need to add extras, this decorator must be above the command decorator. " - "(Command should already be defined)" - ) - elif type_ is CommandType.MISC: - if hasattr(func, "__listener_as_command__"): - command: Command[Any, ..., Any] = getattr(func, "__listener_as_command__") - command.extras[name] = value - else: - if not hasattr(func, "__misc_commands_extras__"): - setattr(func, "__misc_commands_extras__", {}) - getattr(func, "__misc_commands_extras__")[name] = value - return copy_func - - -def _bot_required_permissions_predicate(perms: dict[str, bool]) -> Callable[..., bool]: - def predicate(ctx: Interaction | MiscCommandContext[MyBot]): - match ctx: - case discord.Interaction(): - permissions = ctx.app_permissions - case MiscCommandContext(): - permissions = ctx.bot_permissions - - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] - - if not missing: - return True - - raise BotMissingPermissions(missing) - - return predicate - - -# This exist only to avoid the duplication of invalid perms check. It mays be removed in the future. -def bot_required_permissions_base(type_: CommandType, **perms: bool) -> Callable[[T], T]: - invalid = set(perms) - set(discord.Permissions.VALID_FLAGS) - if invalid: - raise TypeError(f"Invalid permission(s): {", ".join(invalid)}") - - def decorator(func: T) -> T: - match type_: - case CommandType.APP: - add_extra( - type_, func, "bot_required_permissions", [perm for perm, value in perms.items() if value is True] - ) - return app_check(_bot_required_permissions_predicate(perms))(func) - case CommandType.MISC: - add_extra( - CommandType.MISC, - func, - "bot_required_permissions", - [perm for perm, value in perms.items() if value is True], - ) - return misc_check(_bot_required_permissions_predicate(perms))(func) - - return decorator - - -def allowed_users_bool(*user_ids: int) -> Callable[..., bool]: - def inner(user_id: int) -> bool: - if user_id not in user_ids: - raise NotAllowedUser(user_id) - return True - - return inner - - -is_me_bool = allowed_users_bool(*config.owners_ids) diff --git a/src/core/checkers/max_concurrency.py b/src/core/checkers/max_concurrency.py index fe1bbb6..0e4f9f7 100644 --- a/src/core/checkers/max_concurrency.py +++ b/src/core/checkers/max_concurrency.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Self, TypeVar from ..errors import MaxConcurrencyReached -from ..extended_commands import misc_check as misc_check T = TypeVar("T") diff --git a/src/core/checkers/misc.py b/src/core/checkers/misc.py deleted file mode 100644 index 4ea5346..0000000 --- a/src/core/checkers/misc.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from typing import TYPE_CHECKING - -from ..extended_commands import MiscCommandContext, misc_check as misc_check -from ..utils import CommandType -from .base import T, bot_required_permissions_base - -if TYPE_CHECKING: - from mybot import MyBot - - -def bot_required_permissions(**perms: bool) -> Callable[[T], T]: - return bot_required_permissions_base(CommandType.MISC, **perms) - - -async def is_user_authorized(context: MiscCommandContext[MyBot]) -> bool: - del context # unused - # TODO(airo.pi_): check using the database if the user is authorized - return True - - -async def is_activated(context: MiscCommandContext[MyBot]) -> bool: - del context # unused - # TODO(airo.pi_): check using the database if the misc command is activated - return True diff --git a/src/core/extended_commands.py b/src/core/extended_commands.py index 9951a27..757fb9d 100644 --- a/src/core/extended_commands.py +++ b/src/core/extended_commands.py @@ -285,7 +285,7 @@ def decorator(func: R) -> R: return decorator -def misc_check(predicate: Callable[[MiscCommandContext[Any]], CoroT[bool] | bool]) -> Callable[[R], R]: +def check(predicate: Callable[[MiscCommandContext[Any]], CoroT[bool] | bool]) -> Callable[[R], R]: def decorator(func: R) -> R: if hasattr(func, "__listener_as_command__"): misc_command: MiscCommand[Any, Any] = getattr(func, "__listener_as_command__")