-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
381 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Web channel | ||
使用SSE(Server-Sent Events,服务器推送事件)实现,提供了一个默认的网页。也可以自己实现加入api | ||
|
||
#使用方法 | ||
- 在配置文件中channel_type填入web即可 | ||
- 访问地址 http://localhost:9899 | ||
- port可以在配置项 web_port中设置 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
<!DOCTYPE html> | ||
<html lang="zh"> | ||
<head> | ||
<meta charset="UTF-8"> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||
<title>Chat</title> | ||
<style> | ||
body { | ||
font-family: Arial, sans-serif; | ||
display: flex; | ||
flex-direction: column; | ||
height: 100vh; /* 占据所有高度 */ | ||
margin: 0; | ||
/* background-color: #f8f9fa; */ | ||
} | ||
#chat-container { | ||
display: flex; | ||
flex-direction: column; | ||
width: 100%; | ||
max-width: 500px; | ||
margin: auto; | ||
border: 1px solid #ccc; | ||
border-radius: 5px; | ||
overflow: hidden; | ||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | ||
flex: 1; /* 使聊天容器占据剩余空间 */ | ||
} | ||
#messages { | ||
flex-direction: column; | ||
display: flex; | ||
flex: 1; | ||
overflow-y: auto; | ||
padding: 10px; | ||
overflow-y: auto; | ||
border-bottom: 1px solid #ccc; | ||
background-color: #ffffff; | ||
} | ||
|
||
.message { | ||
margin: 5px 0; /* 间隔 */ | ||
padding: 10px 15px; /* 内边距 */ | ||
border-radius: 15px; /* 圆角 */ | ||
max-width: 80%; /* 限制最大宽度 */ | ||
min-width: 80px; /* 设置最小宽度 */ | ||
min-height: 40px; /* 设置最小高度 */ | ||
word-wrap: break-word; /* 自动换行 */ | ||
position: relative; /* 时间戳定位 */ | ||
display: inline-block; /* 内容自适应宽度 */ | ||
box-sizing: border-box; /* 包括内边距和边框 */ | ||
flex-shrink: 0; /* 禁止高度被压缩 */ | ||
word-wrap: break-word; /* 自动换行,防止单行过长 */ | ||
white-space: normal; /* 允许正常换行 */ | ||
overflow: hidden; | ||
} | ||
|
||
.bot { | ||
background-color: #f1f1f1; /* 灰色背景 */ | ||
color: black; /* 黑色字体 */ | ||
align-self: flex-start; /* 左对齐 */ | ||
margin-right: auto; /* 确保消息靠左 */ | ||
text-align: left; /* 内容左对齐 */ | ||
} | ||
|
||
.user { | ||
background-color: #2bc840; /* 蓝色背景 */ | ||
align-self: flex-end; /* 右对齐 */ | ||
margin-left: auto; /* 确保消息靠右 */ | ||
text-align: left; /* 内容左对齐 */ | ||
} | ||
.timestamp { | ||
font-size: 0.8em; /* 时间戳字体大小 */ | ||
color: rgba(0, 0, 0, 0.5); /* 半透明黑色 */ | ||
margin-bottom: 5px; /* 时间戳下方间距 */ | ||
display: block; /* 时间戳独占一行 */ | ||
} | ||
#input-container { | ||
display: flex; | ||
padding: 10px; | ||
background-color: #ffffff; | ||
border-top: 1px solid #ccc; | ||
} | ||
#input { | ||
flex: 1; | ||
padding: 10px; | ||
border: 1px solid #ccc; | ||
border-radius: 5px; | ||
margin-right: 10px; | ||
} | ||
#send { | ||
padding: 10px; | ||
border: none; | ||
background-color: #007bff; | ||
color: white; | ||
border-radius: 5px; | ||
cursor: pointer; | ||
} | ||
#send:hover { | ||
background-color: #0056b3; | ||
} | ||
</style> | ||
</head> | ||
<body> | ||
<div id="chat-container"> | ||
<div id="messages"></div> | ||
<div id="input-container"> | ||
<input type="text" id="input" placeholder="输入消息..." /> | ||
<button id="send">发送</button> | ||
</div> | ||
</div> | ||
|
||
<script> | ||
const messagesDiv = document.getElementById('messages'); | ||
const input = document.getElementById('input'); | ||
const sendButton = document.getElementById('send'); | ||
|
||
// 生成唯一的 user_id | ||
const userId = 'user_' + Math.random().toString(36).substr(2, 9); | ||
|
||
// 连接 SSE | ||
const eventSource = new EventSource(`/sse/${userId}`); | ||
|
||
eventSource.onmessage = function(event) { | ||
const message = JSON.parse(event.data); | ||
const messageDiv = document.createElement('div'); | ||
messageDiv.className = 'message bot'; | ||
const timestamp = new Date(message.timestamp).toLocaleTimeString(); // 假设消息中有时间戳 | ||
messageDiv.innerHTML = `<div class="timestamp">${timestamp}</div>${message.content}`; // 显示时间 | ||
messagesDiv.appendChild(messageDiv); | ||
messagesDiv.scrollTop = messagesDiv.scrollHeight; // 滚动到底部 | ||
}; | ||
|
||
sendButton.onclick = function() { | ||
sendMessage(); | ||
}; | ||
|
||
input.addEventListener('keypress', function(event) { | ||
if (event.key === 'Enter') { | ||
sendMessage(); | ||
event.preventDefault(); // 防止换行 | ||
} | ||
}); | ||
|
||
function sendMessage() { | ||
const userMessage = input.value; | ||
if (userMessage) { | ||
const timestamp = new Date().toISOString(); // 获取当前时间戳 | ||
fetch('/message', { | ||
method: 'POST', | ||
headers: { | ||
'Content-Type': 'application/json' | ||
}, | ||
body: JSON.stringify({ user_id: userId, message: userMessage, timestamp: timestamp }) // 发送时间戳 | ||
}); | ||
const messageDiv = document.createElement('div'); | ||
messageDiv.className = 'message user'; | ||
const userTimestamp = new Date().toLocaleTimeString(); // 获取当前时间 | ||
messageDiv.innerHTML = `<div class="timestamp">${userTimestamp}</div>${userMessage}`; // 显示时间 | ||
messagesDiv.appendChild(messageDiv); | ||
messagesDiv.scrollTop = messagesDiv.scrollHeight; // 滚动到底部 | ||
input.value = ''; // 清空输入框 | ||
} | ||
} | ||
</script> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import sys | ||
import time | ||
import web | ||
import json | ||
from queue import Queue | ||
from bridge.context import * | ||
from bridge.reply import Reply, ReplyType | ||
from channel.chat_channel import ChatChannel, check_prefix | ||
from channel.chat_message import ChatMessage | ||
from common.log import logger | ||
from common.singleton import singleton | ||
from config import conf | ||
import os | ||
|
||
|
||
class WebMessage(ChatMessage): | ||
def __init__( | ||
self, | ||
msg_id, | ||
content, | ||
ctype=ContextType.TEXT, | ||
from_user_id="User", | ||
to_user_id="Chatgpt", | ||
other_user_id="Chatgpt", | ||
): | ||
self.msg_id = msg_id | ||
self.ctype = ctype | ||
self.content = content | ||
self.from_user_id = from_user_id | ||
self.to_user_id = to_user_id | ||
self.other_user_id = other_user_id | ||
|
||
|
||
@singleton | ||
class WebChannel(ChatChannel): | ||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE] | ||
_instance = None | ||
|
||
# def __new__(cls): | ||
# if cls._instance is None: | ||
# cls._instance = super(WebChannel, cls).__new__(cls) | ||
# return cls._instance | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.message_queues = {} # 为每个用户存储一个消息队列 | ||
self.msg_id_counter = 0 # 添加消息ID计数器 | ||
|
||
def _generate_msg_id(self): | ||
"""生成唯一的消息ID""" | ||
self.msg_id_counter += 1 | ||
return str(int(time.time())) + str(self.msg_id_counter) | ||
|
||
def send(self, reply: Reply, context: Context): | ||
try: | ||
if reply.type == ReplyType.IMAGE: | ||
from PIL import Image | ||
|
||
image_storage = reply.content | ||
image_storage.seek(0) | ||
img = Image.open(image_storage) | ||
print("<IMAGE>") | ||
img.show() | ||
elif reply.type == ReplyType.IMAGE_URL: | ||
import io | ||
|
||
import requests | ||
from PIL import Image | ||
|
||
img_url = reply.content | ||
pic_res = requests.get(img_url, stream=True) | ||
image_storage = io.BytesIO() | ||
for block in pic_res.iter_content(1024): | ||
image_storage.write(block) | ||
image_storage.seek(0) | ||
img = Image.open(image_storage) | ||
print(img_url) | ||
img.show() | ||
else: | ||
print(reply.content) | ||
|
||
# 获取用户ID,如果没有则使用默认值 | ||
# user_id = getattr(context.get("session", None), "session_id", "default_user") | ||
user_id = context["receiver"] | ||
# 确保用户有对应的消息队列 | ||
if user_id not in self.message_queues: | ||
self.message_queues[user_id] = Queue() | ||
|
||
# 将消息放入对应用户的队列 | ||
message_data = { | ||
"type": str(reply.type), | ||
"content": reply.content, | ||
"timestamp": time.time() | ||
} | ||
self.message_queues[user_id].put(message_data) | ||
logger.debug(f"Message queued for user {user_id}") | ||
|
||
except Exception as e: | ||
logger.error(f"Error in send method: {e}") | ||
raise | ||
|
||
def sse_handler(self, user_id): | ||
""" | ||
Handle Server-Sent Events (SSE) for real-time communication. | ||
""" | ||
web.header('Content-Type', 'text/event-stream') | ||
web.header('Cache-Control', 'no-cache') | ||
web.header('Connection', 'keep-alive') | ||
|
||
# 确保用户有消息队列 | ||
if user_id not in self.message_queues: | ||
self.message_queues[user_id] = Queue() | ||
|
||
try: | ||
while True: | ||
try: | ||
# 发送心跳 | ||
yield f": heartbeat\n\n" | ||
|
||
# 非阻塞方式获取消息 | ||
if not self.message_queues[user_id].empty(): | ||
message = self.message_queues[user_id].get_nowait() | ||
yield f"data: {json.dumps(message)}\n\n" | ||
time.sleep(0.5) | ||
except Exception as e: | ||
logger.error(f"SSE Error: {e}") | ||
break | ||
finally: | ||
# 清理资源 | ||
if user_id in self.message_queues: | ||
# 只有当队列为空时才删除 | ||
if self.message_queues[user_id].empty(): | ||
del self.message_queues[user_id] | ||
|
||
def post_message(self): | ||
""" | ||
Handle incoming messages from users via POST request. | ||
""" | ||
try: | ||
data = web.data() # 获取原始POST数据 | ||
json_data = json.loads(data) | ||
user_id = json_data.get('user_id', 'default_user') | ||
prompt = json_data.get('message', '') | ||
except json.JSONDecodeError: | ||
return json.dumps({"status": "error", "message": "Invalid JSON"}) | ||
except Exception as e: | ||
return json.dumps({"status": "error", "message": str(e)}) | ||
|
||
if not prompt: | ||
return json.dumps({"status": "error", "message": "No message provided"}) | ||
|
||
try: | ||
msg_id = self._generate_msg_id() | ||
context = self._compose_context(ContextType.TEXT, prompt, msg=WebMessage(msg_id, | ||
prompt, | ||
from_user_id=user_id, | ||
other_user_id = user_id | ||
)) | ||
context["isgroup"] = False | ||
# context["session"] = web.storage(session_id=user_id) | ||
|
||
if not context: | ||
return json.dumps({"status": "error", "message": "Failed to process message"}) | ||
|
||
self.produce(context) | ||
return json.dumps({"status": "success", "message": "Message received"}) | ||
|
||
except Exception as e: | ||
logger.error(f"Error processing message: {e}") | ||
return json.dumps({"status": "error", "message": "Internal server error"}) | ||
|
||
def chat_page(self): | ||
"""Serve the chat HTML page.""" | ||
file_path = os.path.join(os.path.dirname(__file__), 'chat.html') # 使用绝对路径 | ||
with open(file_path, 'r', encoding='utf-8') as f: | ||
return f.read() | ||
|
||
def startup(self): | ||
logger.setLevel("WARN") | ||
print("\nWeb Channel is running. Send POST requests to /message to send messages.") | ||
|
||
urls = ( | ||
'/sse/(.+)', 'SSEHandler', # 修改路由以接收用户ID | ||
'/message', 'MessageHandler', | ||
'/chat', 'ChatHandler', | ||
) | ||
port = conf().get("web_port", 9899) | ||
app = web.application(urls, globals(), autoreload=False) | ||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) | ||
|
||
|
||
class SSEHandler: | ||
def GET(self, user_id): | ||
return WebChannel().sse_handler(user_id) | ||
|
||
|
||
class MessageHandler: | ||
def POST(self): | ||
return WebChannel().post_message() | ||
|
||
|
||
class ChatHandler: | ||
def GET(self): | ||
return WebChannel().chat_page() |
Oops, something went wrong.