Skip to content

Commit

Permalink
Merge pull request #1049 from xtexChooser/matrix/issue1046
Browse files Browse the repository at this point in the history
#1046 for matrix
  • Loading branch information
OasisAkari authored Dec 20, 2023
2 parents d8e6733 + 5842f2a commit a805ff1
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 51 deletions.
34 changes: 22 additions & 12 deletions bots/matrix/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bots.matrix import client
from bots.matrix.client import bot
from bots.matrix.info import client_name
from bots.matrix.message import MessageSession, FetchTarget
from bots.matrix.message import MessageSession, FetchTarget, ReactionMessageSession
from core.builtins import PrivateAssets, Url
from core.logger import Logger
from core.parser.message import parser
Expand Down Expand Up @@ -50,36 +50,45 @@ async def on_room_member(room: nio.MatrixRoom, event: nio.RoomMemberEvent):
Logger.info(f"Left empty room {room.room_id}")


async def on_message(room: nio.MatrixRoom, event: nio.RoomMessageFormatted):
async def on_message(room: nio.MatrixRoom, event: nio.Event):
if event.sender != bot.user_id and bot.olm:
for device_id, olm_device in bot.device_store[event.sender].items():
if bot.olm.is_device_verified(olm_device):
continue
bot.verify_device(olm_device)
Logger.info(f"trust olm device for device id {event.sender} -> {device_id}")
if event.source['content']['msgtype'] == 'm.notice':
if isinstance(event, nio.RoomMessageFormatted) and event.source['content']['msgtype'] == 'm.notice':
# https://spec.matrix.org/v1.7/client-server-api/#mnotice
return
is_room = room.member_count != 2 or room.join_rule != 'invite'
target_id = room.room_id if is_room else event.sender
reply_id = None
if 'm.relates_to' in event.source['content'] and 'm.in_reply_to' in event.source['content']['m.relates_to']:
reply_id = event.source['content']['m.relates_to']['m.in_reply_to']['event_id']

resp = await bot.get_displayname(event.sender)
if isinstance(resp, nio.ErrorResponse):
Logger.error(f"Failed to get display name for {event.sender}")
return
sender_name = resp.displayname

msg = MessageSession(MsgInfo(target_id=f'Matrix|{target_id}',
sender_id=f'Matrix|{event.sender}',
target_from=f'Matrix',
sender_from='Matrix',
sender_name=sender_name,
client_name=client_name,
message_id=event.event_id,
reply_id=reply_id),
Session(message=event.source, target=room.room_id, sender=event.sender))
target = MsgInfo(target_id=f'Matrix|{target_id}',
sender_id=f'Matrix|{event.sender}',
target_from=f'Matrix',
sender_from='Matrix',
sender_name=sender_name,
client_name=client_name,
message_id=event.event_id,
reply_id=reply_id)
session = Session(message=event.source, target=room.room_id, sender=event.sender)

msg = None
if isinstance(event, nio.RoomMessageFormatted):
msg = MessageSession(target, session)
elif isinstance(event, nio.ReactionEvent):
msg = ReactionMessageSession(target, session)
else:
raise NotImplemented
asyncio.create_task(parser(msg))


Expand Down Expand Up @@ -141,6 +150,7 @@ async def start():
bot.add_event_callback(on_invite, nio.InviteEvent)
bot.add_event_callback(on_room_member, nio.RoomMemberEvent)
bot.add_event_callback(on_message, nio.RoomMessageFormatted)
bot.add_event_callback(on_message, nio.ReactionEvent)
bot.add_to_device_callback(on_verify, nio.KeyVerificationEvent)
bot.add_event_callback(on_in_room_verify, nio.RoomMessageUnknown)

Expand Down
33 changes: 32 additions & 1 deletion bots/matrix/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ async def to_message_chain(self):

async def delete(self):
try:
await bot.room_redact(self.session.target, self.session.message['event_id'])
await bot.room_redact(self.session.target, self.target.message_id)
except Exception:
Logger.error(traceback.format_exc())

Expand All @@ -224,6 +224,37 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
pass


class ReactionMessageSession(MessageSession):
class Feature(MessageSession.Feature):
pass

class Typing(MessageSession.Typing):
pass

def as_display(self, text_only=False):
if text_only:
return ''
return self.session.message['content']['m.relates_to']['key']

async def to_message_chain(self):
return MessageChain([])

def is_quick_confirm(self, target: Union[MessageSession, FinishedSession]) -> bool:
content = self.session.message['content']['m.relates_to']
if content['rel_type'] == 'm.annotation':
if content['key'] in ['👍️', '✔️', '🎉']: # todo: move to config
if target is None:
return True
else:
msg = [target.target.message_id] if isinstance(target, MessageSession) else target.message_id
if content['event_id'] in msg:
return True
return False

asDisplay = as_display
toMessageChain = to_message_chain


class FetchedSession(Bot.FetchedSession):

async def _resolve_matrix_room_(self):
Expand Down
1 change: 1 addition & 0 deletions config/config.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ debug = false
cache_path = "./cache/"
command_prefix = ["~", "~",]
confirm_command = ["是", "对", "對", "yes", "Yes", "YES", "y", "Y",]
quick_confirm = true
disabled_bots =
locale = "zh_cn"
timezone_offset = "+8"
Expand Down
10 changes: 6 additions & 4 deletions core/builtins/message/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from core.builtins.message.internal import *
from core.builtins.tasks import MessageTaskManager
from core.builtins.temp import ExecutionLockList
from core.builtins.utils import confirm_command
from core.builtins.utils import confirm_command, quick_confirm
from core.exceptions import WaitCancelException
from core.types.message import MessageSession as MessageSessionT, MsgInfo, Session
from core.types.message import MessageSession as MessageSessionT, FinishedSession, MsgInfo, Session
from core.utils.i18n import Locale
from core.utils.text import parse_time_string
from database import BotDBUtil
Expand Down Expand Up @@ -56,12 +56,14 @@ async def wait_confirm(self, message_chain=None, quote=True, delete=True, timeou
await send.delete()
if result.as_display(text_only=True) in confirm_command:
return True
if quick_confirm and result.is_quick_confirm(send):
return True
return False
else:
raise WaitCancelException

async def wait_next_message(self, message_chain=None, quote=True, delete=False, timeout=120,
append_instruction=True) -> MessageSessionT:
append_instruction=True) -> (MessageSessionT, FinishedSession):
sent = None
ExecutionLockList.remove(self)
if message_chain:
Expand All @@ -79,7 +81,7 @@ async def wait_next_message(self, message_chain=None, quote=True, delete=False,
if delete and sent:
await sent.delete()
if result:
return result
return (result, sent)
else:
raise WaitCancelException

Expand Down
2 changes: 1 addition & 1 deletion core/builtins/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def add_callback(cls, message_id, callback):
cls._callback_list[message_id] = {'callback': callback, 'ts': datetime.now().timestamp()}

@classmethod
def get_result(cls, session: MessageSession):
def get_result(cls, session: MessageSession) -> MessageSession:
if 'result' in cls._list[session.target.target_id][session.target.sender_id][session]:
return cls._list[session.target.target_id][session.target.sender_id][session]['result']
else:
Expand Down
3 changes: 2 additions & 1 deletion core/builtins/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@


confirm_command = Config('confirm_command', default=["是", "对", "對", "yes", "Yes", "YES", "y", "Y"])
quick_confirm = Config('quick_confirm', default=True)
command_prefix = Config('command_prefix', default=['~', '~']) # 消息前缀


class EnableDirtyWordCheck:
status = False


__all__ = ["confirm_command", "command_prefix", "EnableDirtyWordCheck", "PrivateAssets", "Secret"]
__all__ = ["confirm_command", "quick_confirm", "command_prefix", "EnableDirtyWordCheck", "PrivateAssets", "Secret"]
13 changes: 10 additions & 3 deletions core/types/message/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import List, Union, Dict, Coroutine
from typing import List, Union, Dict, Coroutine, Self

from core.exceptions import FinishedException
from .chain import MessageChain
Expand Down Expand Up @@ -149,14 +149,14 @@ async def wait_confirm(self, message_chain=None, quote=True, delete=True, timeou
raise NotImplementedError

async def wait_next_message(self, message_chain=None, quote=True, delete=False, timeout=120,
append_instruction=True):
append_instruction=True) -> (Self, FinishedSession):
"""
一次性模板,用于等待对象的下一条消息。
:param message_chain: 需要发送的确认消息,可不填
:param quote: 是否引用传入dict中的消息(默认为True)
:param delete: 是否在触发后删除消息(默认为False)
:param timeout: 超时时间
:return: 下一条消息的MessageChain对象
:return: 下一条消息的MessageChain对象和发出的提示消息
"""
raise NotImplementedError

Expand Down Expand Up @@ -215,6 +215,13 @@ async def check_native_permission(self):
"""
raise NotImplementedError

def is_quick_confirm(self, target: Union[Self, FinishedSession] = None) -> bool:
"""
用于检查消息是否可用作快速确认事件。
:param target: 确认的目标消息
"""
return False

async def fake_forward_msg(self, nodelist):
"""
用于发送假转发消息(QQ)。
Expand Down
2 changes: 1 addition & 1 deletion modules/chemical_code/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ async def timer(start):

await asyncio.gather(ans(msg, csr['name'], random_mode), timer(time_start))
else:
result = await msg.wait_next_message([Plain(msg.locale.t('chemical_code.message.showid', id=csr["id"])),
result, _ = await msg.wait_next_message([Plain(msg.locale.t('chemical_code.message.showid', id=csr["id"])),
Image(newpath), Plain(msg.locale.t('chemical_code.message.captcha',
times=set_timeout))], timeout=3600, append_instruction=False)
if play_state[msg.target.target_id]['active']:
Expand Down
13 changes: 6 additions & 7 deletions modules/ncmusic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ async def search(msg: Bot.MessageSession, keyword: str):
f"{' / '.join(artist['name'] for artist in song['artists'])}",
f"{song['album']['name']}" + (f" ({' / '.join(song['album']['transNames'])})" if 'transNames' in song['album'] else ''),
f"{song['id']}"
] for i, song in enumerate(songs, start=1)
]
] for i, song in enumerate(songs, start=1)
]

tables = ImageTable(data, [
msg.locale.t('ncmusic.message.search.table.header.id'),
msg.locale.t('ncmusic.message.search.table.header.name'),
msg.locale.t('ncmusic.message.search.table.header.artists'),
msg.locale.t('ncmusic.message.search.table.header.album'),
'ID'
])
])

img = await image_table_render(tables)
if img:
Expand All @@ -62,7 +62,7 @@ async def search(msg: Bot.MessageSession, keyword: str):

else:
send_msg.append(Plain(msg.locale.t('ncmusic.message.search.prompt')))
query = await msg.wait_reply(send_msg)
query, _ = await msg.wait_next_message(send_msg)
query = query.as_display(text_only=True)

if query.isdigit():
Expand All @@ -89,21 +89,20 @@ async def search(msg: Bot.MessageSession, keyword: str):
if 'transNames' in song['album']:
send_msg += f"({' / '.join(song['album']['transNames'])})"
send_msg += f"({song['id']}\n"

if song_count > 10:
song_count = 10
send_msg += msg.locale.t("message.collapse", amount="10")

if song_count == 1:
send_msg += '\n' + msg.locale.t('ncmusic.message.search.confirm')
query = await msg.wait_confirm(send_msg, delete=False)
query, _ = await msg.wait_next_message(send_msg)
if query:
sid = result['result']['songs'][0]['id']
else:
return
else:
send_msg += '\n' + msg.locale.t('ncmusic.message.search.prompt')
query = await msg.wait_reply(send_msg)
query, _ = await msg.wait_next_message(send_msg)
query = query.as_display(text_only=True)

if query.isdigit():
Expand Down
12 changes: 5 additions & 7 deletions modules/summary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

openai.api_key = Config('openai_api_key')

s = module('summary',
developers=['Dianliang233', 'OasisAkari'],
desc='{summary.help.desc}',
available_for=['QQ', 'QQ|Group'])
s = module('summary',
developers=['Dianliang233', 'OasisAkari'],
desc='{summary.help.desc}',
available_for=['QQ', 'QQ|Group'])


@s.handle('{{summary.help}}')
Expand All @@ -28,7 +28,7 @@ async def _(msg: Bot.MessageSession):
qc = CoolDown('call_openai', msg)
c = qc.check(60)
if c == 0 or msg.target.target_from == 'TEST|Console' or is_superuser:
f_msg = await msg.wait_next_message(msg.locale.t('summary.message'), append_instruction=False)
f_msg, _ = await msg.wait_next_message(msg.locale.t('summary.message'), append_instruction=False)
try:
f = re.search(r'\[Ke:forward,id=(.*?)\]', f_msg.as_display()).group(1)
except AttributeError:
Expand Down Expand Up @@ -86,5 +86,3 @@ async def _(msg: Bot.MessageSession):
await msg.finish(output, disable_secret_check=True)
else:
await msg.finish(msg.locale.t('message.cooldown', time=int(c), cd_time='60'))


4 changes: 2 additions & 2 deletions modules/twenty_four/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ async def _(msg: Bot.MessageSession):

numbers = [random.randint(1, 13) for _ in range(4)]
has_solution_flag = await has_solution(numbers)

answer = await msg.wait_next_message(msg.locale.t('twenty_four.message', numbers=numbers), timeout=3600, append_instruction=False)
answer, _ = await msg.wait_next_message(msg.locale.t('twenty_four.message', numbers=numbers), timeout=3600, append_instruction=False)
expression = answer.as_display(text_only=True)
if play_state[msg.target.target_id]['active']:
if expression.lower() in no_solution:
Expand Down
6 changes: 4 additions & 2 deletions modules/wiki/wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import filetype

from core.builtins import Bot, Plain, Image, Voice, Url, confirm_command
from core.builtins import Bot, Plain, Image, Voice, Url, confirm_command, quick_confirm
from core.utils.image_table import image_table_render, ImageTable
from core.component import module
from core.exceptions import AbuseWarning
Expand Down Expand Up @@ -377,11 +377,13 @@ async def image_and_voice():

async def wait_confirm():
if wait_msg_list and session.Feature.wait:
confirm = await session.wait_next_message(wait_msg_list, delete=True, append_instruction=False)
confirm, sent = await session.wait_next_message(wait_msg_list, delete=True, append_instruction=False)
auto_index = False
index = 0
if confirm.as_display(text_only=True) in confirm_command:
auto_index = True
elif quick_confirm and confirm.is_quick_confirm(sent):
auto_index = True
elif confirm.as_display(text_only=True).isdigit():
index = int(confirm.as_display()) - 1
else:
Expand Down
11 changes: 5 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pycryptodome = "^3.18.0"
langconv = "^0.2.0"
toml = "^0.10.2"
khl-py = "^0.3.16"
matrix-nio = "^0.21.2"
matrix-nio = "^0.22.0"
attrs = "^23.1.0"
uvicorn = {extras = ["standard"], version = "^0.23.2"}
pyjwt = {extras = ["crypto"], version = "^2.8.0"}
Expand Down
Loading

0 comments on commit a805ff1

Please sign in to comment.