From 27053236f7b935b0271a9edd75503023dc665766 Mon Sep 17 00:00:00 2001 From: mAxYoLo01 Date: Fri, 25 Aug 2023 12:34:58 +0200 Subject: [PATCH] refactor!: Auto mod refactor --- interactions/api/events/discord.py | 2 + .../api/events/processors/auto_mod.py | 7 +- interactions/models/discord/auto_mod.py | 100 ++++++++++-------- interactions/models/discord/enums.py | 11 +- 4 files changed, 68 insertions(+), 52 deletions(-) diff --git a/interactions/api/events/discord.py b/interactions/api/events/discord.py index 35fe5c783..c621fd01f 100644 --- a/interactions/api/events/discord.py +++ b/interactions/api/events/discord.py @@ -132,6 +132,8 @@ class AutoModExec(BaseEvent): @attrs.define(eq=False, order=False, hash=False, kw_only=False) class AutoModCreated(BaseEvent): + """Dispatched when an auto mod rule is created""" + guild: "Guild" = attrs.field(repr=False, metadata=docs("The guild the rule was modified in")) rule: "AutoModRule" = attrs.field(repr=False, metadata=docs("The rule that was modified")) diff --git a/interactions/api/events/processors/auto_mod.py b/interactions/api/events/processors/auto_mod.py index 31f08476f..dfee26c2d 100644 --- a/interactions/api/events/processors/auto_mod.py +++ b/interactions/api/events/processors/auto_mod.py @@ -1,8 +1,9 @@ from typing import TYPE_CHECKING from interactions.models.discord.auto_mod import AutoModerationAction, AutoModRule -from ._template import EventMixinTemplate, Processor + from ... import events +from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: from interactions.api.events import RawGatewayEvent @@ -25,13 +26,13 @@ async def raw_auto_moderation_rule_create(self, event: "RawGatewayEvent") -> Non self.dispatch(events.AutoModCreated(guild, rule)) @Processor.define() - async def raw_auto_moderation_rule_delete(self, event: "RawGatewayEvent") -> None: + async def raw_auto_moderation_rule_update(self, event: "RawGatewayEvent") -> None: rule = AutoModRule.from_dict(event.data, self) guild = self.get_guild(event.data["guild_id"]) self.dispatch(events.AutoModUpdated(guild, rule)) @Processor.define() - async def raw_auto_moderation_rule_update(self, event: "RawGatewayEvent") -> None: + async def raw_auto_moderation_rule_delete(self, event: "RawGatewayEvent") -> None: rule = AutoModRule.from_dict(event.data, self) guild = self.get_guild(event.data["guild_id"]) self.dispatch(events.AutoModDeleted(guild, rule)) diff --git a/interactions/models/discord/auto_mod.py b/interactions/models/discord/auto_mod.py index c9be2d142..030d1700d 100644 --- a/interactions/models/discord/auto_mod.py +++ b/interactions/models/discord/auto_mod.py @@ -1,22 +1,30 @@ -from typing import Any, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, Union import attrs -from interactions.client.const import get_logger, MISSING, Absent +from interactions.client.const import MISSING, Absent, get_logger from interactions.client.mixins.serialization import DictSerializationMixin from interactions.client.utils import list_converter, optional from interactions.client.utils.attr_utils import docs from interactions.models.discord.base import ClientObject, DiscordObject from interactions.models.discord.enums import ( - AutoModTriggerType, AutoModAction, AutoModEvent, - AutoModLanuguageType, + AutoModTriggerType, + KeywordPresetType, ) -from interactions.models.discord.snowflake import to_snowflake_list, to_snowflake +from interactions.models.discord.snowflake import to_snowflake, to_snowflake_list if TYPE_CHECKING: - from interactions import Snowflake_Type, Guild, GuildText, Message, Client, Member, User + from interactions import ( + Client, + Guild, + GuildText, + Member, + Message, + Snowflake_Type, + User, + ) __all__ = ("AutoModerationAction", "AutoModRule") @@ -71,7 +79,7 @@ def _process_dict(cls, data: dict[str, Any]) -> dict[str, Any]: return data @classmethod - def from_dict_factory(cls, data: dict) -> "BaseAction": + def from_dict_factory(cls, data: dict) -> "BaseTrigger": trigger_class = TRIGGER_MAPPING.get(data.get("trigger_type")) meta = data.get("trigger_metadata", {}) if not trigger_class: @@ -97,29 +105,23 @@ def _keyword_converter(filter: str | list[str]) -> list[str]: class KeywordTrigger(BaseTrigger): """A trigger that checks if content contains words from a user defined list of keywords""" - type: AutoModTriggerType = attrs.field( - default=AutoModTriggerType.KEYWORD, - converter=AutoModTriggerType, + keyword_filter: list[str] = attrs.field( + factory=list, repr=True, - metadata=docs("The type of trigger"), + metadata=docs("Substrings which will be searched for in content"), + converter=_keyword_converter, ) - keyword_filter: str | list[str] = attrs.field( + regex_patterns: list[str] = attrs.field( factory=list, repr=True, - metadata=docs("What words will trigger this"), + metadata=docs("Regular expression patterns which will be matched against content"), converter=_keyword_converter, ) - - -@attrs.define(eq=False, order=False, hash=False, kw_only=True) -class HarmfulLinkFilter(BaseTrigger): - """A trigger that checks if content contains any harmful links""" - - type: AutoModTriggerType = attrs.field( - default=AutoModTriggerType.HARMFUL_LINK, - converter=AutoModTriggerType, + allow_list: list[str] = attrs.field( + factory=list, repr=True, - metadata=docs("The type of trigger"), + metadata=docs("Substrings which should not trigger the rule"), + converter=_keyword_converter, ) @@ -127,47 +129,57 @@ class HarmfulLinkFilter(BaseTrigger): class KeywordPresetTrigger(BaseTrigger): """A trigger that checks if content contains words from internal pre-defined wordsets""" - type: AutoModTriggerType = attrs.field( - default=AutoModTriggerType.KEYWORD_PRESET, - converter=AutoModTriggerType, + presets: list[KeywordPresetType] = attrs.field( + factory=list, + converter=list_converter(KeywordPresetType), repr=True, - metadata=docs("The type of trigger"), + metadata=docs("The internally pre-defined wordsets which will be searched for in content"), ) - keyword_lists: list[AutoModLanuguageType] = attrs.field( + allow_list: str | list[str] = attrs.field( factory=list, - converter=list_converter(AutoModLanuguageType), repr=True, - metadata=docs("The preset list of keywords that will trigger this"), + metadata=docs("Substrings which should not trigger the rule"), + converter=_keyword_converter, ) @attrs.define(eq=False, order=False, hash=False, kw_only=True) class MentionSpamTrigger(BaseTrigger): - """A trigger that checks if content contains more mentions than allowed""" + """A trigger that checks if content contains more unique mentions than allowed""" mention_total_limit: int = attrs.field( default=3, repr=True, metadata=docs("The maximum number of mentions allowed") ) + mention_raid_protection_enabled: bool = attrs.field( + repr=True, metadata=docs("Whether to automatically detect mention raids") + ) @attrs.define(eq=False, order=False, hash=False, kw_only=True) class MemberProfileTrigger(BaseTrigger): + """A trigger that checks if member profile contains words from a user defined list of keywords""" + regex_patterns: list[str] = attrs.field( - factory=list, repr=True, metadata=docs("The regex patterns to check against") + factory=list, repr=True, metadata=docs("The regex patterns to check against"), converter=_keyword_converter ) keyword_filter: str | list[str] = attrs.field( - factory=list, repr=True, metadata=docs("The keywords to check against") + factory=list, repr=True, metadata=docs("The keywords to check against"), converter=_keyword_converter ) allow_list: list["Snowflake_Type"] = attrs.field( - factory=list, repr=True, metadata=docs("The roles exempt from this rule") + factory=list, repr=True, metadata=docs("The roles exempt from this rule"), converter=_keyword_converter ) +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class SpamTrigger(BaseTrigger): + """A trigger that checks if content represents generic spam""" + + @attrs.define(eq=False, order=False, hash=False, kw_only=True) class BlockMessage(BaseAction): """blocks the content of a message according to the rule""" - type: AutoModAction = attrs.field(repr=False, default=AutoModAction.BLOCK_MESSAGE, converter=AutoModAction) + custom_message: Optional[str] = attrs.field(repr=True, default=None) @attrs.define(eq=False, order=False, hash=False, kw_only=True) @@ -175,7 +187,6 @@ class AlertMessage(BaseAction): """logs user content to a specified channel""" channel_id: "Snowflake_Type" = attrs.field(repr=True) - type: AutoModAction = attrs.field(repr=False, default=AutoModAction.ALERT_MESSAGE, converter=AutoModAction) @attrs.define(eq=False, order=False, hash=False, kw_only=False) @@ -183,7 +194,6 @@ class TimeoutUser(BaseAction): """timeout user for a specified duration""" duration_seconds: int = attrs.field(repr=True, default=60) - type: AutoModAction = attrs.field(repr=False, default=AutoModAction.TIMEOUT_USER, converter=AutoModAction) @attrs.define(eq=False, order=False, hash=False, kw_only=False) @@ -204,13 +214,13 @@ class AutoModRule(DiscordObject): enabled: bool = attrs.field(repr=False, default=False) """whether the rule is enabled""" - actions: list[BaseAction] = attrs.field(repr=False, factory=list) + actions: list["TYPE_ALL_ACTION"] = attrs.field(repr=False, factory=list) """the actions which will execute when the rule is triggered""" event_type: AutoModEvent = attrs.field( repr=False, ) """the rule event type""" - trigger: BaseTrigger = attrs.field( + trigger: "TYPE_ALL_TRIGGER" = attrs.field( repr=False, ) """The trigger for this rule""" @@ -262,10 +272,10 @@ async def modify( self, *, name: Absent[str] = MISSING, - trigger: Absent[BaseTrigger] = MISSING, + trigger: Absent["TYPE_ALL_TRIGGER"] = MISSING, trigger_type: Absent[AutoModTriggerType] = MISSING, trigger_metadata: Absent[dict] = MISSING, - actions: Absent[list[BaseAction]] = MISSING, + actions: Absent[list["TYPE_ALL_ACTION"]] = MISSING, exempt_channels: Absent[list["Snowflake_Type"]] = MISSING, exempt_roles: Absent[list["Snowflake_Type"]] = MISSING, event_type: Absent[AutoModEvent] = MISSING, @@ -318,7 +328,7 @@ class AutoModerationAction(ClientObject): repr=False, ) - action: BaseAction = attrs.field(default=MISSING, repr=True) + action: "TYPE_ALL_ACTION" = attrs.field(default=MISSING, repr=True) matched_keyword: str = attrs.field(repr=True) matched_content: Optional[str] = attrs.field(repr=False, default=None) @@ -368,8 +378,12 @@ def member(self) -> "Optional[Member]": TRIGGER_MAPPING = { AutoModTriggerType.KEYWORD: KeywordTrigger, - AutoModTriggerType.HARMFUL_LINK: HarmfulLinkFilter, + AutoModTriggerType.SPAM: SpamTrigger, AutoModTriggerType.KEYWORD_PRESET: KeywordPresetTrigger, AutoModTriggerType.MENTION_SPAM: MentionSpamTrigger, AutoModTriggerType.MEMBER_PROFILE: MemberProfileTrigger, } + +TYPE_ALL_TRIGGER = Union[KeywordTrigger, SpamTrigger, KeywordPresetTrigger, MentionSpamTrigger, MemberProfileTrigger] + +TYPE_ALL_ACTION = Union[BlockMessage, AlertMessage, TimeoutUser, BlockMemberInteraction] diff --git a/interactions/models/discord/enums.py b/interactions/models/discord/enums.py index c201f6d50..6537207ea 100644 --- a/interactions/models/discord/enums.py +++ b/interactions/models/discord/enums.py @@ -1,7 +1,7 @@ from enum import Enum, EnumMeta, IntEnum, IntFlag from functools import reduce from operator import or_ -from typing import Iterator, Tuple, TypeVar, Type, Optional +from typing import Iterator, Optional, Tuple, Type, TypeVar from interactions.client.const import get_logger @@ -991,7 +991,6 @@ class AuditLogEventType(CursedIntEnum): class AutoModTriggerType(CursedIntEnum): KEYWORD = 1 - HARMFUL_LINK = 2 SPAM = 3 KEYWORD_PRESET = 4 MENTION_SPAM = 5 @@ -1010,10 +1009,10 @@ class AutoModEvent(CursedIntEnum): MEMBER_UPDATE = 2 -class AutoModLanuguageType(Enum): - PROFANITY = "PROFANITY" - SEXUAL = "SEXUAL_CONTENT" - INSULTS_AND_SLURS = "SLURS" +class KeywordPresetType(CursedIntEnum): + PROFANITY = 1 + SEXUAL_CONTENT = 2 + SLURS = 3 class MemberFlags(DiscordIntFlag):