From 20916608a29ea333c3650095f2d160b294aa59bd Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Thu, 10 Oct 2024 23:47:03 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/base/base_server.py | 1 + server/base/database/product_db.py | 4 ++-- server/base/models/user_model.py | 6 +++--- server/base/routers/llm.py | 2 +- server/base/routers/streaming_room.py | 2 +- server/base/utils.py | 2 -- 6 files changed, 8 insertions(+), 9 deletions(-) diff --git a/server/base/base_server.py b/server/base/base_server.py index 0e7529c..6f62dcd 100644 --- a/server/base/base_server.py +++ b/server/base/base_server.py @@ -59,6 +59,7 @@ async def lifespan(app: FastAPI): if WEB_CONFIGS.ENABLE_RAG: from .modules.rag.rag_worker import load_rag_model + # 生成 rag 数据库 await load_rag_model(user_id=1) diff --git a/server/base/database/product_db.py b/server/base/database/product_db.py index 20c2e72..f79b143 100644 --- a/server/base/database/product_db.py +++ b/server/base/database/product_db.py @@ -26,7 +26,7 @@ async def get_db_product_info( page_size: int = 10, product_name: str | None = None, product_id: int | None = None, - exclude_list: List[int] | None = None + exclude_list: List[int] | None = None, ) -> Tuple[List[ProductInfo], int]: """查询数据库中的商品信息 @@ -64,7 +64,7 @@ async def get_db_product_info( query_condiction = and_( ProductInfo.user_id == user_id, ProductInfo.delete == False, ProductInfo.product_id == product_id ) - + elif exclude_list is not None: # 排除查询 query_condiction = and_( diff --git a/server/base/models/user_model.py b/server/base/models/user_model.py index e038db1..d1d208f 100644 --- a/server/base/models/user_model.py +++ b/server/base/models/user_model.py @@ -28,16 +28,16 @@ class UserBaseInfo(BaseModel): username: str = Field(index=True, unique=True) email: str | None = None avatar: str | None = None - create_time: datetime =datetime.now() + create_time: datetime = datetime.now() # ======================================================= # 数据库模型 # ======================================================= class UserInfo(UserBaseInfo, SQLModel, table=True): - + __tablename__ = "user_info" - + hashed_password: str ip_address: IPv4Address | None = None delete: bool = False diff --git a/server/base/routers/llm.py b/server/base/routers/llm.py index 2715f88..f05857d 100644 --- a/server/base/routers/llm.py +++ b/server/base/routers/llm.py @@ -73,7 +73,7 @@ async def gen_poduct_base_prompt( """ assert (streamer_id == -1 and streamer_info is not None) or (streamer_id != -1 and streamer_info is None) - assert (product_id == -1 and product_info is not None) or (product_id != -1 and product_info is None) + assert (product_id == -1 and product_info is not None) or (product_id != -1 and product_info is None) # 加载对话配置文件 dataset_yaml = await get_llm_product_prompt_base_info() diff --git a/server/base/routers/streaming_room.py b/server/base/routers/streaming_room.py index aba4e5e..0a67bc0 100644 --- a/server/base/routers/streaming_room.py +++ b/server/base/routers/streaming_room.py @@ -75,7 +75,7 @@ async def get_streaming_room_id_api( # 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段 format_product_list = [] for db_product in streaming_room_list[0].product_list: - + product_dict = dict(db_product) # 将 start_video 改为服务器地址 if product_dict["start_video"] != "": diff --git a/server/base/utils.py b/server/base/utils.py index 3f1b93a..a7bc122 100644 --- a/server/base/utils.py +++ b/server/base/utils.py @@ -329,7 +329,6 @@ def create_default_user(): session.add(admin_user) session.commit() - def init_user() -> bool: """判断是否需要创建默认用户 @@ -347,7 +346,6 @@ def init_user() -> bool: return False - def create_default_product_item(): """生成商品默认数据库""" delivery_company_list = ["京东", "顺丰", "韵达", "圆通", "中通"]