Skip to content

Commit

Permalink
feat(dprompt.py): 解耦完成
Browse files Browse the repository at this point in the history
  • Loading branch information
RockChinQ committed Mar 26, 2023
1 parent f6cad85 commit bb4b897
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 153 deletions.
5 changes: 2 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,8 @@ def start(first_time_init=False):
import pkg.openai.session
import pkg.qqbot.manager
import pkg.openai.dprompt

pkg.openai.dprompt.read_prompt_from_file()
pkg.openai.dprompt.read_scenario_from_file()

pkg.openai.dprompt.register_all()

# 主启动流程
database = pkg.database.manager.DatabaseManager()
Expand Down
232 changes: 128 additions & 104 deletions pkg/openai/dprompt.py
Original file line number Diff line number Diff line change
@@ -1,121 +1,145 @@
# 多情景预设值管理
import json
import logging
import config
import os

# __current__ = "default"
# """当前默认使用的情景预设的名称

# 由管理员使用`!default <名称>`指令切换
# """

# __prompts_from_files__ = {}
# """从文件中读取的情景预设值"""

# __scenario_from_files__ = {}


__universal_first_reply__ = "ok, I'll follow your commands."
"""通用首次回复"""


class ScenarioMode:
"""情景预设模式抽象类"""

using_prompt_name = "default"
"""新session创建时使用的prompt名称"""

prompts: dict[str, list] = {}

def __init__(self):
logging.debug("prompts: {}".format(self.prompts))

def list(self) -> dict[str, list]:
"""获取所有情景预设的名称及内容"""
return self.prompts

def get_prompt(self, name: str) -> tuple[list, str]:
"""获取指定情景预设的名称及内容"""
for key in self.prompts:
if key.startswith(name):
return self.prompts[key], key
raise Exception("没有找到情景预设: {}".format(name))

def set_using_name(self, name: str) -> str:
"""设置默认情景预设"""
for key in self.prompts:
if key.startswith(name):
self.using_prompt_name = key
return key
raise Exception("没有找到情景预设: {}".format(name))

def get_full_name(self, name: str) -> str:
"""获取完整的情景预设名称"""
for key in self.prompts:
if key.startswith(name):
return key
raise Exception("没有找到情景预设: {}".format(name))

def get_using_name(self) -> str:
"""获取默认情景预设"""
return self.using_prompt_name


class NormalScenarioMode(ScenarioMode):
"""普通情景预设模式"""

def __init__(self):
global __universal_first_reply__
# 加载config中的default_prompt值
if type(config.default_prompt) == str:
self.using_prompt_name = "default"
self.prompts = {"default": [
{
"role": "user",
"content": config.default_prompt
},{
"role": "assistant",
"content": __universal_first_reply__
}
]}

elif type(config.default_prompt) == dict:
for key in config.default_prompt:
self.prompts[key] = [
{
"role": "user",
"content": config.default_prompt[key]
},{
"role": "assistant",
"content": __universal_first_reply__
}
]

__current__ = "default"
"""当前默认使用的情景预设的名称
由管理员使用`!default <名称>`指令切换
"""

__prompts_from_files__ = {}
"""从文件中读取的情景预设值"""

__scenario_from_files__ = {}


def read_prompt_from_file():
"""从文件读取预设值"""
# 读取prompts/目录下的所有文件,以文件名为键,文件内容为值
# 保存在__prompts_from_files__中
global __prompts_from_files__
import os

__prompts_from_files__ = {}
for file in os.listdir("prompts"):
with open(os.path.join("prompts", file), encoding="utf-8") as f:
__prompts_from_files__[file] = f.read()


def read_scenario_from_file():
"""从JSON文件读取情景预设"""
global __scenario_from_files__
import os

__scenario_from_files__ = {}
for file in os.listdir("scenario"):
if file == "default-template.json":
continue
with open(os.path.join("scenario", file), encoding="utf-8") as f:
__scenario_from_files__[file] = json.load(f)


def get_prompt_dict() -> dict:
"""获取预设值字典"""
import config
default_prompt = config.default_prompt
if type(default_prompt) == str:
default_prompt = {"default": default_prompt}
elif type(default_prompt) == dict:
pass
else:
raise TypeError("default_prompt must be str or dict")

# 将文件中的预设值合并到default_prompt中
for key in __prompts_from_files__:
default_prompt[key] = __prompts_from_files__[key]
# 从prompts/目录下的文件中载入
# 遍历文件
for file in os.listdir("prompts"):
with open(os.path.join("prompts", file), encoding="utf-8") as f:
self.prompts[file] = [
{
"role": "user",
"content": f.read()
},{
"role": "assistant",
"content": __universal_first_reply__
}
]

return default_prompt

class FullScenarioMode(ScenarioMode):
"""完整情景预设模式"""

def set_current(name):
global __current__
for key in get_prompt_dict():
if key.lower().startswith(name.lower()):
__current__ = key
return
raise KeyError("未找到情景预设: " + name)
def __init__(self):
"""从json读取所有"""
# 遍历scenario/目录下的所有文件,以文件名为键,文件内容中的prompt为值
for file in os.listdir("scenario"):
if file == "default-template.json":
continue
with open(os.path.join("scenario", file), encoding="utf-8") as f:
self.prompts[file] = json.load(f)["prompt"]

super().__init__()

def get_current():
global __current__
return __current__

scenario_mode_mapping = {}
"""情景预设模式名称与对象的映射"""

def set_to_default():
global __current__
default_dict = get_prompt_dict()

if "default" in default_dict:
__current__ = "default"
else:
__current__ = list(default_dict.keys())[0]
def register_all():
"""注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载"""
global scenario_mode_mapping
scenario_mode_mapping = {
"normal": NormalScenarioMode(),
"full_scenario": FullScenarioMode()
}


def get_prompt(name: str = None) -> list:
global __scenario_from_files__
def mode_inst() -> ScenarioMode:
"""获取指定名称的情景预设模式对象"""
import config
preset_mode = config.preset_mode

"""获取预设值"""
if name is None:
name = get_current()

# JSON预设方式
if preset_mode == 'full_scenario':
import os

for key in __scenario_from_files__:
if key.lower().startswith(name.lower()):
logging.debug('成功加载情景预设从JSON文件: {}'.format(key))
return __scenario_from_files__[key]['prompt']

# 默认预设方式
elif preset_mode == 'default' or preset_mode == 'normal':

default_dict = get_prompt_dict()

for key in default_dict:
if key.lower().startswith(name.lower()):
return [
{
"role": "user",
"content": default_dict[key]
},
{
"role": "assistant",
"content": "好的。"
}
]
if config.preset_mode == "default":
config.preset_mode = "normal"

raise KeyError("未找到默认情景预设: " + name)
return scenario_mode_mapping[config.preset_mode]
4 changes: 2 additions & 2 deletions pkg/openai/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def get_default_prompt(self, use_default: str = None):
import pkg.openai.dprompt as dprompt

if use_default is None:
use_default = dprompt.get_current()
use_default = dprompt.mode_inst().get_using_name()

current_default_prompt = dprompt.get_prompt(use_default)
current_default_prompt, _ = dprompt.mode_inst().get_prompt(use_default)
return current_default_prompt

def __init__(self, name: str):
Expand Down
Loading

0 comments on commit bb4b897

Please sign in to comment.