Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 优化 Chat 的载入及为其创建一个简易 server #124

Merged
merged 3 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,18 @@ COMMAND_START=["","/"]
#请参考 https://zhuanlan.zhihu.com/p/618011122 配置 strategy
#CHAT_STRATEGY=cuda fp16

# 是否使用本地 api
#chat_use_local_server = False

#chat api超时时间,机子差可以设长一点
#chat_server_timeout = 15

#chat api重试次数
#chat_server_retry = 3

# tts 功能相关配置

# 声码器,可选值:pwgan_aishell3、wavernn_csmsc
#TTS_VOCODER=pwgan_aishell3


12 changes: 10 additions & 2 deletions docs/AIDeployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,17 @@ AI 功能均对设备硬件要求较高,且配置操作更加复杂一些。
python -m pip install tokenizers rwkv
```

3. `src/plugins/chat/prompt.py` 里的起手咒语 `INIT_PROMPT` 有兴趣可以试着改改
3. (可选)在 `.env` 里配置是否启用 chat server,由独立进程加载聊天模型。默认不启用,由 Pallas-Bot 直接加载聊天模型

```bash
python src/pluings/chat/server.py
```

`src/plugins/chat/server.py`中的端口可以自行修改,默认为 5000,保证与 `src/plugins/chat/__init__.py` 中一致即可。也可以自行部署 gunicorn 等生产服务器。

4. `src/plugins/chat/prompt.py` 里的起手咒语 `INIT_PROMPT` 有兴趣可以试着改改

4. `src/plugins/chat/model.py` 里的 `STRATEGY` 可以按上游仓库的 [说明](https://github.com/BlinkDL/ChatRWKV/tree/main#%E4%B8%AD%E6%96%87%E6%A8%A1%E5%9E%8B) 改改,能省点显存啥的
5. `src/plugins/chat/model.py` 里的 `STRATEGY` 可以按上游仓库的 [说明](https://github.com/BlinkDL/ChatRWKV/tree/main#%E4%B8%AD%E6%96%87%E6%A8%A1%E5%9E%8B) 改改,能省点显存啥的

## 酒后语音说话(TTS)

Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ pydantic~=1.10.0
pymongo~=4.3.3
jieba~=0.42.1
pypinyin~=0.49.0

# chat
httpx~=0.27.0
6 changes: 6 additions & 0 deletions src/common/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class PluginConfig(BaseModel, extra=Extra.ignore):
tts_vocoder: str = 'pwgan_aishell3'
# chat 模型的strategy
chat_strategy: str = ''
# chat 是否使用本地api
chat_use_local_server: bool = False
# chat api超时时间
chat_server_timeout: int = 15
# chat api重试次数
chat_server_retry: int = 3


try:
Expand Down
73 changes: 60 additions & 13 deletions src/plugins/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
from asyncer import asyncify
from nonebot.adapters.onebot.v11 import MessageSegment, permission, GroupMessageEvent
from nonebot.adapters import Bot, Event
from nonebot.rule import Rule
from nonebot.typing import T_State
from nonebot import on_message, logger
import httpx

from src.common.config import BotConfig, GroupConfig, plugin_config

Expand All @@ -21,34 +23,70 @@
raise error

TTS_MIN_LENGTH = 10
CHAT_API_URL = 'http://127.0.0.1:5000/chat'
USE_API = plugin_config.chat_use_local_server
TIMEOUT = plugin_config.chat_server_timeout
MAX_RETRIES = plugin_config.chat_server_retry
RETRY_BACKOFF_FACTOR = 1 # 重试间隔

# 用来重试的
client = httpx.AsyncClient(
timeout=httpx.Timeout(timeout=TIMEOUT),
transport=httpx.AsyncHTTPTransport(retries=MAX_RETRIES)
)

try:
chat = Chat(plugin_config.chat_strategy)
except Exception as error:
logger.error('Chat model init error: ', error)
raise error

if USE_API:
try:
chat = None
except Exception as error:
logger.error('Chat api init error: ', error)
raise error
else:
try:
chat = Chat(plugin_config.chat_strategy)
except Exception as error:
logger.error('Chat model init error: ', error)
raise error

@BotConfig.handle_sober_up
def on_sober_up(bot_id, group_id, drunkenness) -> None:
session = f'{bot_id}_{group_id}'
logger.info(
f'bot [{bot_id}] sober up in group [{group_id}], clear session [{session}]')
chat.del_session(session)

logger.info(f'bot [{bot_id}] sober up in group [{group_id}], clear session [{session}]')
if USE_API:
try:
response = client.delete(f'{CHAT_API_URL}/del_session', params={'session': session})
response.raise_for_status()
except httpx.HTTPError as error:
logger.error(f'Failed to delete session [{session}]: {error}')
else:
if chat is not None:
chat.del_session(session)

def is_drunk(bot: Bot, event: Event, state: T_State) -> int:
config = BotConfig(event.self_id, event.group_id)
return config.drunkenness()


drunk_msg = on_message(
rule=Rule(is_drunk),
priority=13,
block=True,
permission=permission.GROUP,
)

async def make_api_request(url, method, json_data=None, params=None):
for a in range(MAX_RETRIES + 1):
try:
if method == 'POST':
response = await client.post(url, json=json_data)
elif method == 'DELETE':
response = await client.delete(url, params=params)
response.raise_for_status()
return response
except httpx.HTTPError as error:
logger.error(f'Request failed (attempt {a + 1}): {error}')
if a < MAX_RETRIES:
await asyncio.sleep(RETRY_BACKOFF_FACTOR * (2 ** a))
return None

@drunk_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
Expand All @@ -71,7 +109,16 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
text = text[:50]
if not text:
return
ans = await asyncify(chat.chat)(session, text)

if USE_API:
response = await make_api_request(CHAT_API_URL, 'POST', json_data={'session': session, 'text': text, 'token_count': 50})
if response:
ans = response.json().get('response', '')
else:
return
else:
ans = await asyncify(chat.chat)(session, text)

logger.info(f'session [{session}]: {text} -> {ans}')

if TTS_AVAIABLE and len(ans) >= TTS_MIN_LENGTH:
Expand All @@ -80,4 +127,4 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await drunk_msg.send(voice)

config.reset_cooldown(cd_key)
await drunk_msg.finish(ans)
await drunk_msg.finish(ans)
23 changes: 20 additions & 3 deletions src/plugins/chat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from threading import Lock
from copy import deepcopy
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import os
import time
import torch
import threading

cuda = torch.cuda.is_available()
os.environ['RWKV_JIT_ON'] = '1'
Expand Down Expand Up @@ -34,6 +37,17 @@ def __init__(self, strategy=DEFAULT_STRATEGY, model_dir=DEFAULT_MODEL_DIR) -> No
raise Exception(f'Chat model not found in {self.MODEL_DIR}')
if not self.TOKEN_PATH.exists():
raise Exception(f'Chat token not found in {self.TOKEN_PATH}')

self.pipeline = None
self.args = None
self.all_state = defaultdict(lambda: None)
self.all_occurrence = {}
self.chat_locker = Lock()
self.executor = ThreadPoolExecutor(max_workers=10)

threading.Thread(target=self._load_model).start()

def _load_model(self):
model = RWKV(model=str(self.MODEL_PATH), strategy=self.STRATEGY)
self.pipeline = PIPELINE(model, str(self.TOKEN_PATH))
self.args = PIPELINE_ARGS(
Expand All @@ -49,11 +63,14 @@ def __init__(self, strategy=DEFAULT_STRATEGY, model_dir=DEFAULT_MODEL_DIR) -> No
INIT_STATE = deepcopy(self.pipeline.generate(
INIT_PROMPT, token_count=200, args=self.args)[1])
self.all_state = defaultdict(lambda: deepcopy(INIT_STATE))
self.all_occurrence = {}

self.chat_locker = Lock()

def chat(self, session: str, text: str, token_count: int = 50) -> str:
while self.pipeline is None:
time.sleep(0.1)
future = self.executor.submit(self._chat, session, text, token_count)
return future.result()

def _chat(self, session: str, text: str, token_count: int = 50) -> str:
with self.chat_locker:
state = self.all_state[session]
ctx = CHAT_FORMAT.format(text)
Expand Down
102 changes: 102 additions & 0 deletions src/plugins/chat/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from flask import Flask, request, jsonify
from pathlib import Path
from threading import Lock
from copy import deepcopy
from collections import defaultdict
import os
import torch


app = Flask(__name__)

cuda = torch.cuda.is_available()
os.environ['RWKV_JIT_ON'] = '1'
# 这个要配个 ninja 啥的环境,能大幅提高推理速度,有需要可以自己弄下(仅支持 cuda 显卡)
os.environ["RWKV_CUDA_ON"] = '0'

from rwkv.model import RWKV
import prompt
import pipeline

DEFAULT_STRATEGY = 'cuda fp16' if cuda else 'cpu fp32'
API_DIR = Path(__file__).resolve().parent.parent.parent.parent
DEFAULT_MODEL_DIR = API_DIR / 'resource' / 'chat' / 'models'
print(f"DEFAULT_MODEL_DIR: {DEFAULT_MODEL_DIR}")
print("Files in directory:")
for f in DEFAULT_MODEL_DIR.iterdir():
print(f)
class Chat:
def __init__(self, strategy=DEFAULT_STRATEGY, model_dir=DEFAULT_MODEL_DIR) -> None:
self.STRATEGY = strategy if strategy else DEFAULT_STRATEGY
self.MODEL_DIR = model_dir
self.MODEL_EXT = '.pth'
self.MODEL_PATH = None
self.TOKEN_PATH = self.MODEL_DIR / '20B_tokenizer.json'
for f in self.MODEL_DIR.glob('*'):
if f.suffix != self.MODEL_EXT:
continue
self.MODEL_PATH = f.with_suffix('')
break
if not self.MODEL_PATH:
raise Exception(f'Chat model not found in {self.MODEL_DIR}')
if not self.TOKEN_PATH.exists():
raise Exception(f'Chat token not found in {self.TOKEN_PATH}')
model = RWKV(model=str(self.MODEL_PATH), strategy=self.STRATEGY)
self.pipeline = pipeline.PIPELINE(model, str(self.TOKEN_PATH))
self.args = pipeline.PIPELINE_ARGS(
temperature=1.0,
top_p=0.7,
alpha_frequency=0.25,
alpha_presence=0.25,
token_ban=[0], # ban the generation of some tokens
token_stop=[], # stop generation whenever you see any token here
ends=('\n'),
ends_if_too_long=("。", "!", "?", "\n"))

INIT_STATE = deepcopy(self.pipeline.generate(
prompt.INIT_PROMPT, token_count=200, args=self.args)[1])
self.all_state = defaultdict(lambda: deepcopy(INIT_STATE))
self.all_occurrence = {}

self.chat_locker = Lock()

def chat(self, session: str, text: str, token_count: int = 50) -> str:
with self.chat_locker:
state = self.all_state[session]
ctx = prompt.CHAT_FORMAT.format(text)
occurrence = self.all_occurrence.get(session, {})

out, state, occurrence = self.pipeline.generate(
ctx, token_count=token_count, args=self.args, state=state, occurrence=occurrence)

self.all_state[session] = deepcopy(state)
self.all_occurrence[session] = occurrence
return out.strip()

def del_session(self, session: str):
with self.chat_locker:
if session in self.all_state:
del self.all_state[session]
if session in self.all_occurrence:
del self.all_occurrence[session]

chat_instance = Chat('cpu fp32')

@app.route('/chat', methods=['POST'])
def chat():
data = request.json
session = data.get('session', 'main')
text = data.get('text', '')
token_count = data.get('token_count', 50)
response = chat_instance.chat(session, text, token_count)
return jsonify({'response': response})

@app.route('/del_session', methods=['DELETE'])
def del_session():
data = request.json
session = data.get('session', 'main')
chat_instance.del_session(session)
return jsonify({'status': 'success'})

if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000)
Loading