diff --git a/.gitignore b/.gitignore index dd19aa5a..4ac91687 100644 --- a/.gitignore +++ b/.gitignore @@ -34,4 +34,7 @@ data/labels/instance_id.json .DS_Store /data botpy.log* -/poc \ No newline at end of file +/poc +/libs/wecom_api/test.py +/venv + diff --git a/libs/wecom_api/WXBizMsgCrypt3.py b/libs/wecom_api/WXBizMsgCrypt3.py new file mode 100644 index 00000000..0123c7d1 --- /dev/null +++ b/libs/wecom_api/WXBizMsgCrypt3.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python +# -*- encoding:utf-8 -*- + +""" 对企业微信发送给企业后台的消息加解密示例代码. +@copyright: Copyright (c) 1998-2014 Tencent Inc. + +""" +# ------------------------------------------------------------------------ +import logging +import base64 +import random +import hashlib +import time +import struct +from Crypto.Cipher import AES +import xml.etree.cElementTree as ET +import socket + +from . import ierror + + +""" +Crypto.Cipher包已不再维护,开发者可以通过以下命令下载安装最新版的加解密工具包 + pip install pycryptodome +""" + + +class FormatException(Exception): + pass + + +def throw_exception(message, exception_class=FormatException): + """my define raise exception function""" + raise exception_class(message) + + +class SHA1: + """计算企业微信的消息签名接口""" + + def getSHA1(self, token, timestamp, nonce, encrypt): + """用SHA1算法生成安全签名 + @param token: 票据 + @param timestamp: 时间戳 + @param encrypt: 密文 + @param nonce: 随机字符串 + @return: 安全签名 + """ + try: + sortlist = [token, timestamp, nonce, encrypt] + sortlist.sort() + sha = hashlib.sha1() + sha.update("".join(sortlist).encode()) + return ierror.WXBizMsgCrypt_OK, sha.hexdigest() + except Exception as e: + logger = logging.getLogger() + logger.error(e) + return ierror.WXBizMsgCrypt_ComputeSignature_Error, None + + +class XMLParse: + """提供提取消息格式中的密文及生成回复消息格式的接口""" + + # xml消息模板 + AES_TEXT_RESPONSE_TEMPLATE = """ + + +%(timestamp)s + +""" + + def extract(self, xmltext): + """提取出xml数据包中的加密消息 + @param xmltext: 待提取的xml字符串 + @return: 提取出的加密消息字符串 + """ + try: + xml_tree = ET.fromstring(xmltext) + encrypt = xml_tree.find("Encrypt") + return ierror.WXBizMsgCrypt_OK, encrypt.text + except Exception as e: + logger = logging.getLogger() + logger.error(e) + return ierror.WXBizMsgCrypt_ParseXml_Error, None + + def generate(self, encrypt, signature, timestamp, nonce): + """生成xml消息 + @param encrypt: 加密后的消息密文 + @param signature: 安全签名 + @param timestamp: 时间戳 + @param nonce: 随机字符串 + @return: 生成的xml字符串 + """ + resp_dict = { + 'msg_encrypt': encrypt, + 'msg_signaturet': signature, + 'timestamp': timestamp, + 'nonce': nonce, + } + resp_xml = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict + return resp_xml + + +class PKCS7Encoder(): + """提供基于PKCS7算法的加解密接口""" + + block_size = 32 + + def encode(self, text): + """ 对需要加密的明文进行填充补位 + @param text: 需要进行填充补位操作的明文 + @return: 补齐明文字符串 + """ + text_length = len(text) + # 计算需要填充的位数 + amount_to_pad = self.block_size - (text_length % self.block_size) + if amount_to_pad == 0: + amount_to_pad = self.block_size + # 获得补位所用的字符 + pad = chr(amount_to_pad) + return text + (pad * amount_to_pad).encode() + + def decode(self, decrypted): + """删除解密后明文的补位字符 + @param decrypted: 解密后的明文 + @return: 删除补位字符后的明文 + """ + pad = ord(decrypted[-1]) + if pad < 1 or pad > 32: + pad = 0 + return decrypted[:-pad] + + +class Prpcrypt(object): + """提供接收和推送给企业微信消息的加解密接口""" + + def __init__(self, key): + + # self.key = base64.b64decode(key+"=") + self.key = key + # 设置加解密模式为AES的CBC模式 + self.mode = AES.MODE_CBC + + def encrypt(self, text, receiveid): + """对明文进行加密 + @param text: 需要加密的明文 + @return: 加密得到的字符串 + """ + # 16位随机字符串添加到明文开头 + text = text.encode() + text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode() + + # 使用自定义的填充方式对明文进行补位填充 + pkcs7 = PKCS7Encoder() + text = pkcs7.encode(text) + # 加密 + cryptor = AES.new(self.key, self.mode, self.key[:16]) + try: + ciphertext = cryptor.encrypt(text) + # 使用BASE64对加密后的字符串进行编码 + return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext) + except Exception as e: + logger = logging.getLogger() + logger.error(e) + return ierror.WXBizMsgCrypt_EncryptAES_Error, None + + def decrypt(self, text, receiveid): + """对解密后的明文进行补位删除 + @param text: 密文 + @return: 删除填充补位后的明文 + """ + try: + cryptor = AES.new(self.key, self.mode, self.key[:16]) + # 使用BASE64对密文进行解码,然后AES-CBC解密 + plain_text = cryptor.decrypt(base64.b64decode(text)) + except Exception as e: + logger = logging.getLogger() + logger.error(e) + return ierror.WXBizMsgCrypt_DecryptAES_Error, None + try: + pad = plain_text[-1] + # 去掉补位字符串 + # pkcs7 = PKCS7Encoder() + # plain_text = pkcs7.encode(plain_text) + # 去除16位随机字符串 + content = plain_text[16:-pad] + xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0]) + xml_content = content[4: xml_len + 4] + from_receiveid = content[xml_len + 4:] + except Exception as e: + logger = logging.getLogger() + logger.error(e) + return ierror.WXBizMsgCrypt_IllegalBuffer, None + + if from_receiveid.decode('utf8') != receiveid: + return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None + return 0, xml_content + + def get_random_str(self): + """ 随机生成16位字符串 + @return: 16位字符串 + """ + return str(random.randint(1000000000000000, 9999999999999999)).encode() + + +class WXBizMsgCrypt(object): + # 构造函数 + def __init__(self, sToken, sEncodingAESKey, sReceiveId): + try: + self.key = base64.b64decode(sEncodingAESKey + "=") + assert len(self.key) == 32 + except: + throw_exception("[error]: EncodingAESKey unvalid !", FormatException) + # return ierror.WXBizMsgCrypt_IllegalAesKey,None + self.m_sToken = sToken + self.m_sReceiveId = sReceiveId + + # 验证URL + # @param sMsgSignature: 签名串,对应URL参数的msg_signature + # @param sTimeStamp: 时间戳,对应URL参数的timestamp + # @param sNonce: 随机串,对应URL参数的nonce + # @param sEchoStr: 随机串,对应URL参数的echostr + # @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效 + # @return:成功0,失败返回对应的错误码 + + def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr): + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId) + return ret, sReplyEchoStr + + def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None): + # 将企业回复用户的消息加密打包 + # @param sReplyMsg: 企业号待回复用户的消息,xml格式的字符串 + # @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间 + # @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce + # sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的xml格式的字符串, + # return:成功0,sEncryptMsg,失败返回对应的错误码None + pc = Prpcrypt(self.key) + ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId) + encrypt = encrypt.decode('utf8') + if ret != 0: + return ret, None + if timestamp is None: + timestamp = str(int(time.time())) + # 生成安全签名 + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt) + if ret != 0: + return ret, None + xmlParse = XMLParse() + return ret, xmlParse.generate(encrypt, signature, timestamp, sNonce) + + def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce): + # 检验消息的真实性,并且获取解密后的明文 + # @param sMsgSignature: 签名串,对应URL参数的msg_signature + # @param sTimeStamp: 时间戳,对应URL参数的timestamp + # @param sNonce: 随机串,对应URL参数的nonce + # @param sPostData: 密文,对应POST请求的数据 + # xml_content: 解密后的原文,当return返回0时有效 + # @return: 成功0,失败返回对应的错误码 + # 验证安全签名 + xmlParse = XMLParse() + ret, encrypt = xmlParse.extract(sPostData) + if ret != 0: + return ret, None + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret, xml_content = pc.decrypt(encrypt, self.m_sReceiveId) + return ret, xml_content diff --git a/libs/wecom_api/__init__.py b/libs/wecom_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/wecom_api/api.py b/libs/wecom_api/api.py new file mode 100644 index 00000000..0d02ffcd --- /dev/null +++ b/libs/wecom_api/api.py @@ -0,0 +1,305 @@ +from quart import request +from .WXBizMsgCrypt3 import WXBizMsgCrypt +import base64 +import binascii +import httpx +from quart import Quart +import xml.etree.ElementTree as ET +from typing import Callable, Dict, Any +from .wecomevent import WecomEvent +from pkg.platform.types import events as platform_events, message as platform_message +import aiofiles + + +class WecomClient(): + def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str,contacts_secret:str): + self.corpid = corpid + self.secret = secret + self.access_token_for_contacts ='' + self.token = token + self.aes = EncodingAESKey + self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin' + self.access_token = '' + self.secret_for_contacts = contacts_secret + self.app = Quart(__name__) + self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid) + self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) + self._message_handlers = { + "example":[], + } + + #access——token操作 + async def check_access_token(self): + return bool(self.access_token and self.access_token.strip()) + + async def check_access_token_for_contacts(self): + return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip()) + + async def get_access_token(self,secret): + url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}' + async with httpx.AsyncClient() as client: + response = await client.get(url) + data = response.json() + if 'access_token' in data: + return data['access_token'] + else: + raise Exception(f"未获取access token: {data}") + + async def get_users(self): + if not self.check_access_token_for_contacts(): + self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts) + + url = self.base_url+'/user/list_id?access_token='+self.access_token_for_contacts + async with httpx.AsyncClient() as client: + params = { + "cursor":"", + "limit":10000, + } + response = await client.post(url,json=params) + data = response.json() + if data['errcode'] == 0: + dept_users = data['dept_user'] + userid = [] + for user in dept_users: + userid.append(user["userid"]) + return userid + else: + raise Exception("未获取用户") + + async def send_to_all(self,content:str): + if not self.check_access_token_for_contacts(): + self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts) + + url = self.base_url+'/message/send?access_token='+self.access_token_for_contacts + user_ids = await self.get_users() + user_ids_string = "|".join(user_ids) + async with httpx.AsyncClient() as client: + params = { + "touser" : user_ids_string, + "msgtype" : "text", + "agentid" : 1000002, + "text" : { + "content" : content, + }, + "safe":0, + "enable_id_trans": 0, + "enable_duplicate_check": 0, + "duplicate_check_interval": 1800 + } + response = await client.post(url,json=params) + data = response.json() + if data['errcode'] != 0: + raise Exception("Failed to send message: "+str(data)) + + async def send_image(self,user_id:str,agent_id:int,media_id:str): + if not await self.check_access_token(): + self.access_token = await self.get_access_token(self.secret) + url = self.base_url+'/media/upload?access_token='+self.access_token + async with httpx.AsyncClient() as client: + params = { + "touser" : user_id, + "toparty" : "", + "totag":"", + "agentid" : agent_id, + "msgtype" : "image", + "image" : { + "media_id" : media_id, + }, + "safe":0, + "enable_id_trans": 0, + "enable_duplicate_check": 0, + "duplicate_check_interval": 1800 + } + response = await client.post(url,json=params) + data = response.json() + if data['errcode'] != 0: + raise Exception("Failed to send image: "+str(data)) + + async def send_private_msg(self,user_id:str, agent_id:int,content:str): + if not await self.check_access_token(): + self.access_token = await self.get_access_token(self.secret) + + url = self.base_url+'/message/send?access_token='+self.access_token + + async with httpx.AsyncClient() as client: + params={ + "touser" : user_id, + "msgtype" : "text", + "agentid" : agent_id, + "text" : { + "content" : content, + }, + "safe":0, + "enable_id_trans": 0, + "enable_duplicate_check": 0, + "duplicate_check_interval": 1800 + } + response = await client.post(url,json=params) + data = response.json() + + if data['errcode'] != 0: + raise Exception("Failed to send message: "+str(data)) + + async def handle_callback_request(self): + """ + 处理回调请求,包括 GET 验证和 POST 消息接收。 + """ + try: + + msg_signature = request.args.get("msg_signature") + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") + + if request.method == "GET": + echostr = request.args.get("echostr") + ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) + if ret != 0: + raise Exception(f"验证失败,错误码: {ret}") + return reply_echo_str + + elif request.method == "POST": + encrypt_msg = await request.data + ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce) + if ret != 0: + raise Exception(f"消息解密失败,错误码: {ret}") + + # 解析消息并处理 + message_data = await self.get_message(xml_msg) + if message_data: + event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象 + if event: + await self._handle_message(event) + + return "success" + except Exception as e: + return f"Error processing request: {str(e)}", 400 + + async def run_task(self, host: str, port: int, *args, **kwargs): + """ + 启动 Quart 应用。 + """ + await self.app.run_task(host=host, port=port, *args, **kwargs) + + def on_message(self, msg_type: str): + """ + 注册消息类型处理器。 + """ + def decorator(func: Callable[[WecomEvent], None]): + if msg_type not in self._message_handlers: + self._message_handlers[msg_type] = [] + self._message_handlers[msg_type].append(func) + return func + return decorator + + async def _handle_message(self, event: WecomEvent): + """ + 处理消息事件。 + """ + msg_type = event.type + if msg_type in self._message_handlers: + for handler in self._message_handlers[msg_type]: + await handler(event) + + async def get_message(self, xml_msg: str) -> Dict[str, Any]: + """ + 解析微信返回的 XML 消息并转换为字典。 + """ + root = ET.fromstring(xml_msg) + message_data = { + "ToUserName": root.find("ToUserName").text, + "FromUserName": root.find("FromUserName").text, + "CreateTime": int(root.find("CreateTime").text), + "MsgType": root.find("MsgType").text, + "Content": root.find("Content").text if root.find("Content") is not None else None, + "MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, + "AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None, + } + if message_data["MsgType"] == "image": + message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None + message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None + + return message_data + + @staticmethod + async def get_image_type(image_bytes: bytes) -> str: + """ + 通过图片的magic numbers判断图片类型 + """ + magic_numbers = { + b'\xFF\xD8\xFF': 'jpg', + b'\x89\x50\x4E\x47': 'png', + b'\x47\x49\x46': 'gif', + b'\x42\x4D': 'bmp', + b'\x00\x00\x01\x00': 'ico' + } + + for magic, ext in magic_numbers.items(): + if image_bytes.startswith(magic): + return ext + return 'jpg' # 默认返回jpg + + + async def upload_to_work(self, image: platform_message.Image): + """ + 获取 media_id + """ + if not await self.check_access_token(): + self.access_token = await self.get_access_token(self.secret) + + url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file' + file_bytes = None + file_name = "uploaded_file.txt" + + # 获取文件的二进制数据 + if image.path: + async with aiofiles.open(image.path, 'rb') as f: + file_bytes = await f.read() + file_name = image.path.split('/')[-1] + elif image.url: + file_bytes = await self.download_image_to_bytes(image.url) + file_name = image.url.split('/')[-1] + elif image.base64: + try: + base64_data = image.base64 + if ',' in base64_data: + base64_data = base64_data.split(',', 1)[1] + padding = 4 - (len(base64_data) % 4) if len(base64_data) % 4 else 0 + padded_base64 = base64_data + '=' * padding + file_bytes = base64.b64decode(padded_base64) + except binascii.Error as e: + raise ValueError(f"Invalid base64 string: {str(e)}") + else: + raise ValueError("image对象出错") + + # 设置 multipart/form-data 格式的文件 + boundary = "-------------------------acebdf13572468" + headers = { + 'Content-Type': f'multipart/form-data; boundary={boundary}' + } + body = ( + f"--{boundary}\r\n" + f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n" + f"Content-Type: application/octet-stream\r\n\r\n" + ).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8') + + # 上传文件 + async with httpx.AsyncClient() as client: + response = await client.post(url, headers=headers, content=body) + data = response.json() + if data.get('errcode', 0) != 0: + raise Exception("failed to upload file") + + return data.get('media_id') + + + async def download_image_to_bytes(self,url:str) -> bytes: + async with httpx.AsyncClient() as client: + response = await client.get(url) + response.raise_for_status() + return response.content + + #进行media_id的获取 + async def get_media_id(self, image: platform_message.Image): + + media_id = await self.upload_to_work(image=image) + return media_id diff --git a/libs/wecom_api/ierror.py b/libs/wecom_api/ierror.py new file mode 100644 index 00000000..8985b886 --- /dev/null +++ b/libs/wecom_api/ierror.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +######################################################################### +# Author: jonyqin +# Created Time: Thu 11 Sep 2014 01:53:58 PM CST +# File Name: ierror.py +# Description:定义错误码含义 +######################################################################### +WXBizMsgCrypt_OK = 0 +WXBizMsgCrypt_ValidateSignature_Error = -40001 +WXBizMsgCrypt_ParseXml_Error = -40002 +WXBizMsgCrypt_ComputeSignature_Error = -40003 +WXBizMsgCrypt_IllegalAesKey = -40004 +WXBizMsgCrypt_ValidateCorpid_Error = -40005 +WXBizMsgCrypt_EncryptAES_Error = -40006 +WXBizMsgCrypt_DecryptAES_Error = -40007 +WXBizMsgCrypt_IllegalBuffer = -40008 +WXBizMsgCrypt_EncodeBase64_Error = -40009 +WXBizMsgCrypt_DecodeBase64_Error = -40010 +WXBizMsgCrypt_GenReturnXml_Error = -40011 \ No newline at end of file diff --git a/libs/wecom_api/wecomevent.py b/libs/wecom_api/wecomevent.py new file mode 100644 index 00000000..3606cdf5 --- /dev/null +++ b/libs/wecom_api/wecomevent.py @@ -0,0 +1,179 @@ +from typing import Dict, Any, Optional + + +class WecomEvent(dict): + """ + 封装从企业微信收到的事件数据对象(字典),提供属性以获取其中的字段。 + + 除 `type` 和 `detail_type` 属性对于任何事件都有效外,其它属性是否存在(若不存在则返回 `None`)依事件类型不同而不同。 + """ + + @staticmethod + def from_payload(payload: Dict[str, Any]) -> Optional["WecomEvent"]: + """ + 从企业微信事件数据构造 `WecomEvent` 对象。 + + Args: + payload (Dict[str, Any]): 解密后的企业微信事件数据。 + + Returns: + Optional[WecomEvent]: 如果事件数据合法,则返回 WecomEvent 对象;否则返回 None。 + """ + try: + event = WecomEvent(payload) + _ = event.type, event.detail_type # 确保必须字段存在 + return event + except KeyError: + return None + + @property + def type(self) -> str: + """ + 事件类型,例如 "message"、"event"、"text" 等。 + + Returns: + str: 事件类型。 + """ + return self.get("MsgType", "") + + @property + def picurl(self) -> str: + """ + 图片链接 + """ + return self.get("PicUrl") + + @property + def detail_type(self) -> str: + """ + 事件详细类型,依 `type` 的不同而不同。例如: + - 消息事件: "text", "image", "voice", 等 + - 事件通知: "subscribe", "unsubscribe", "click", 等 + + Returns: + str: 事件详细类型。 + """ + if self.type == "event": + return self.get("Event", "") + return self.type + + @property + def name(self) -> str: + """ + 事件名,对于消息事件是 `type.detail_type`,对于其他事件是 `event_type`。 + + Returns: + str: 事件名。 + """ + return f"{self.type}.{self.detail_type}" + + @property + def user_id(self) -> Optional[str]: + """ + 用户 ID,例如消息的发送者或事件的触发者。 + + Returns: + Optional[str]: 用户 ID。 + """ + return self.get("FromUserName") + + @property + def agent_id(self) -> Optional[int]: + """ + 机器人 ID,仅在消息类型事件中存在。 + + Returns: + Optional[int]: 机器人 ID。 + """ + return self.get("AgentID") + + @property + def receiver_id(self) -> Optional[str]: + """ + 接收者 ID,例如机器人自身的企业微信 ID。 + + Returns: + Optional[str]: 接收者 ID。 + """ + return self.get("ToUserName") + + @property + def message_id(self) -> Optional[str]: + """ + 消息 ID,仅在消息类型事件中存在。 + + Returns: + Optional[str]: 消息 ID。 + """ + return self.get("MsgId") + + @property + def message(self) -> Optional[str]: + """ + 消息内容,仅在消息类型事件中存在。 + + Returns: + Optional[str]: 消息内容。 + """ + return self.get("Content") + + @property + def media_id(self) -> Optional[str]: + """ + 媒体文件 ID,仅在图片、语音等消息类型中存在。 + + Returns: + Optional[str]: 媒体文件 ID。 + """ + return self.get("MediaId") + + @property + def timestamp(self) -> Optional[int]: + """ + 事件发生的时间戳。 + + Returns: + Optional[int]: 时间戳。 + """ + return self.get("CreateTime") + + @property + def event_key(self) -> Optional[str]: + """ + 事件的 Key 值,例如点击菜单时的 `EventKey`。 + + Returns: + Optional[str]: 事件 Key。 + """ + return self.get("EventKey") + + def __getattr__(self, key: str) -> Optional[Any]: + """ + 允许通过属性访问数据中的任意字段。 + + Args: + key (str): 字段名。 + + Returns: + Optional[Any]: 字段值。 + """ + return self.get(key) + + def __setattr__(self, key: str, value: Any) -> None: + """ + 允许通过属性设置数据中的任意字段。 + + Args: + key (str): 字段名。 + value (Any): 字段值。 + """ + self[key] = value + + def __repr__(self) -> str: + """ + 生成事件对象的字符串表示。 + + Returns: + str: 字符串表示。 + """ + return f"" diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 502fc73e..a57fbbfd 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -45,10 +45,10 @@ class Query(pydantic.BaseModel): launcher_type: LauncherTypes """会话类型,platform处理阶段设置""" - launcher_id: int + launcher_id: typing.Union[int, str] """会话ID,platform处理阶段设置""" - sender_id: int + sender_id: typing.Union[int, str] """发送者ID,platform处理阶段设置""" message_event: platform_events.MessageEvent @@ -114,9 +114,9 @@ class Session(pydantic.BaseModel): """会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}""" launcher_type: LauncherTypes - launcher_id: int + launcher_id: typing.Union[int, str] - sender_id: typing.Optional[int] = 0 + sender_id: typing.Optional[typing.Union[int, str]] = 0 use_prompt_name: typing.Optional[str] = 'default' diff --git a/pkg/core/migrations/m020_wecom_config.py b/pkg/core/migrations/m020_wecom_config.py new file mode 100644 index 00000000..a501eee2 --- /dev/null +++ b/pkg/core/migrations/m020_wecom_config.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("wecom-config", 20) +class WecomConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + + for adapter in self.ap.platform_cfg.data['platform-adapters']: + if adapter['adapter'] == 'wecom': + return False + + return True + + async def run(self): + """执行迁移""" + self.ap.platform_cfg.data['platform-adapters'].append({ + "adapter": "wecom", + "enable": False, + "host": "0.0.0.0", + "port": 2290, + "corpid": "", + "secret": "", + "token": "", + "EncodingAESKey": "", + "contacts_secret": "" + }) + + await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index 2dfaeb50..5d96d6da 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -8,6 +8,7 @@ from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config, m014_force_delay_config from ..migrations import m015_gitee_ai_config, m016_dify_service_api, m017_dify_api_timeout_params, m018_xai_config, m019_zhipuai_config +from ..migrations import m020_wecom_config @stage.stage_class("MigrationStage") diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 3113b3bf..807bec05 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -68,7 +68,7 @@ async def _process_query(selected_query): except Exception as e: # traceback.print_exc() self.ap.logger.error(f"控制器循环出错: {e}") - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + self.ap.logger.error(f"Traceback: {traceback.format_exc()}") async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): """检查输出 @@ -163,29 +163,30 @@ async def _execute_from_stage( async def process_query(self, query: entities.Query): """处理请求 """ + try: - # ======== 触发 MessageReceived 事件 ======== - event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived + # ======== 触发 MessageReceived 事件 ======== + event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived - event_ctx = await self.ap.plugin_mgr.emit_event( - event=event_type( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - message_chain=query.message_chain, - query=query + event_ctx = await self.ap.plugin_mgr.emit_event( + event=event_type( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + message_chain=query.message_chain, + query=query + ) ) - ) - if event_ctx.is_prevented_default(): - return - - self.ap.logger.debug(f"Processing query {query}") + if event_ctx.is_prevented_default(): + return + + self.ap.logger.debug(f"Processing query {query}") - try: await self._execute_from_stage(0, query) except Exception as e: - self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}") + inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' + self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") finally: self.ap.logger.debug(f"Query {query} processed") diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index 45f16e66..5b18b08c 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio - +import typing from ..core import entities from ..platform import adapter as msadapter @@ -29,8 +29,8 @@ def __init__(self): async def add_query( self, launcher_type: entities.LauncherTypes, - launcher_id: int, - sender_id: int, + launcher_id: typing.Union[int, str], + sender_id: typing.Union[int, str], message_event: platform_events.MessageEvent, message_chain: platform_message.MessageChain, adapter: msadapter.MessageSourceAdapter diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index af4def16..9b418dd2 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -31,7 +31,7 @@ async def initialize(self): pass @abc.abstractmethod - async def require_access(self, launcher_type: str, launcher_id: int) -> bool: + async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: """进入处理流程 这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。 @@ -46,7 +46,7 @@ async def require_access(self, launcher_type: str, launcher_id: int) -> bool: raise NotImplementedError @abc.abstractmethod - async def release_access(self, launcher_type: str, launcher_id: int): + async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]): """退出处理流程 Args: diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index 3bf8a5e5..3cc1ab94 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio import time +import typing from .. import algo # 固定窗口算法 @@ -29,7 +30,7 @@ async def initialize(self): self.containers_lock = asyncio.Lock() self.containers = {} - async def require_access(self, launcher_type: str, launcher_id: int) -> bool: + async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: # 加锁,找容器 container: SessionContainer = None @@ -83,5 +84,5 @@ async def require_access(self, launcher_type: str, launcher_id: int) -> bool: # 返回True return True - async def release_access(self, launcher_type: str, launcher_id: int): + async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]): pass diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index 7cf64a12..e04ec29c 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -52,7 +52,7 @@ def __init__(self, config: dict, ap: app.Application): self.config = config self.ap = ap - async def send_message( + async def send_message( self, target_type: str, target_id: str, diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 2e241c7a..f8809750 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -37,7 +37,7 @@ def __init__(self, ap: app.Application = None): async def initialize(self): - from .sources import nakuru, aiocqhttp, qqbotpy + from .sources import nakuru, aiocqhttp, qqbotpy,wecom async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter): diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py new file mode 100644 index 00000000..dda22816 --- /dev/null +++ b/pkg/platform/sources/wecom.py @@ -0,0 +1,258 @@ +from __future__ import annotations +import typing +import asyncio +import traceback +import time +import datetime + +import aiocqhttp +import aiohttp +from libs.wecom_api.api import WecomClient +from pkg.platform.adapter import MessageSourceAdapter +from pkg.platform.types import events as platform_events, message as platform_message +from libs.wecom_api.wecomevent import WecomEvent +from pkg.core import app +from .. import adapter +from ...pipeline.longtext.strategies import forward +from ...core import app +from ..types import message as platform_message +from ..types import events as platform_events +from ..types import entities as platform_entities +from ...command.errors import ParamNotEnoughError +from ...utils import image + +class WecomMessageConverter(adapter.MessageConverter): + + @staticmethod + async def yiri2target( + message_chain: platform_message.MessageChain, bot: WecomClient + ): + content_list = [] + + [ + { + "type": "text", + "content": "text", + }, + { + "type": "image", + "media_id": "media_id", + } + ] + + for msg in message_chain: + if type(msg) is platform_message.Plain: + content_list.append({ + "type": "text", + "content": msg.text, + }) + elif type(msg) is platform_message.Image: + content_list.append({ + "type": "image", + "media_id": await bot.get_media_id(msg), + }) + elif type(msg) is platform_message.Forward: + for node in msg.node_list: + content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot))) + else: + content_list.append({ + "type": "text", + "content": str(msg), + }) + + return content_list + + @staticmethod + async def target2yiri(message: str, message_id: int = -1): + yiri_msg_list = [] + yiri_msg_list.append( + platform_message.Source(id=message_id, time=datetime.datetime.now()) + ) + + yiri_msg_list.append(platform_message.Plain(text=message)) + chain = platform_message.MessageChain(yiri_msg_list) + + return chain + + @staticmethod + async def target2yiri_image(picurl: str, message_id: int = -1): + yiri_msg_list = [] + yiri_msg_list.append( + platform_message.Source(id=message_id, time=datetime.datetime.now()) + ) + image_base64, image_format = await image.get_wecom_image_base64(pic_url=picurl) + yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}")) + chain = platform_message.MessageChain(yiri_msg_list) + + return chain + + +class WecomEventConverter: + + @staticmethod + async def yiri2target( + event: platform_events.Event, bot_account_id: int, bot: WecomClient + ) -> WecomEvent: + # only for extracting user information + + if type(event) is platform_events.GroupMessage: + pass + + if type(event) is platform_events.FriendMessage: + + payload = { + "MsgType": "text", + "Content": '', + "FromUserName": event.sender.id, + "ToUserName": bot_account_id, + "CreateTime": int(datetime.datetime.now().timestamp()), + "AgentID": event.sender.nickname, + } + wecom_event = WecomEvent.from_payload(payload=payload) + if not wecom_event: + raise ValueError("无法从 message_data 构造 WecomEvent 对象") + + return wecom_event + + @staticmethod + async def target2yiri(event: WecomEvent): + """ + 将 WecomEvent 转换为平台的 FriendMessage 对象。 + + Args: + event (WecomEvent): 企业微信事件。 + + Returns: + platform_events.FriendMessage: 转换后的 FriendMessage 对象。 + """ + # 转换消息链 + if event.type == "text": + yiri_chain = await WecomMessageConverter.target2yiri( + event.message, event.message_id + ) + + friend = platform_entities.Friend( + id=event.user_id, + nickname=str(event.agent_id), + remark="", + ) + + return platform_events.FriendMessage( + sender=friend, message_chain=yiri_chain, time=event.timestamp + ) + elif event.type == "image": + friend = platform_entities.Friend( + id=event.user_id, + nickname=str(event.agent_id), + remark="", + ) + + yiri_chain = await WecomMessageConverter.target2yiri_image( + picurl=event.picurl, message_id=event.message_id + ) + + return platform_events.FriendMessage( + sender=friend, message_chain=yiri_chain, time=event.timestamp + ) + + +@adapter.adapter_class("wecom") +class WecomeAdapter(adapter.MessageSourceAdapter): + + bot: WecomClient + ap: app.Application + bot_account_id: str + message_converter: WecomMessageConverter = WecomMessageConverter() + event_converter: WecomEventConverter = WecomEventConverter() + config: dict + ap: app.Application + + def __init__(self, config: dict, ap: app.Application): + self.config = config + + self.ap = ap + + required_keys = [ + "corpid", + "secret", + "token", + "EncodingAESKey", + "contacts_secret", + ] + missing_keys = [key for key in required_keys if key not in config] + if missing_keys: + raise ParamNotEnoughError("企业微信缺少相关配置项,请查看文档或联系管理员") + + self.bot = WecomClient( + corpid=config["corpid"], + secret=config["secret"], + token=config["token"], + EncodingAESKey=config["EncodingAESKey"], + contacts_secret=config["contacts_secret"], + ) + + async def reply_message( + self, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, + quote_origin: bool = False, + ): + + Wecom_event = await WecomEventConverter.yiri2target( + message_source, self.bot_account_id, self.bot + ) + content_list = await WecomMessageConverter.yiri2target(message, self.bot) + + for content in content_list: + if content["type"] == "text": + await self.bot.send_private_msg(Wecom_event.user_id, Wecom_event.agent_id, content["content"]) + elif content["type"] == "image": + await self.bot.send_image(Wecom_event.user_id, Wecom_event.agent_id, content["media_id"]) + + async def send_message( + self, target_type: str, target_id: str, message: platform_message.MessageChain + ): + pass + + def register_listener( + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[ + [platform_events.Event, adapter.MessageSourceAdapter], None + ], + ): + async def on_message(event: WecomEvent): + self.bot_account_id = event.receiver_id + try: + return await callback( + await self.event_converter.target2yiri(event), self + ) + except: + traceback.print_exc() + + if event_type == platform_events.FriendMessage: + self.bot.on_message("text")(on_message) + self.bot.on_message("image")(on_message) + elif event_type == platform_events.GroupMessage: + pass + + async def run_async(self): + async def shutdown_trigger_placeholder(): + while True: + await asyncio.sleep(1) + + await self.bot.run_task( + host=self.config["host"], + port=self.config["port"], + shutdown_trigger=shutdown_trigger_placeholder, + ) + + async def kill(self) -> bool: + return False + + async def unregister_listener( + self, + event_type: type, + callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None], + ): + return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/types/entities.py b/pkg/platform/types/entities.py index a9aa32a2..57d9372a 100644 --- a/pkg/platform/types/entities.py +++ b/pkg/platform/types/entities.py @@ -25,7 +25,7 @@ def get_name(self) -> str: class Friend(Entity): """好友。""" - id: int + id: typing.Union[int, str] """QQ 号。""" nickname: typing.Optional[str] """昵称。""" @@ -52,7 +52,7 @@ def __repr__(self) -> str: class Group(Entity): """群。""" - id: int + id: typing.Union[int, str] """群号。""" name: str """群名称。""" @@ -67,7 +67,7 @@ def get_name(self) -> str: class GroupMember(Entity): """群成员。""" - id: int + id: typing.Union[int, str] """QQ 号。""" member_name: str """群成员名称。""" @@ -92,7 +92,7 @@ def get_name(self) -> str: class Client(Entity): """来自其他客户端的用户。""" - id: int + id: typing.Union[int, str] """识别 id。""" platform: str """来源平台。""" @@ -105,7 +105,7 @@ def get_name(self) -> str: class Subject(pydantic.BaseModel): """另一种实体类型表示。""" - id: int + id: typing.Union[int, str] """QQ 号或群号。""" kind: typing.Literal['Friend', 'Group', 'Stranger'] """类型。""" diff --git a/pkg/platform/types/message.py b/pkg/platform/types/message.py index 9bd33be7..45d37a41 100644 --- a/pkg/platform/types/message.py +++ b/pkg/platform/types/message.py @@ -485,11 +485,11 @@ class Quote(MessageComponent): """消息组件类型。""" id: typing.Optional[int] = None """被引用回复的原消息的 message_id。""" - group_id: typing.Optional[int] = None + group_id: typing.Optional[typing.Union[int, str]] = None """被引用回复的原消息所接收的群号,当为好友消息时为0。""" - sender_id: typing.Optional[int] = None + sender_id: typing.Optional[typing.Union[int, str]] = None """被引用回复的原消息的发送者的QQ号。""" - target_id: typing.Optional[int] = None + target_id: typing.Optional[typing.Union[int, str]] = None """被引用回复的原消息的接收者者的QQ号(或群号)。""" origin: MessageChain """被引用回复的原消息的消息链对象。""" @@ -749,7 +749,7 @@ async def from_local( class ForwardMessageNode(pydantic.BaseModel): """合并转发中的一条消息。""" - sender_id: typing.Optional[int] = None + sender_id: typing.Optional[typing.Union[int, str]] = None """发送人QQ号。""" sender_name: typing.Optional[str] = None """显示名称。""" diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py index f1aff459..152ac39f 100644 --- a/pkg/plugin/events.py +++ b/pkg/plugin/events.py @@ -25,10 +25,10 @@ class PersonMessageReceived(BaseEventModel): launcher_type: str """发起对象类型(group/person)""" - launcher_id: int + launcher_id: typing.Union[int, str] """发起对象ID(群号/QQ号)""" - sender_id: int + sender_id: typing.Union[int, str] """发送者ID(QQ号)""" message_chain: platform_message.MessageChain @@ -39,9 +39,9 @@ class GroupMessageReceived(BaseEventModel): launcher_type: str - launcher_id: int + launcher_id: typing.Union[int, str] - sender_id: int + sender_id: typing.Union[int, str] message_chain: platform_message.MessageChain @@ -51,9 +51,9 @@ class PersonNormalMessageReceived(BaseEventModel): launcher_type: str - launcher_id: int + launcher_id: typing.Union[int, str] - sender_id: int + sender_id: typing.Union[int, str] text_message: str @@ -69,9 +69,9 @@ class PersonCommandSent(BaseEventModel): launcher_type: str - launcher_id: int + launcher_id: typing.Union[int, str] - sender_id: int + sender_id: typing.Union[int, str] command: str @@ -93,9 +93,9 @@ class GroupNormalMessageReceived(BaseEventModel): launcher_type: str - launcher_id: int + launcher_id: typing.Union[int, str] - sender_id: int + sender_id: typing.Union[int, str] text_message: str @@ -111,9 +111,9 @@ class GroupCommandSent(BaseEventModel): launcher_type: str - launcher_id: int + launcher_id: typing.Union[int, str] - sender_id: int + sender_id: typing.Union[int, str] command: str @@ -135,9 +135,9 @@ class NormalMessageResponded(BaseEventModel): launcher_type: str - launcher_id: int + launcher_id: typing.Union[int, str] - sender_id: int + sender_id: typing.Union[int, str] session: core_entities.Session """会话对象""" diff --git a/pkg/utils/image.py b/pkg/utils/image.py index 06885175..6f769b26 100644 --- a/pkg/utils/image.py +++ b/pkg/utils/image.py @@ -7,6 +7,31 @@ import aiohttp import PIL.Image +async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]: + """ + 下载企业微信图片并转换为 base64 + :param pic_url: 企业微信图片URL + :return: (base64_str, image_format) + """ + async with aiohttp.ClientSession() as session: + async with session.get(pic_url) as response: + if response.status != 200: + raise Exception(f"Failed to download image: {response.status}") + + # 读取图片数据 + image_data = await response.read() + + # 获取图片格式 + content_type = response.headers.get('Content-Type', '') + image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg' + + # 转换为 base64 + import base64 + image_base64 = base64.b64encode(image_data).decode('utf-8') + + return image_base64, image_format + + def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]: """获取QQ图片的下载链接""" diff --git a/requirements.txt b/requirements.txt index 5d142ce2..845ae48c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,7 @@ aiofiles aioshutil argon2-cffi pyjwt +pycryptodome # indirect taskgroup==0.0.0a4 \ No newline at end of file diff --git a/templates/platform.json b/templates/platform.json index 5ed9e353..299656fd 100644 --- a/templates/platform.json +++ b/templates/platform.json @@ -24,6 +24,17 @@ "public_guild_messages", "direct_message" ] + }, + { + "adapter": "wecom", + "enable": false, + "host": "0.0.0.0", + "port": 2290, + "corpid": "", + "secret": "", + "token": "", + "EncodingAESKey": "", + "contacts_secret": "" } ], "track-function-calls": true,