diff --git a/app.py b/app.py index ff2a6c774..5a1e01aac 100644 --- a/app.py +++ b/app.py @@ -27,7 +27,7 @@ def func(_signo, _stack_frame): def start_channel(channel_name: str): channel = channel_factory.create_channel(channel_name) - if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", + if channel_name in ["wx", "wxy", "terminal", "wechatmp","web", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU, const.DINGTALK]: PluginManager().load_plugins() diff --git a/channel/channel_factory.py b/channel/channel_factory.py index c2c6937cf..43a8be3eb 100644 --- a/channel/channel_factory.py +++ b/channel/channel_factory.py @@ -21,6 +21,9 @@ def create_channel(channel_type) -> Channel: elif channel_type == "terminal": from channel.terminal.terminal_channel import TerminalChannel ch = TerminalChannel() + elif channel_type == 'web': + from channel.web.web_channel import WebChannel + ch = WebChannel() elif channel_type == "wechatmp": from channel.wechatmp.wechatmp_channel import WechatMPChannel ch = WechatMPChannel(passive_reply=True) diff --git a/channel/web/README.md b/channel/web/README.md new file mode 100644 index 000000000..34484dd6c --- /dev/null +++ b/channel/web/README.md @@ -0,0 +1,7 @@ +# Web channel +使用SSE(Server-Sent Events,服务器推送事件)实现,提供了一个默认的网页。也可以自己实现加入api + +#使用方法 +- 在配置文件中channel_type填入web即可 +- 访问地址 http://localhost:9899 +- port可以在配置项 web_port中设置 diff --git a/channel/web/chat.html b/channel/web/chat.html new file mode 100644 index 000000000..5afd911a0 --- /dev/null +++ b/channel/web/chat.html @@ -0,0 +1,165 @@ + + + + + + Chat + + + +
+
+
+ + +
+
+ + + + \ No newline at end of file diff --git a/channel/web/web_channel.py b/channel/web/web_channel.py new file mode 100644 index 000000000..661eb3862 --- /dev/null +++ b/channel/web/web_channel.py @@ -0,0 +1,204 @@ +import sys +import time +import web +import json +from queue import Queue +from bridge.context import * +from bridge.reply import Reply, ReplyType +from channel.chat_channel import ChatChannel, check_prefix +from channel.chat_message import ChatMessage +from common.log import logger +from common.singleton import singleton +from config import conf +import os + + +class WebMessage(ChatMessage): + def __init__( + self, + msg_id, + content, + ctype=ContextType.TEXT, + from_user_id="User", + to_user_id="Chatgpt", + other_user_id="Chatgpt", + ): + self.msg_id = msg_id + self.ctype = ctype + self.content = content + self.from_user_id = from_user_id + self.to_user_id = to_user_id + self.other_user_id = other_user_id + + +@singleton +class WebChannel(ChatChannel): + NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE] + _instance = None + + # def __new__(cls): + # if cls._instance is None: + # cls._instance = super(WebChannel, cls).__new__(cls) + # return cls._instance + + def __init__(self): + super().__init__() + self.message_queues = {} # 为每个用户存储一个消息队列 + self.msg_id_counter = 0 # 添加消息ID计数器 + + def _generate_msg_id(self): + """生成唯一的消息ID""" + self.msg_id_counter += 1 + return str(int(time.time())) + str(self.msg_id_counter) + + def send(self, reply: Reply, context: Context): + try: + if reply.type == ReplyType.IMAGE: + from PIL import Image + + image_storage = reply.content + image_storage.seek(0) + img = Image.open(image_storage) + print("") + img.show() + elif reply.type == ReplyType.IMAGE_URL: + import io + + import requests + from PIL import Image + + img_url = reply.content + pic_res = requests.get(img_url, stream=True) + image_storage = io.BytesIO() + for block in pic_res.iter_content(1024): + image_storage.write(block) + image_storage.seek(0) + img = Image.open(image_storage) + print(img_url) + img.show() + else: + print(reply.content) + + # 获取用户ID,如果没有则使用默认值 + # user_id = getattr(context.get("session", None), "session_id", "default_user") + user_id = context["receiver"] + # 确保用户有对应的消息队列 + if user_id not in self.message_queues: + self.message_queues[user_id] = Queue() + + # 将消息放入对应用户的队列 + message_data = { + "type": str(reply.type), + "content": reply.content, + "timestamp": time.time() + } + self.message_queues[user_id].put(message_data) + logger.debug(f"Message queued for user {user_id}") + + except Exception as e: + logger.error(f"Error in send method: {e}") + raise + + def sse_handler(self, user_id): + """ + Handle Server-Sent Events (SSE) for real-time communication. + """ + web.header('Content-Type', 'text/event-stream') + web.header('Cache-Control', 'no-cache') + web.header('Connection', 'keep-alive') + + # 确保用户有消息队列 + if user_id not in self.message_queues: + self.message_queues[user_id] = Queue() + + try: + while True: + try: + # 发送心跳 + yield f": heartbeat\n\n" + + # 非阻塞方式获取消息 + if not self.message_queues[user_id].empty(): + message = self.message_queues[user_id].get_nowait() + yield f"data: {json.dumps(message)}\n\n" + time.sleep(0.5) + except Exception as e: + logger.error(f"SSE Error: {e}") + break + finally: + # 清理资源 + if user_id in self.message_queues: + # 只有当队列为空时才删除 + if self.message_queues[user_id].empty(): + del self.message_queues[user_id] + + def post_message(self): + """ + Handle incoming messages from users via POST request. + """ + try: + data = web.data() # 获取原始POST数据 + json_data = json.loads(data) + user_id = json_data.get('user_id', 'default_user') + prompt = json_data.get('message', '') + except json.JSONDecodeError: + return json.dumps({"status": "error", "message": "Invalid JSON"}) + except Exception as e: + return json.dumps({"status": "error", "message": str(e)}) + + if not prompt: + return json.dumps({"status": "error", "message": "No message provided"}) + + try: + msg_id = self._generate_msg_id() + context = self._compose_context(ContextType.TEXT, prompt, msg=WebMessage(msg_id, + prompt, + from_user_id=user_id, + other_user_id = user_id + )) + context["isgroup"] = False + # context["session"] = web.storage(session_id=user_id) + + if not context: + return json.dumps({"status": "error", "message": "Failed to process message"}) + + self.produce(context) + return json.dumps({"status": "success", "message": "Message received"}) + + except Exception as e: + logger.error(f"Error processing message: {e}") + return json.dumps({"status": "error", "message": "Internal server error"}) + + def chat_page(self): + """Serve the chat HTML page.""" + file_path = os.path.join(os.path.dirname(__file__), 'chat.html') # 使用绝对路径 + with open(file_path, 'r', encoding='utf-8') as f: + return f.read() + + def startup(self): + logger.setLevel("WARN") + print("\nWeb Channel is running. Send POST requests to /message to send messages.") + + urls = ( + '/sse/(.+)', 'SSEHandler', # 修改路由以接收用户ID + '/message', 'MessageHandler', + '/chat', 'ChatHandler', + ) + port = conf().get("web_port", 9899) + app = web.application(urls, globals(), autoreload=False) + web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) + + +class SSEHandler: + def GET(self, user_id): + return WebChannel().sse_handler(user_id) + + +class MessageHandler: + def POST(self): + return WebChannel().post_message() + + +class ChatHandler: + def GET(self): + return WebChannel().chat_page() diff --git a/config.py b/config.py index a28c5b959..b03aec225 100644 --- a/config.py +++ b/config.py @@ -179,6 +179,7 @@ "Minimax_api_key": "", "Minimax_group_id": "", "Minimax_base_url": "", + "web_port": 9899, }