Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: Auto mod refactor #1543

Open
wants to merge 1 commit into
base: unstable
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions interactions/api/events/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class AutoModExec(BaseEvent):

@attrs.define(eq=False, order=False, hash=False, kw_only=False)
class AutoModCreated(BaseEvent):
"""Dispatched when an auto mod rule is created"""

guild: "Guild" = attrs.field(repr=False, metadata=docs("The guild the rule was modified in"))
rule: "AutoModRule" = attrs.field(repr=False, metadata=docs("The rule that was modified"))

Expand Down
7 changes: 4 additions & 3 deletions interactions/api/events/processors/auto_mod.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING

from interactions.models.discord.auto_mod import AutoModerationAction, AutoModRule
from ._template import EventMixinTemplate, Processor

from ... import events
from ._template import EventMixinTemplate, Processor

if TYPE_CHECKING:
from interactions.api.events import RawGatewayEvent
Expand All @@ -25,13 +26,13 @@ async def raw_auto_moderation_rule_create(self, event: "RawGatewayEvent") -> Non
self.dispatch(events.AutoModCreated(guild, rule))

@Processor.define()
async def raw_auto_moderation_rule_delete(self, event: "RawGatewayEvent") -> None:
async def raw_auto_moderation_rule_update(self, event: "RawGatewayEvent") -> None:
rule = AutoModRule.from_dict(event.data, self)
guild = self.get_guild(event.data["guild_id"])
self.dispatch(events.AutoModUpdated(guild, rule))

@Processor.define()
async def raw_auto_moderation_rule_update(self, event: "RawGatewayEvent") -> None:
async def raw_auto_moderation_rule_delete(self, event: "RawGatewayEvent") -> None:
rule = AutoModRule.from_dict(event.data, self)
guild = self.get_guild(event.data["guild_id"])
self.dispatch(events.AutoModDeleted(guild, rule))
100 changes: 57 additions & 43 deletions interactions/models/discord/auto_mod.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
from typing import Any, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, Union

import attrs

from interactions.client.const import get_logger, MISSING, Absent
from interactions.client.const import MISSING, Absent, get_logger
from interactions.client.mixins.serialization import DictSerializationMixin
from interactions.client.utils import list_converter, optional
from interactions.client.utils.attr_utils import docs
from interactions.models.discord.base import ClientObject, DiscordObject
from interactions.models.discord.enums import (
AutoModTriggerType,
AutoModAction,
AutoModEvent,
AutoModLanuguageType,
AutoModTriggerType,
KeywordPresetType,
)
from interactions.models.discord.snowflake import to_snowflake_list, to_snowflake
from interactions.models.discord.snowflake import to_snowflake, to_snowflake_list

if TYPE_CHECKING:
from interactions import Snowflake_Type, Guild, GuildText, Message, Client, Member, User
from interactions import (
Client,
Guild,
GuildText,
Member,
Message,
Snowflake_Type,
User,
)

__all__ = ("AutoModerationAction", "AutoModRule")

Expand Down Expand Up @@ -71,7 +79,7 @@ def _process_dict(cls, data: dict[str, Any]) -> dict[str, Any]:
return data

@classmethod
def from_dict_factory(cls, data: dict) -> "BaseAction":
def from_dict_factory(cls, data: dict) -> "BaseTrigger":
trigger_class = TRIGGER_MAPPING.get(data.get("trigger_type"))
meta = data.get("trigger_metadata", {})
if not trigger_class:
Expand All @@ -97,93 +105,95 @@ def _keyword_converter(filter: str | list[str]) -> list[str]:
class KeywordTrigger(BaseTrigger):
"""A trigger that checks if content contains words from a user defined list of keywords"""

type: AutoModTriggerType = attrs.field(
default=AutoModTriggerType.KEYWORD,
converter=AutoModTriggerType,
keyword_filter: list[str] = attrs.field(
factory=list,
repr=True,
metadata=docs("The type of trigger"),
metadata=docs("Substrings which will be searched for in content"),
converter=_keyword_converter,
)
keyword_filter: str | list[str] = attrs.field(
regex_patterns: list[str] = attrs.field(
factory=list,
repr=True,
metadata=docs("What words will trigger this"),
metadata=docs("Regular expression patterns which will be matched against content"),
converter=_keyword_converter,
)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class HarmfulLinkFilter(BaseTrigger):
"""A trigger that checks if content contains any harmful links"""

type: AutoModTriggerType = attrs.field(
default=AutoModTriggerType.HARMFUL_LINK,
converter=AutoModTriggerType,
allow_list: list[str] = attrs.field(
factory=list,
repr=True,
metadata=docs("The type of trigger"),
metadata=docs("Substrings which should not trigger the rule"),
converter=_keyword_converter,
)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class KeywordPresetTrigger(BaseTrigger):
"""A trigger that checks if content contains words from internal pre-defined wordsets"""

type: AutoModTriggerType = attrs.field(
default=AutoModTriggerType.KEYWORD_PRESET,
converter=AutoModTriggerType,
presets: list[KeywordPresetType] = attrs.field(
factory=list,
converter=list_converter(KeywordPresetType),
repr=True,
metadata=docs("The type of trigger"),
metadata=docs("The internally pre-defined wordsets which will be searched for in content"),
)
keyword_lists: list[AutoModLanuguageType] = attrs.field(
allow_list: str | list[str] = attrs.field(
factory=list,
converter=list_converter(AutoModLanuguageType),
repr=True,
metadata=docs("The preset list of keywords that will trigger this"),
metadata=docs("Substrings which should not trigger the rule"),
converter=_keyword_converter,
)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class MentionSpamTrigger(BaseTrigger):
"""A trigger that checks if content contains more mentions than allowed"""
"""A trigger that checks if content contains more unique mentions than allowed"""

mention_total_limit: int = attrs.field(
default=3, repr=True, metadata=docs("The maximum number of mentions allowed")
)
mention_raid_protection_enabled: bool = attrs.field(
repr=True, metadata=docs("Whether to automatically detect mention raids")
)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class MemberProfileTrigger(BaseTrigger):
"""A trigger that checks if member profile contains words from a user defined list of keywords"""

regex_patterns: list[str] = attrs.field(
factory=list, repr=True, metadata=docs("The regex patterns to check against")
factory=list, repr=True, metadata=docs("The regex patterns to check against"), converter=_keyword_converter
)
keyword_filter: str | list[str] = attrs.field(
factory=list, repr=True, metadata=docs("The keywords to check against")
factory=list, repr=True, metadata=docs("The keywords to check against"), converter=_keyword_converter
)
allow_list: list["Snowflake_Type"] = attrs.field(
factory=list, repr=True, metadata=docs("The roles exempt from this rule")
factory=list, repr=True, metadata=docs("The roles exempt from this rule"), converter=_keyword_converter
)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class SpamTrigger(BaseTrigger):
"""A trigger that checks if content represents generic spam"""


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class BlockMessage(BaseAction):
"""blocks the content of a message according to the rule"""

type: AutoModAction = attrs.field(repr=False, default=AutoModAction.BLOCK_MESSAGE, converter=AutoModAction)
custom_message: Optional[str] = attrs.field(repr=True, default=None)


@attrs.define(eq=False, order=False, hash=False, kw_only=True)
class AlertMessage(BaseAction):
"""logs user content to a specified channel"""

channel_id: "Snowflake_Type" = attrs.field(repr=True)
type: AutoModAction = attrs.field(repr=False, default=AutoModAction.ALERT_MESSAGE, converter=AutoModAction)


@attrs.define(eq=False, order=False, hash=False, kw_only=False)
class TimeoutUser(BaseAction):
"""timeout user for a specified duration"""

duration_seconds: int = attrs.field(repr=True, default=60)
type: AutoModAction = attrs.field(repr=False, default=AutoModAction.TIMEOUT_USER, converter=AutoModAction)


@attrs.define(eq=False, order=False, hash=False, kw_only=False)
Expand All @@ -204,13 +214,13 @@ class AutoModRule(DiscordObject):
enabled: bool = attrs.field(repr=False, default=False)
"""whether the rule is enabled"""

actions: list[BaseAction] = attrs.field(repr=False, factory=list)
actions: list["TYPE_ALL_ACTION"] = attrs.field(repr=False, factory=list)
"""the actions which will execute when the rule is triggered"""
event_type: AutoModEvent = attrs.field(
repr=False,
)
"""the rule event type"""
trigger: BaseTrigger = attrs.field(
trigger: "TYPE_ALL_TRIGGER" = attrs.field(
repr=False,
)
"""The trigger for this rule"""
Expand Down Expand Up @@ -262,10 +272,10 @@ async def modify(
self,
*,
name: Absent[str] = MISSING,
trigger: Absent[BaseTrigger] = MISSING,
trigger: Absent["TYPE_ALL_TRIGGER"] = MISSING,
trigger_type: Absent[AutoModTriggerType] = MISSING,
trigger_metadata: Absent[dict] = MISSING,
actions: Absent[list[BaseAction]] = MISSING,
actions: Absent[list["TYPE_ALL_ACTION"]] = MISSING,
exempt_channels: Absent[list["Snowflake_Type"]] = MISSING,
exempt_roles: Absent[list["Snowflake_Type"]] = MISSING,
event_type: Absent[AutoModEvent] = MISSING,
Expand Down Expand Up @@ -318,7 +328,7 @@ class AutoModerationAction(ClientObject):
repr=False,
)

action: BaseAction = attrs.field(default=MISSING, repr=True)
action: "TYPE_ALL_ACTION" = attrs.field(default=MISSING, repr=True)

matched_keyword: str = attrs.field(repr=True)
matched_content: Optional[str] = attrs.field(repr=False, default=None)
Expand Down Expand Up @@ -368,8 +378,12 @@ def member(self) -> "Optional[Member]":

TRIGGER_MAPPING = {
AutoModTriggerType.KEYWORD: KeywordTrigger,
AutoModTriggerType.HARMFUL_LINK: HarmfulLinkFilter,
AutoModTriggerType.SPAM: SpamTrigger,
AutoModTriggerType.KEYWORD_PRESET: KeywordPresetTrigger,
AutoModTriggerType.MENTION_SPAM: MentionSpamTrigger,
AutoModTriggerType.MEMBER_PROFILE: MemberProfileTrigger,
}

TYPE_ALL_TRIGGER = Union[KeywordTrigger, SpamTrigger, KeywordPresetTrigger, MentionSpamTrigger, MemberProfileTrigger]

TYPE_ALL_ACTION = Union[BlockMessage, AlertMessage, TimeoutUser, BlockMemberInteraction]
11 changes: 5 additions & 6 deletions interactions/models/discord/enums.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum, EnumMeta, IntEnum, IntFlag
from functools import reduce
from operator import or_
from typing import Iterator, Tuple, TypeVar, Type, Optional
from typing import Iterator, Optional, Tuple, Type, TypeVar

from interactions.client.const import get_logger

Expand Down Expand Up @@ -991,7 +991,6 @@ class AuditLogEventType(CursedIntEnum):

class AutoModTriggerType(CursedIntEnum):
KEYWORD = 1
HARMFUL_LINK = 2
SPAM = 3
KEYWORD_PRESET = 4
MENTION_SPAM = 5
Expand All @@ -1010,10 +1009,10 @@ class AutoModEvent(CursedIntEnum):
MEMBER_UPDATE = 2


class AutoModLanuguageType(Enum):
PROFANITY = "PROFANITY"
SEXUAL = "SEXUAL_CONTENT"
INSULTS_AND_SLURS = "SLURS"
class KeywordPresetType(CursedIntEnum):
PROFANITY = 1
SEXUAL_CONTENT = 2
SLURS = 3


class MemberFlags(DiscordIntFlag):
Expand Down