From 10ccccf2833ecbb750a426a5439397da62f82473 Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Wed, 3 Jul 2024 19:50:15 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=A8=A1=E5=9E=8B=E5=AF=BC?= =?UTF-8?q?=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/model_loader.py | 23 ++++++++++++----------- utils/web_configs.py | 11 +++++++---- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/utils/model_loader.py b/utils/model_loader.py index 522b597..5e8f6e8 100644 --- a/utils/model_loader.py +++ b/utils/model_loader.py @@ -1,19 +1,10 @@ +from .web_configs import WEB_CONFIGS + from .rag.rag_worker import load_rag_model from .asr.asr_worker import load_asr_model from .digital_human.realtime_inference import digital_human_preprocess from .infer.load_infer_model import load_turbomind_model from .tts.gpt_sovits.inference_gpt_sovits import get_tts_model -from .web_configs import WEB_CONFIGS - - -# ================================================================== -# RAG 模型 -# ================================================================== - -if WEB_CONFIGS.ENABLE_RAG: - RAG_RETRIEVER = load_rag_model() -else: - RAG_RETRIEVER = None # ================================================================== @@ -33,6 +24,16 @@ DIGITAL_HUMAN_HANDLER = None +# ================================================================== +# RAG 模型 +# ================================================================== + +if WEB_CONFIGS.ENABLE_RAG: + RAG_RETRIEVER = load_rag_model() +else: + RAG_RETRIEVER = None + + # ================================================================== # TTS 模型 # ================================================================== diff --git a/utils/web_configs.py b/utils/web_configs.py index 7524be0..8a19b88 100644 --- a/utils/web_configs.py +++ b/utils/web_configs.py @@ -17,7 +17,7 @@ class WebConfigs: LLM_MODEL_NAME: str = "HinGwenWoong/streamer-sales-lelemiao-7b" SALES_NAME: str = "乐乐喵" # 启动的角色名 - + LLM_MODEL_DIR: str = r"./weights/llm_weights/" # ================================================================== @@ -26,12 +26,14 @@ class WebConfigs: ENABLE_RAG: bool = True # True 启用 RAG 检索增强,False 不启用 ENABLE_TTS: bool = True # True 启动 tts,False 不启用 ENABLE_DIGITAL_HUMAN: bool = True # True 启动 数字人,False 不启用 - ENABLE_AGENT: bool = os.environ.get("ENABLE_AGENT", 'true') == 'true' # True 启动 Agent,False 不启用 - ENABLE_ASR: bool = os.environ.get("ENABLE_ASR", 'true') == 'true' # True 启动 语音转文字,False 不启用 + ENABLE_AGENT: bool = os.environ.get("ENABLE_AGENT", "true") == "true" # True 启动 Agent,False 不启用 + ENABLE_ASR: bool = os.environ.get("ENABLE_ASR", "true") == "true" # True 启动 语音转文字,False 不启用 DISABLE_UPLOAD: bool = os.getenv("DISABLE_UPLOAD") == "true" - CACHE_MAX_ENTRY_COUNT: float = float(os.environ.get("KV_CACHE", 0.1)) # KV cache 占比,如果部署出现 OOM 降低这个配置,反之可以加大 + CACHE_MAX_ENTRY_COUNT: float = float( + os.environ.get("KV_CACHE", 0.1) + ) # KV cache 占比,如果部署出现 OOM 降低这个配置,反之可以加大 # ================================================================== # 页面配置 @@ -93,5 +95,6 @@ class WebConfigs: ASR_WAV_SAVE_PATH: str = r"./work_dirs/asr_wavs" ASR_MODEL_DIR: str = r"./weights/asr_weights/" + # 实例化 WEB_CONFIGS = WebConfigs()