diff --git a/core/connection.py b/core/connection.py index 70dd988..e463684 100644 --- a/core/connection.py +++ b/core/connection.py @@ -9,8 +9,6 @@ import threading import websockets from typing import Dict, Any -from collections import deque -from core.utils.util import is_segment from core.utils.dialogue import Message, Dialogue from core.handle.textHandle import handleTextMessage from core.utils.util import get_string_no_punctuation_or_emoji @@ -21,7 +19,6 @@ from core.auth import AuthMiddleware, AuthenticationError from core.utils.auth_code_gen import AuthCodeGenerator - TAG = __name__ @@ -45,8 +42,8 @@ def __init__(self, config: Dict[str, Any], _vad, _asr, _llm, _tts, _music): self.loop = asyncio.get_event_loop() self.stop_event = threading.Event() self.tts_queue = queue.Queue() + self.audio_play_queue = queue.Queue() self.executor = ThreadPoolExecutor(max_workers=10) - self.scheduled_tasks = deque() # 依赖的组件 self.vad = _vad @@ -140,9 +137,14 @@ async def handle_connection(self, ws): await self.loop.run_in_executor(None, self._initialize_components) - tts_priority = threading.Thread(target=self._priority_thread, daemon=True) + # tts 消化线程 + tts_priority = threading.Thread(target=self._tts_priority_thread, daemon=True) tts_priority.start() + # 音频播放 消化线程 + audio_play_priority = threading.Thread(target=self._audio_play_priority_thread, daemon=True) + audio_play_priority.start() + try: async for message in self.websocket: await self._route_message(message) @@ -198,10 +200,8 @@ def isNeedAuth(self): return not self.is_device_verified def chat(self, query): - # 如果设备未验证,就发送验证码 if self.isNeedAuth(): self.llm_finish_task = True - # 创建一个新的事件循环来运行异步函数 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: @@ -212,49 +212,61 @@ def chat(self, query): self.dialogue.put(Message(role="user", content=query)) response_message = [] - start = 0 - # 提交 LLM 任务 + processed_chars = 0 # 跟踪已处理的字符位置 try: - start_time = time.time() # 记录开始时间 + start_time = time.time() llm_responses = self.llm.response(self.session_id, self.dialogue.get_llm_dialogue()) except Exception as e: self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}") return None - # 提交 TTS 任务到线程池 + self.llm_finish_task = False for content in llm_responses: response_message.append(content) - # 如果中途被打断,就停止生成 if self.client_abort: - start = len(response_message) break - end_time = time.time() # 记录结束时间 - self.logger.bind(tag=TAG).debug(f"大模型返回时间时间: {end_time - start_time} 秒, 生成token={content}") - if is_segment(response_message): - segment_text = "".join(response_message[start:]) - segment_text = get_string_no_punctuation_or_emoji(segment_text) - if len(segment_text) > 0: + end_time = time.time() + self.logger.bind(tag=TAG).debug(f"大模型返回时间: {end_time - start_time} 秒, 生成token={content}") + + # 合并当前全部文本并处理未分割部分 + full_text = "".join(response_message) + current_text = full_text[processed_chars:] # 从未处理的位置开始 + + # 查找最后一个有效标点 + punctuations = ("。", "?", "!", ".", "?", "!", ";", ";", ":", ":", ",", ",") + last_punct_pos = -1 + for punct in punctuations: + pos = current_text.rfind(punct) + if pos > last_punct_pos: + last_punct_pos = pos + + # 找到分割点则处理 + if last_punct_pos != -1: + segment_text_raw = current_text[:last_punct_pos + 1] + segment_text = get_string_no_punctuation_or_emoji(segment_text_raw) + if segment_text: self.recode_first_last_text(segment_text) future = self.executor.submit(self.speak_and_play, segment_text) self.tts_queue.put(future) - start = len(response_message) - - # 处理剩余的响应 - if start < len(response_message): - segment_text = "".join(response_message[start:]) - if len(segment_text) > 0: + processed_chars += len(segment_text_raw) # 更新已处理字符位置 + + # 处理最后剩余的文本 + full_text = "".join(response_message) + remaining_text = full_text[processed_chars:] + if remaining_text: + segment_text = get_string_no_punctuation_or_emoji(remaining_text) + if segment_text: self.recode_first_last_text(segment_text) future = self.executor.submit(self.speak_and_play, segment_text) self.tts_queue.put(future) self.llm_finish_task = True - # 更新对话 self.dialogue.put(Message(role="assistant", content="".join(response_message))) self.logger.bind(tag=TAG).debug(json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False)) return True - def _priority_thread(self): + def _tts_priority_thread(self): while not self.stop_event.is_set(): text = None try: @@ -276,7 +288,6 @@ def _priority_thread(self): else: self.logger.bind(tag=TAG).error(f"TTS文件不存在: {tts_file}") opus_datas = [] - duration = 0 except TimeoutError: self.logger.bind(tag=TAG).error("TTS 任务超时") continue @@ -285,9 +296,7 @@ def _priority_thread(self): continue if not self.client_abort: # 如果没有中途打断就发送语音 - asyncio.run_coroutine_threadsafe( - sendAudioMessage(self, opus_datas, duration, text), self.loop - ) + self.audio_play_queue.put((opus_datas, text)) if self.tts.delete_audio_file and os.path.exists(tts_file): os.remove(tts_file) except Exception as e: @@ -299,6 +308,16 @@ def _priority_thread(self): ) self.logger.bind(tag=TAG).error(f"tts_priority priority_thread: {text}{e}") + def _audio_play_priority_thread(self): + while not self.stop_event.is_set(): + text = None + try: + opus_datas, text = self.audio_play_queue.get() + future = asyncio.run_coroutine_threadsafe(sendAudioMessage(self, opus_datas, text), self.loop) + future.result() + except Exception as e: + self.logger.bind(tag=TAG).error(f"audio_play_priority priority_thread: {text}{e}") + def speak_and_play(self, text): if text is None or len(text) <= 0: self.logger.bind(tag=TAG).info(f"无需tts转换,query为空,{text}") @@ -340,9 +359,3 @@ def reset_vad_states(self): self.client_have_voice_last_time = 0 self.client_voice_stop = False self.logger.bind(tag=TAG).debug("VAD states reset.") - - def stop_all_tasks(self): - while self.scheduled_tasks: - task = self.scheduled_tasks.popleft() - task.cancel() - self.scheduled_tasks.clear() \ No newline at end of file diff --git a/core/handle/abortHandle.py b/core/handle/abortHandle.py index 12d10ce..3183487 100644 --- a/core/handle/abortHandle.py +++ b/core/handle/abortHandle.py @@ -1,4 +1,5 @@ import json +import queue from config.logger import setup_logging TAG = __name__ @@ -9,8 +10,6 @@ async def handleAbortMessage(conn): logger.bind(tag=TAG).info("Abort message received") # 设置成打断状态,会自动打断llm、tts任务 conn.client_abort = True - # 打断屏显任务 - conn.stop_all_tasks() # 打断客户端说话状态 await conn.websocket.send(json.dumps({"type": "tts", "state": "stop", "session_id": conn.session_id})) conn.clearSpeakStatus() diff --git a/core/handle/musicHandler.py b/core/handle/musicHandler.py index a3d865d..672550c 100644 --- a/core/handle/musicHandler.py +++ b/core/handle/musicHandler.py @@ -4,7 +4,7 @@ import difflib import re import traceback -from core.handle.sendAudioHandle import sendAudioMessage, send_stt_message +from core.handle.sendAudioHandle import send_stt_message TAG = __name__ logger = setup_logging() @@ -102,7 +102,8 @@ async def play_local_music(self, conn, specific_file=None): conn.tts_last_text = selected_music conn.llm_finish_task = True opus_packets, duration = conn.tts.wav_to_opus_data(music_path) - await sendAudioMessage(conn, opus_packets, duration, selected_music) + + conn.audio_play_queue.put((opus_packets, selected_music)) except Exception as e: logger.bind(tag=TAG).error(f"播放音乐失败: {str(e)}") diff --git a/core/handle/receiveAudioHandle.py b/core/handle/receiveAudioHandle.py index 59e0734..f0e7b4a 100644 --- a/core/handle/receiveAudioHandle.py +++ b/core/handle/receiveAudioHandle.py @@ -1,8 +1,7 @@ from config.logger import setup_logging -import asyncio import time from core.utils.util import remove_punctuation_and_length -from core.handle.sendAudioHandle import schedule_with_interrupt, send_stt_message +from core.handle.sendAudioHandle import send_stt_message TAG = __name__ logger = setup_logging() @@ -61,10 +60,7 @@ async def handleCMDMessage(conn, text): async def startToChat(conn, text): # 异步发送 stt 信息 - stt_task = asyncio.create_task( - schedule_with_interrupt(0, send_stt_message(conn, text)) - ) - conn.scheduled_tasks.append(stt_task) + await send_stt_message(conn, text) conn.executor.submit(conn.chat, text) diff --git a/core/handle/sendAudioHandle.py b/core/handle/sendAudioHandle.py index bcf639a..ab653a3 100644 --- a/core/handle/sendAudioHandle.py +++ b/core/handle/sendAudioHandle.py @@ -8,49 +8,39 @@ logger = setup_logging() -async def isLLMWantToFinish(conn): - first_text = conn.tts_first_text - last_text = conn.tts_last_text +async def isLLMWantToFinish(last_text): _, last_text_without_punctuation = remove_punctuation_and_length(last_text) if "再见" in last_text_without_punctuation or "拜拜" in last_text_without_punctuation: return True - _, first_text_without_punctuation = remove_punctuation_and_length(first_text) - if "再见" in first_text_without_punctuation or "拜拜" in first_text_without_punctuation: - return True return False -async def sendAudioMessage(conn, audios, duration, text): - base_delay = conn.tts_duration - +async def sendAudioMessage(conn, audios, text): # 发送 tts.start if text == conn.tts_first_text: logger.bind(tag=TAG).info(f"发送第一段语音: {text}") conn.tts_start_speak_time = time.time() - - # 发送 sentence_start(每个音频文件之前发送一次) - sentence_task = asyncio.create_task( - schedule_with_interrupt(base_delay, send_tts_message(conn, "sentence_start", text)) - ) - conn.scheduled_tasks.append(sentence_task) - - conn.tts_duration += duration + await send_tts_message(conn, "sentence_start", text) # 发送音频数据 + frame_duration = 60 # 初始帧持续时间(毫秒) + start_time = time.time() # 记录开始时间 for idx, opus_packet in enumerate(audios): + if conn.client_abort: + return + # 计算当前包的预期发送时间 + expected_time = start_time + idx * (frame_duration / 1000) + current_time = time.time() + # 如果未到预期时间则等待差值 + if current_time < expected_time: + await asyncio.sleep(expected_time - current_time) + # 发送音频包 await conn.websocket.send(opus_packet) if conn.llm_finish_task and text == conn.tts_last_text: - stop_duration = conn.tts_duration - (time.time() - conn.tts_start_speak_time) - stop_task = asyncio.create_task( - schedule_with_interrupt(stop_duration, send_tts_message(conn, 'stop')) - ) - conn.scheduled_tasks.append(stop_task) - if await isLLMWantToFinish(conn): - finish_task = asyncio.create_task( - schedule_with_interrupt(stop_duration, await conn.close()) - ) - conn.scheduled_tasks.append(finish_task) + await send_tts_message(conn, 'stop') + if await isLLMWantToFinish(text): + await conn.close() async def send_tts_message(conn, state, text=None): @@ -84,12 +74,3 @@ async def send_stt_message(conn, text): "session_id": conn.session_id} )) await send_tts_message(conn, "start") - - -async def schedule_with_interrupt(delay, coro): - """可中断的延迟调度""" - try: - await asyncio.sleep(delay) - await coro - except asyncio.CancelledError: - pass diff --git a/core/utils/llm.py b/core/utils/llm.py index 8fdec09..4c6ae27 100644 --- a/core/utils/llm.py +++ b/core/utils/llm.py @@ -1,7 +1,5 @@ import os import sys -import asyncio -from typing import List, Dict, Any # 添加项目根目录到Python路径 current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -10,10 +8,6 @@ from config.logger import setup_logging import importlib -from datetime import datetime -from core.utils.util import is_segment -from core.utils.util import get_string_no_punctuation_or_emoji -from core.utils.util import read_config, get_project_dir logger = setup_logging() @@ -27,117 +21,3 @@ def create_instance(class_name, *args, **kwargs): return sys.modules[lib_name].LLMProvider(*args, **kwargs) raise ValueError(f"不支持的LLM类型: {class_name},请检查该配置的type是否设置正确") - - -async def test_single_model(llm_name: str, llm_config: Dict[str, Any], test_prompt: str, config: Dict[str, Any]) -> Dict[str, Any]: - """异步测试单个模型""" - try: - # 获取实际的LLM类型 - llm_type = llm_config["type"] if "type" in llm_config else llm_name - llm = create_instance(llm_type, llm_config) - - # 开始测试 - dialogue = [] - dialogue.append({"role": "system", "content": config.get("prompt")}) - dialogue.append({"role": "user", "content": test_prompt}) - - start_time = datetime.now() - llm_responses = llm.response("test", dialogue) - response_message = [] - first_response_time = None - total_response_time = None - start = 0 - full_response = "" - - for content in llm_responses: - response_message.append(content) - full_response += content - - if is_segment(response_message): - segment_text = "".join(response_message[start:]) - segment_text = get_string_no_punctuation_or_emoji(segment_text) - if len(segment_text) > 0: - if first_response_time is None: - first_response_time = (datetime.now() - start_time).total_seconds() - start = len(response_message) - - total_response_time = (datetime.now() - start_time).total_seconds() - - return { - "name": llm_name, - "type": llm_type, - "first_response_time": first_response_time, - "total_response_time": total_response_time, - "response_length": len(full_response), - "status": "成功", - "response": full_response - } - - except Exception as e: - print(f"测试 {llm_name} 时发生错误: {str(e)}") - return { - "name": llm_name, - "type": llm_config.get("type", llm_name), - "first_response_time": None, - "total_response_time": None, - "response_length": 0, - "status": f"失败 - {str(e)}", - "response": "" - } - - -async def main(): - """ - LLM模型响应速度测试和排行(异步版本) - """ - config = read_config(get_project_dir() + "config.yaml") - test_prompt = "你好小智" - - print("开始并发测试所有模型...") - - # 创建所有模型的测试任务 - tasks = [] - for llm_name, llm_config in config["LLM"].items(): - task = asyncio.create_task(test_single_model(llm_name, llm_config, test_prompt, config)) - tasks.append(task) - - # 等待所有测试完成 - test_results = await asyncio.gather(*tasks) - - # 打印测试结果排行榜 - print("\n========= LLM模型性能测试排行榜 =========") - print("测试提示词:", test_prompt) - - # 过滤出成功的结果,并确保数值有效 - successful_results = [r for r in test_results if r["status"] == "成功" and r["first_response_time"] is not None] - - if successful_results: - print("\n1. 首次响应时间排行:") - sorted_by_first = sorted(successful_results, key=lambda x: x["first_response_time"]) - for i, result in enumerate(sorted_by_first, 1): - print(f"{i}. {result['name']}({result['type']}) - {result['first_response_time']:.2f}秒") - print(f" 响应内容: {result['response'][:50]}...") # 只显示前50个字符 - - print("\n2. 总响应时间排行:") - sorted_by_total = sorted(successful_results, key=lambda x: x["total_response_time"] or float('inf')) - for i, result in enumerate(sorted_by_total, 1): - if result["total_response_time"] is not None: - print(f"{i}. {result['name']}({result['type']}) - {result['total_response_time']:.2f}秒") - - print("\n3. 响应长度比较:") - sorted_by_length = sorted(successful_results, key=lambda x: x["response_length"], reverse=True) - for i, result in enumerate(sorted_by_length, 1): - print(f"{i}. {result['name']}({result['type']}) - {result['response_length']}字符") - else: - print("\n没有成功完成测试的模型。") - - if len(test_results) != len(successful_results): - print("\n测试失败的模型:") - failed_results = [r for r in test_results if r["status"] != "成功" or r["first_response_time"] is None] - for result in failed_results: - print(f"- {result['name']}({result['type']}): {result['status']}") - - -if __name__ == "__main__": - # 运行异步主函数 - asyncio.run(main()) diff --git a/core/utils/tts.py b/core/utils/tts.py index 2fc733f..7f71536 100644 --- a/core/utils/tts.py +++ b/core/utils/tts.py @@ -2,8 +2,6 @@ import sys from config.logger import setup_logging import importlib -from datetime import datetime -from core.utils.util import read_config, get_project_dir logger = setup_logging() @@ -16,26 +14,4 @@ def create_instance(class_name, *args, **kwargs): sys.modules[lib_name] = importlib.import_module(f'{lib_name}') return sys.modules[lib_name].TTSProvider(*args, **kwargs) - raise ValueError(f"不支持的TTS类型: {class_name},请检查该配置的type是否设置正确") - - -if __name__ == "__main__": - """ - 响应速度测试 - """ - config = read_config(get_project_dir() + "config.yaml") - tts = create_instance( - config["selected_module"]["TTS"] - if not 'type' in config["TTS"][config["selected_module"]["TTS"]] - else - config["TTS"][config["selected_module"]["TTS"]]["type"], - config["TTS"][config["selected_module"]["TTS"]], - config["delete_audio"] - ) - tts.output_file = get_project_dir() + tts.output_file - start = datetime.now() - file_path = tts.to_tts("你好,测试,我是人工智能小智") - print("语音合成耗时:" + str(datetime.now() - start)) - start = datetime.now() - tts.wav_to_opus_data(file_path) - print("语音opus耗时:" + str(datetime.now() - start)) + raise ValueError(f"不支持的TTS类型: {class_name},请检查该配置的type是否设置正确") \ No newline at end of file diff --git a/core/utils/util.py b/core/utils/util.py index 19495c3..a61f770 100644 --- a/core/utils/util.py +++ b/core/utils/util.py @@ -34,13 +34,6 @@ def write_json_file(file_path, data): json.dump(data, file, ensure_ascii=False, indent=4) -def is_segment(tokens): - if tokens[-1] in (",", ".", "?", ",", "。", "?", "!", "!", ";", ";", ":", ":"): - return True - else: - return False - - def is_punctuation_or_emoji(char): """检查字符是否为空格、指定标点或表情符号""" # 定义需要去除的中英文标点(包括全角/半角)