Skip to content

Commit

Permalink
完善模型导入
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterH0323 committed Jul 3, 2024
1 parent 0d1f5b6 commit 10ccccf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
23 changes: 12 additions & 11 deletions utils/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
from .web_configs import WEB_CONFIGS

from .rag.rag_worker import load_rag_model
from .asr.asr_worker import load_asr_model
from .digital_human.realtime_inference import digital_human_preprocess
from .infer.load_infer_model import load_turbomind_model
from .tts.gpt_sovits.inference_gpt_sovits import get_tts_model
from .web_configs import WEB_CONFIGS


# ==================================================================
# RAG 模型
# ==================================================================

if WEB_CONFIGS.ENABLE_RAG:
RAG_RETRIEVER = load_rag_model()
else:
RAG_RETRIEVER = None


# ==================================================================
Expand All @@ -33,6 +24,16 @@
DIGITAL_HUMAN_HANDLER = None


# ==================================================================
# RAG 模型
# ==================================================================

if WEB_CONFIGS.ENABLE_RAG:
RAG_RETRIEVER = load_rag_model()
else:
RAG_RETRIEVER = None


# ==================================================================
# TTS 模型
# ==================================================================
Expand Down
11 changes: 7 additions & 4 deletions utils/web_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class WebConfigs:
LLM_MODEL_NAME: str = "HinGwenWoong/streamer-sales-lelemiao-7b"

SALES_NAME: str = "乐乐喵" # 启动的角色名

LLM_MODEL_DIR: str = r"./weights/llm_weights/"

# ==================================================================
Expand All @@ -26,12 +26,14 @@ class WebConfigs:
ENABLE_RAG: bool = True # True 启用 RAG 检索增强,False 不启用
ENABLE_TTS: bool = True # True 启动 tts,False 不启用
ENABLE_DIGITAL_HUMAN: bool = True # True 启动 数字人,False 不启用
ENABLE_AGENT: bool = os.environ.get("ENABLE_AGENT", 'true') == 'true' # True 启动 Agent,False 不启用
ENABLE_ASR: bool = os.environ.get("ENABLE_ASR", 'true') == 'true' # True 启动 语音转文字,False 不启用
ENABLE_AGENT: bool = os.environ.get("ENABLE_AGENT", "true") == "true" # True 启动 Agent,False 不启用
ENABLE_ASR: bool = os.environ.get("ENABLE_ASR", "true") == "true" # True 启动 语音转文字,False 不启用

DISABLE_UPLOAD: bool = os.getenv("DISABLE_UPLOAD") == "true"

CACHE_MAX_ENTRY_COUNT: float = float(os.environ.get("KV_CACHE", 0.1)) # KV cache 占比,如果部署出现 OOM 降低这个配置,反之可以加大
CACHE_MAX_ENTRY_COUNT: float = float(
os.environ.get("KV_CACHE", 0.1)
) # KV cache 占比,如果部署出现 OOM 降低这个配置,反之可以加大

# ==================================================================
# 页面配置
Expand Down Expand Up @@ -93,5 +95,6 @@ class WebConfigs:
ASR_WAV_SAVE_PATH: str = r"./work_dirs/asr_wavs"
ASR_MODEL_DIR: str = r"./weights/asr_weights/"


# 实例化
WEB_CONFIGS = WebConfigs()

0 comments on commit 10ccccf

Please sign in to comment.