Skip to content

Commit

Permalink
Fix long audio bug (#158)
Browse files Browse the repository at this point in the history
* update:异步生成音频

* update:优化LLM断句

---------

Co-authored-by: hrz <[email protected]>
  • Loading branch information
xinnan-tech and openrz authored Feb 28, 2025
1 parent 2c88a36 commit b94842c
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 235 deletions.
87 changes: 50 additions & 37 deletions core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import threading
import websockets
from typing import Dict, Any
from collections import deque
from core.utils.util import is_segment
from core.utils.dialogue import Message, Dialogue
from core.handle.textHandle import handleTextMessage
from core.utils.util import get_string_no_punctuation_or_emoji
Expand All @@ -21,7 +19,6 @@
from core.auth import AuthMiddleware, AuthenticationError
from core.utils.auth_code_gen import AuthCodeGenerator


TAG = __name__


Expand All @@ -45,8 +42,8 @@ def __init__(self, config: Dict[str, Any], _vad, _asr, _llm, _tts, _music):
self.loop = asyncio.get_event_loop()
self.stop_event = threading.Event()
self.tts_queue = queue.Queue()
self.audio_play_queue = queue.Queue()
self.executor = ThreadPoolExecutor(max_workers=10)
self.scheduled_tasks = deque()

# 依赖的组件
self.vad = _vad
Expand Down Expand Up @@ -140,9 +137,14 @@ async def handle_connection(self, ws):

await self.loop.run_in_executor(None, self._initialize_components)

tts_priority = threading.Thread(target=self._priority_thread, daemon=True)
# tts 消化线程
tts_priority = threading.Thread(target=self._tts_priority_thread, daemon=True)
tts_priority.start()

# 音频播放 消化线程
audio_play_priority = threading.Thread(target=self._audio_play_priority_thread, daemon=True)
audio_play_priority.start()

try:
async for message in self.websocket:
await self._route_message(message)
Expand Down Expand Up @@ -198,10 +200,8 @@ def isNeedAuth(self):
return not self.is_device_verified

def chat(self, query):
# 如果设备未验证,就发送验证码
if self.isNeedAuth():
self.llm_finish_task = True
# 创建一个新的事件循环来运行异步函数
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
Expand All @@ -212,49 +212,61 @@ def chat(self, query):

self.dialogue.put(Message(role="user", content=query))
response_message = []
start = 0
# 提交 LLM 任务
processed_chars = 0 # 跟踪已处理的字符位置
try:
start_time = time.time() # 记录开始时间
start_time = time.time()
llm_responses = self.llm.response(self.session_id, self.dialogue.get_llm_dialogue())
except Exception as e:
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
return None
# 提交 TTS 任务到线程池

self.llm_finish_task = False
for content in llm_responses:
response_message.append(content)
# 如果中途被打断,就停止生成
if self.client_abort:
start = len(response_message)
break

end_time = time.time() # 记录结束时间
self.logger.bind(tag=TAG).debug(f"大模型返回时间时间: {end_time - start_time} 秒, 生成token={content}")
if is_segment(response_message):
segment_text = "".join(response_message[start:])
segment_text = get_string_no_punctuation_or_emoji(segment_text)
if len(segment_text) > 0:
end_time = time.time()
self.logger.bind(tag=TAG).debug(f"大模型返回时间: {end_time - start_time} 秒, 生成token={content}")

# 合并当前全部文本并处理未分割部分
full_text = "".join(response_message)
current_text = full_text[processed_chars:] # 从未处理的位置开始

# 查找最后一个有效标点
punctuations = ("。", "?", "!", ".", "?", "!", ";", ";", ":", ":", ",", ",")
last_punct_pos = -1
for punct in punctuations:
pos = current_text.rfind(punct)
if pos > last_punct_pos:
last_punct_pos = pos

# 找到分割点则处理
if last_punct_pos != -1:
segment_text_raw = current_text[:last_punct_pos + 1]
segment_text = get_string_no_punctuation_or_emoji(segment_text_raw)
if segment_text:
self.recode_first_last_text(segment_text)
future = self.executor.submit(self.speak_and_play, segment_text)
self.tts_queue.put(future)
start = len(response_message)

# 处理剩余的响应
if start < len(response_message):
segment_text = "".join(response_message[start:])
if len(segment_text) > 0:
processed_chars += len(segment_text_raw) # 更新已处理字符位置

# 处理最后剩余的文本
full_text = "".join(response_message)
remaining_text = full_text[processed_chars:]
if remaining_text:
segment_text = get_string_no_punctuation_or_emoji(remaining_text)
if segment_text:
self.recode_first_last_text(segment_text)
future = self.executor.submit(self.speak_and_play, segment_text)
self.tts_queue.put(future)

self.llm_finish_task = True
# 更新对话
self.dialogue.put(Message(role="assistant", content="".join(response_message)))
self.logger.bind(tag=TAG).debug(json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False))
return True

def _priority_thread(self):
def _tts_priority_thread(self):
while not self.stop_event.is_set():
text = None
try:
Expand All @@ -276,7 +288,6 @@ def _priority_thread(self):
else:
self.logger.bind(tag=TAG).error(f"TTS文件不存在: {tts_file}")
opus_datas = []
duration = 0
except TimeoutError:
self.logger.bind(tag=TAG).error("TTS 任务超时")
continue
Expand All @@ -285,9 +296,7 @@ def _priority_thread(self):
continue
if not self.client_abort:
# 如果没有中途打断就发送语音
asyncio.run_coroutine_threadsafe(
sendAudioMessage(self, opus_datas, duration, text), self.loop
)
self.audio_play_queue.put((opus_datas, text))
if self.tts.delete_audio_file and os.path.exists(tts_file):
os.remove(tts_file)
except Exception as e:
Expand All @@ -299,6 +308,16 @@ def _priority_thread(self):
)
self.logger.bind(tag=TAG).error(f"tts_priority priority_thread: {text}{e}")

def _audio_play_priority_thread(self):
while not self.stop_event.is_set():
text = None
try:
opus_datas, text = self.audio_play_queue.get()
future = asyncio.run_coroutine_threadsafe(sendAudioMessage(self, opus_datas, text), self.loop)
future.result()
except Exception as e:
self.logger.bind(tag=TAG).error(f"audio_play_priority priority_thread: {text}{e}")

def speak_and_play(self, text):
if text is None or len(text) <= 0:
self.logger.bind(tag=TAG).info(f"无需tts转换,query为空,{text}")
Expand Down Expand Up @@ -340,9 +359,3 @@ def reset_vad_states(self):
self.client_have_voice_last_time = 0
self.client_voice_stop = False
self.logger.bind(tag=TAG).debug("VAD states reset.")

def stop_all_tasks(self):
while self.scheduled_tasks:
task = self.scheduled_tasks.popleft()
task.cancel()
self.scheduled_tasks.clear()
3 changes: 1 addition & 2 deletions core/handle/abortHandle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import queue
from config.logger import setup_logging

TAG = __name__
Expand All @@ -9,8 +10,6 @@ async def handleAbortMessage(conn):
logger.bind(tag=TAG).info("Abort message received")
# 设置成打断状态,会自动打断llm、tts任务
conn.client_abort = True
# 打断屏显任务
conn.stop_all_tasks()
# 打断客户端说话状态
await conn.websocket.send(json.dumps({"type": "tts", "state": "stop", "session_id": conn.session_id}))
conn.clearSpeakStatus()
Expand Down
5 changes: 3 additions & 2 deletions core/handle/musicHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import difflib
import re
import traceback
from core.handle.sendAudioHandle import sendAudioMessage, send_stt_message
from core.handle.sendAudioHandle import send_stt_message

TAG = __name__
logger = setup_logging()
Expand Down Expand Up @@ -102,7 +102,8 @@ async def play_local_music(self, conn, specific_file=None):
conn.tts_last_text = selected_music
conn.llm_finish_task = True
opus_packets, duration = conn.tts.wav_to_opus_data(music_path)
await sendAudioMessage(conn, opus_packets, duration, selected_music)

conn.audio_play_queue.put((opus_packets, selected_music))

except Exception as e:
logger.bind(tag=TAG).error(f"播放音乐失败: {str(e)}")
Expand Down
8 changes: 2 additions & 6 deletions core/handle/receiveAudioHandle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from config.logger import setup_logging
import asyncio
import time
from core.utils.util import remove_punctuation_and_length
from core.handle.sendAudioHandle import schedule_with_interrupt, send_stt_message
from core.handle.sendAudioHandle import send_stt_message

TAG = __name__
logger = setup_logging()
Expand Down Expand Up @@ -61,10 +60,7 @@ async def handleCMDMessage(conn, text):

async def startToChat(conn, text):
# 异步发送 stt 信息
stt_task = asyncio.create_task(
schedule_with_interrupt(0, send_stt_message(conn, text))
)
conn.scheduled_tasks.append(stt_task)
await send_stt_message(conn, text)
conn.executor.submit(conn.chat, text)


Expand Down
53 changes: 17 additions & 36 deletions core/handle/sendAudioHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,39 @@
logger = setup_logging()


async def isLLMWantToFinish(conn):
first_text = conn.tts_first_text
last_text = conn.tts_last_text
async def isLLMWantToFinish(last_text):
_, last_text_without_punctuation = remove_punctuation_and_length(last_text)
if "再见" in last_text_without_punctuation or "拜拜" in last_text_without_punctuation:
return True
_, first_text_without_punctuation = remove_punctuation_and_length(first_text)
if "再见" in first_text_without_punctuation or "拜拜" in first_text_without_punctuation:
return True
return False


async def sendAudioMessage(conn, audios, duration, text):
base_delay = conn.tts_duration

async def sendAudioMessage(conn, audios, text):
# 发送 tts.start
if text == conn.tts_first_text:
logger.bind(tag=TAG).info(f"发送第一段语音: {text}")
conn.tts_start_speak_time = time.time()

# 发送 sentence_start(每个音频文件之前发送一次)
sentence_task = asyncio.create_task(
schedule_with_interrupt(base_delay, send_tts_message(conn, "sentence_start", text))
)
conn.scheduled_tasks.append(sentence_task)

conn.tts_duration += duration
await send_tts_message(conn, "sentence_start", text)

# 发送音频数据
frame_duration = 60 # 初始帧持续时间(毫秒)
start_time = time.time() # 记录开始时间
for idx, opus_packet in enumerate(audios):
if conn.client_abort:
return
# 计算当前包的预期发送时间
expected_time = start_time + idx * (frame_duration / 1000)
current_time = time.time()
# 如果未到预期时间则等待差值
if current_time < expected_time:
await asyncio.sleep(expected_time - current_time)
# 发送音频包
await conn.websocket.send(opus_packet)

if conn.llm_finish_task and text == conn.tts_last_text:
stop_duration = conn.tts_duration - (time.time() - conn.tts_start_speak_time)
stop_task = asyncio.create_task(
schedule_with_interrupt(stop_duration, send_tts_message(conn, 'stop'))
)
conn.scheduled_tasks.append(stop_task)
if await isLLMWantToFinish(conn):
finish_task = asyncio.create_task(
schedule_with_interrupt(stop_duration, await conn.close())
)
conn.scheduled_tasks.append(finish_task)
await send_tts_message(conn, 'stop')
if await isLLMWantToFinish(text):
await conn.close()


async def send_tts_message(conn, state, text=None):
Expand Down Expand Up @@ -84,12 +74,3 @@ async def send_stt_message(conn, text):
"session_id": conn.session_id}
))
await send_tts_message(conn, "start")


async def schedule_with_interrupt(delay, coro):
"""可中断的延迟调度"""
try:
await asyncio.sleep(delay)
await coro
except asyncio.CancelledError:
pass
Loading

0 comments on commit b94842c

Please sign in to comment.