Skip to content

Commit

Permalink
增加web channel
Browse files Browse the repository at this point in the history
  • Loading branch information
stonyz committed Nov 27, 2024
1 parent cf84e57 commit 71c18c0
Show file tree
Hide file tree
Showing 6 changed files with 381 additions and 1 deletion.
2 changes: 1 addition & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def func(_signo, _stack_frame):

def start_channel(channel_name: str):
channel = channel_factory.create_channel(channel_name)
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
if channel_name in ["wx", "wxy", "terminal", "wechatmp","web", "wechatmp_service", "wechatcom_app", "wework",
const.FEISHU, const.DINGTALK]:
PluginManager().load_plugins()

Expand Down
3 changes: 3 additions & 0 deletions channel/channel_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def create_channel(channel_type) -> Channel:
elif channel_type == "terminal":
from channel.terminal.terminal_channel import TerminalChannel
ch = TerminalChannel()
elif channel_type == 'web':
from channel.web.web_channel import WebChannel
ch = WebChannel()
elif channel_type == "wechatmp":
from channel.wechatmp.wechatmp_channel import WechatMPChannel
ch = WechatMPChannel(passive_reply=True)
Expand Down
7 changes: 7 additions & 0 deletions channel/web/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Web channel
使用SSE(Server-Sent Events,服务器推送事件)实现,提供了一个默认的网页。也可以自己实现加入api

#使用方法
- 在配置文件中channel_type填入web即可
- 访问地址 http://localhost:9899
- port可以在配置项 web_port中设置
165 changes: 165 additions & 0 deletions channel/web/chat.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
<!DOCTYPE html>
<html lang="zh">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Chat</title>
<style>
body {
font-family: Arial, sans-serif;
display: flex;
flex-direction: column;
height: 100vh; /* 占据所有高度 */
margin: 0;
/* background-color: #f8f9fa; */
}
#chat-container {
display: flex;
flex-direction: column;
width: 100%;
max-width: 500px;
margin: auto;
border: 1px solid #ccc;
border-radius: 5px;
overflow: hidden;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
flex: 1; /* 使聊天容器占据剩余空间 */
}
#messages {
flex-direction: column;
display: flex;
flex: 1;
overflow-y: auto;
padding: 10px;
overflow-y: auto;
border-bottom: 1px solid #ccc;
background-color: #ffffff;
}

.message {
margin: 5px 0; /* 间隔 */
padding: 10px 15px; /* 内边距 */
border-radius: 15px; /* 圆角 */
max-width: 80%; /* 限制最大宽度 */
min-width: 80px; /* 设置最小宽度 */
min-height: 40px; /* 设置最小高度 */
word-wrap: break-word; /* 自动换行 */
position: relative; /* 时间戳定位 */
display: inline-block; /* 内容自适应宽度 */
box-sizing: border-box; /* 包括内边距和边框 */
flex-shrink: 0; /* 禁止高度被压缩 */
word-wrap: break-word; /* 自动换行,防止单行过长 */
white-space: normal; /* 允许正常换行 */
overflow: hidden;
}

.bot {
background-color: #f1f1f1; /* 灰色背景 */
color: black; /* 黑色字体 */
align-self: flex-start; /* 左对齐 */
margin-right: auto; /* 确保消息靠左 */
text-align: left; /* 内容左对齐 */
}

.user {
background-color: #2bc840; /* 蓝色背景 */
align-self: flex-end; /* 右对齐 */
margin-left: auto; /* 确保消息靠右 */
text-align: left; /* 内容左对齐 */
}
.timestamp {
font-size: 0.8em; /* 时间戳字体大小 */
color: rgba(0, 0, 0, 0.5); /* 半透明黑色 */
margin-bottom: 5px; /* 时间戳下方间距 */
display: block; /* 时间戳独占一行 */
}
#input-container {
display: flex;
padding: 10px;
background-color: #ffffff;
border-top: 1px solid #ccc;
}
#input {
flex: 1;
padding: 10px;
border: 1px solid #ccc;
border-radius: 5px;
margin-right: 10px;
}
#send {
padding: 10px;
border: none;
background-color: #007bff;
color: white;
border-radius: 5px;
cursor: pointer;
}
#send:hover {
background-color: #0056b3;
}
</style>
</head>
<body>
<div id="chat-container">
<div id="messages"></div>
<div id="input-container">
<input type="text" id="input" placeholder="输入消息..." />
<button id="send">发送</button>
</div>
</div>

<script>
const messagesDiv = document.getElementById('messages');
const input = document.getElementById('input');
const sendButton = document.getElementById('send');

// 生成唯一的 user_id
const userId = 'user_' + Math.random().toString(36).substr(2, 9);

// 连接 SSE
const eventSource = new EventSource(`/sse/${userId}`);

eventSource.onmessage = function(event) {
const message = JSON.parse(event.data);
const messageDiv = document.createElement('div');
messageDiv.className = 'message bot';
const timestamp = new Date(message.timestamp).toLocaleTimeString(); // 假设消息中有时间戳
messageDiv.innerHTML = `<div class="timestamp">${timestamp}</div>${message.content}`; // 显示时间
messagesDiv.appendChild(messageDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight; // 滚动到底部
};

sendButton.onclick = function() {
sendMessage();
};

input.addEventListener('keypress', function(event) {
if (event.key === 'Enter') {
sendMessage();
event.preventDefault(); // 防止换行
}
});

function sendMessage() {
const userMessage = input.value;
if (userMessage) {
const timestamp = new Date().toISOString(); // 获取当前时间戳
fetch('/message', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ user_id: userId, message: userMessage, timestamp: timestamp }) // 发送时间戳
});
const messageDiv = document.createElement('div');
messageDiv.className = 'message user';
const userTimestamp = new Date().toLocaleTimeString(); // 获取当前时间
messageDiv.innerHTML = `<div class="timestamp">${userTimestamp}</div>${userMessage}`; // 显示时间
messagesDiv.appendChild(messageDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight; // 滚动到底部
input.value = ''; // 清空输入框
}
}
</script>
</body>
</html>
204 changes: 204 additions & 0 deletions channel/web/web_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import sys
import time
import web
import json
from queue import Queue
from bridge.context import *
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.chat_message import ChatMessage
from common.log import logger
from common.singleton import singleton
from config import conf
import os


class WebMessage(ChatMessage):
def __init__(
self,
msg_id,
content,
ctype=ContextType.TEXT,
from_user_id="User",
to_user_id="Chatgpt",
other_user_id="Chatgpt",
):
self.msg_id = msg_id
self.ctype = ctype
self.content = content
self.from_user_id = from_user_id
self.to_user_id = to_user_id
self.other_user_id = other_user_id


@singleton
class WebChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
_instance = None

# def __new__(cls):
# if cls._instance is None:
# cls._instance = super(WebChannel, cls).__new__(cls)
# return cls._instance

def __init__(self):
super().__init__()
self.message_queues = {} # 为每个用户存储一个消息队列
self.msg_id_counter = 0 # 添加消息ID计数器

def _generate_msg_id(self):
"""生成唯一的消息ID"""
self.msg_id_counter += 1
return str(int(time.time())) + str(self.msg_id_counter)

def send(self, reply: Reply, context: Context):
try:
if reply.type == ReplyType.IMAGE:
from PIL import Image

image_storage = reply.content
image_storage.seek(0)
img = Image.open(image_storage)
print("<IMAGE>")
img.show()
elif reply.type == ReplyType.IMAGE_URL:
import io

import requests
from PIL import Image

img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
img = Image.open(image_storage)
print(img_url)
img.show()
else:
print(reply.content)

# 获取用户ID,如果没有则使用默认值
# user_id = getattr(context.get("session", None), "session_id", "default_user")
user_id = context["receiver"]
# 确保用户有对应的消息队列
if user_id not in self.message_queues:
self.message_queues[user_id] = Queue()

# 将消息放入对应用户的队列
message_data = {
"type": str(reply.type),
"content": reply.content,
"timestamp": time.time()
}
self.message_queues[user_id].put(message_data)
logger.debug(f"Message queued for user {user_id}")

except Exception as e:
logger.error(f"Error in send method: {e}")
raise

def sse_handler(self, user_id):
"""
Handle Server-Sent Events (SSE) for real-time communication.
"""
web.header('Content-Type', 'text/event-stream')
web.header('Cache-Control', 'no-cache')
web.header('Connection', 'keep-alive')

# 确保用户有消息队列
if user_id not in self.message_queues:
self.message_queues[user_id] = Queue()

try:
while True:
try:
# 发送心跳
yield f": heartbeat\n\n"

# 非阻塞方式获取消息
if not self.message_queues[user_id].empty():
message = self.message_queues[user_id].get_nowait()
yield f"data: {json.dumps(message)}\n\n"
time.sleep(0.5)
except Exception as e:
logger.error(f"SSE Error: {e}")
break
finally:
# 清理资源
if user_id in self.message_queues:
# 只有当队列为空时才删除
if self.message_queues[user_id].empty():
del self.message_queues[user_id]

def post_message(self):
"""
Handle incoming messages from users via POST request.
"""
try:
data = web.data() # 获取原始POST数据
json_data = json.loads(data)
user_id = json_data.get('user_id', 'default_user')
prompt = json_data.get('message', '')
except json.JSONDecodeError:
return json.dumps({"status": "error", "message": "Invalid JSON"})
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})

if not prompt:
return json.dumps({"status": "error", "message": "No message provided"})

try:
msg_id = self._generate_msg_id()
context = self._compose_context(ContextType.TEXT, prompt, msg=WebMessage(msg_id,
prompt,
from_user_id=user_id,
other_user_id = user_id
))
context["isgroup"] = False
# context["session"] = web.storage(session_id=user_id)

if not context:
return json.dumps({"status": "error", "message": "Failed to process message"})

self.produce(context)
return json.dumps({"status": "success", "message": "Message received"})

except Exception as e:
logger.error(f"Error processing message: {e}")
return json.dumps({"status": "error", "message": "Internal server error"})

def chat_page(self):
"""Serve the chat HTML page."""
file_path = os.path.join(os.path.dirname(__file__), 'chat.html') # 使用绝对路径
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()

def startup(self):
logger.setLevel("WARN")
print("\nWeb Channel is running. Send POST requests to /message to send messages.")

urls = (
'/sse/(.+)', 'SSEHandler', # 修改路由以接收用户ID
'/message', 'MessageHandler',
'/chat', 'ChatHandler',
)
port = conf().get("web_port", 9899)
app = web.application(urls, globals(), autoreload=False)
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))


class SSEHandler:
def GET(self, user_id):
return WebChannel().sse_handler(user_id)


class MessageHandler:
def POST(self):
return WebChannel().post_message()


class ChatHandler:
def GET(self):
return WebChannel().chat_page()
Loading

0 comments on commit 71c18c0

Please sign in to comment.