forked from RockChinQ/LangBot
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
192 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,121 +1,145 @@ | ||
# 多情景预设值管理 | ||
import json | ||
import logging | ||
import config | ||
import os | ||
|
||
# __current__ = "default" | ||
# """当前默认使用的情景预设的名称 | ||
|
||
# 由管理员使用`!default <名称>`指令切换 | ||
# """ | ||
|
||
# __prompts_from_files__ = {} | ||
# """从文件中读取的情景预设值""" | ||
|
||
# __scenario_from_files__ = {} | ||
|
||
|
||
__universal_first_reply__ = "ok, I'll follow your commands." | ||
"""通用首次回复""" | ||
|
||
|
||
class ScenarioMode: | ||
"""情景预设模式抽象类""" | ||
|
||
using_prompt_name = "default" | ||
"""新session创建时使用的prompt名称""" | ||
|
||
prompts: dict[str, list] = {} | ||
|
||
def __init__(self): | ||
logging.debug("prompts: {}".format(self.prompts)) | ||
|
||
def list(self) -> dict[str, list]: | ||
"""获取所有情景预设的名称及内容""" | ||
return self.prompts | ||
|
||
def get_prompt(self, name: str) -> tuple[list, str]: | ||
"""获取指定情景预设的名称及内容""" | ||
for key in self.prompts: | ||
if key.startswith(name): | ||
return self.prompts[key], key | ||
raise Exception("没有找到情景预设: {}".format(name)) | ||
|
||
def set_using_name(self, name: str) -> str: | ||
"""设置默认情景预设""" | ||
for key in self.prompts: | ||
if key.startswith(name): | ||
self.using_prompt_name = key | ||
return key | ||
raise Exception("没有找到情景预设: {}".format(name)) | ||
|
||
def get_full_name(self, name: str) -> str: | ||
"""获取完整的情景预设名称""" | ||
for key in self.prompts: | ||
if key.startswith(name): | ||
return key | ||
raise Exception("没有找到情景预设: {}".format(name)) | ||
|
||
def get_using_name(self) -> str: | ||
"""获取默认情景预设""" | ||
return self.using_prompt_name | ||
|
||
|
||
class NormalScenarioMode(ScenarioMode): | ||
"""普通情景预设模式""" | ||
|
||
def __init__(self): | ||
global __universal_first_reply__ | ||
# 加载config中的default_prompt值 | ||
if type(config.default_prompt) == str: | ||
self.using_prompt_name = "default" | ||
self.prompts = {"default": [ | ||
{ | ||
"role": "user", | ||
"content": config.default_prompt | ||
},{ | ||
"role": "assistant", | ||
"content": __universal_first_reply__ | ||
} | ||
]} | ||
|
||
elif type(config.default_prompt) == dict: | ||
for key in config.default_prompt: | ||
self.prompts[key] = [ | ||
{ | ||
"role": "user", | ||
"content": config.default_prompt[key] | ||
},{ | ||
"role": "assistant", | ||
"content": __universal_first_reply__ | ||
} | ||
] | ||
|
||
__current__ = "default" | ||
"""当前默认使用的情景预设的名称 | ||
由管理员使用`!default <名称>`指令切换 | ||
""" | ||
|
||
__prompts_from_files__ = {} | ||
"""从文件中读取的情景预设值""" | ||
|
||
__scenario_from_files__ = {} | ||
|
||
|
||
def read_prompt_from_file(): | ||
"""从文件读取预设值""" | ||
# 读取prompts/目录下的所有文件,以文件名为键,文件内容为值 | ||
# 保存在__prompts_from_files__中 | ||
global __prompts_from_files__ | ||
import os | ||
|
||
__prompts_from_files__ = {} | ||
for file in os.listdir("prompts"): | ||
with open(os.path.join("prompts", file), encoding="utf-8") as f: | ||
__prompts_from_files__[file] = f.read() | ||
|
||
|
||
def read_scenario_from_file(): | ||
"""从JSON文件读取情景预设""" | ||
global __scenario_from_files__ | ||
import os | ||
|
||
__scenario_from_files__ = {} | ||
for file in os.listdir("scenario"): | ||
if file == "default-template.json": | ||
continue | ||
with open(os.path.join("scenario", file), encoding="utf-8") as f: | ||
__scenario_from_files__[file] = json.load(f) | ||
|
||
|
||
def get_prompt_dict() -> dict: | ||
"""获取预设值字典""" | ||
import config | ||
default_prompt = config.default_prompt | ||
if type(default_prompt) == str: | ||
default_prompt = {"default": default_prompt} | ||
elif type(default_prompt) == dict: | ||
pass | ||
else: | ||
raise TypeError("default_prompt must be str or dict") | ||
|
||
# 将文件中的预设值合并到default_prompt中 | ||
for key in __prompts_from_files__: | ||
default_prompt[key] = __prompts_from_files__[key] | ||
# 从prompts/目录下的文件中载入 | ||
# 遍历文件 | ||
for file in os.listdir("prompts"): | ||
with open(os.path.join("prompts", file), encoding="utf-8") as f: | ||
self.prompts[file] = [ | ||
{ | ||
"role": "user", | ||
"content": f.read() | ||
},{ | ||
"role": "assistant", | ||
"content": __universal_first_reply__ | ||
} | ||
] | ||
|
||
return default_prompt | ||
|
||
class FullScenarioMode(ScenarioMode): | ||
"""完整情景预设模式""" | ||
|
||
def set_current(name): | ||
global __current__ | ||
for key in get_prompt_dict(): | ||
if key.lower().startswith(name.lower()): | ||
__current__ = key | ||
return | ||
raise KeyError("未找到情景预设: " + name) | ||
def __init__(self): | ||
"""从json读取所有""" | ||
# 遍历scenario/目录下的所有文件,以文件名为键,文件内容中的prompt为值 | ||
for file in os.listdir("scenario"): | ||
if file == "default-template.json": | ||
continue | ||
with open(os.path.join("scenario", file), encoding="utf-8") as f: | ||
self.prompts[file] = json.load(f)["prompt"] | ||
|
||
super().__init__() | ||
|
||
def get_current(): | ||
global __current__ | ||
return __current__ | ||
|
||
scenario_mode_mapping = {} | ||
"""情景预设模式名称与对象的映射""" | ||
|
||
def set_to_default(): | ||
global __current__ | ||
default_dict = get_prompt_dict() | ||
|
||
if "default" in default_dict: | ||
__current__ = "default" | ||
else: | ||
__current__ = list(default_dict.keys())[0] | ||
def register_all(): | ||
"""注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载""" | ||
global scenario_mode_mapping | ||
scenario_mode_mapping = { | ||
"normal": NormalScenarioMode(), | ||
"full_scenario": FullScenarioMode() | ||
} | ||
|
||
|
||
def get_prompt(name: str = None) -> list: | ||
global __scenario_from_files__ | ||
def mode_inst() -> ScenarioMode: | ||
"""获取指定名称的情景预设模式对象""" | ||
import config | ||
preset_mode = config.preset_mode | ||
|
||
"""获取预设值""" | ||
if name is None: | ||
name = get_current() | ||
|
||
# JSON预设方式 | ||
if preset_mode == 'full_scenario': | ||
import os | ||
|
||
for key in __scenario_from_files__: | ||
if key.lower().startswith(name.lower()): | ||
logging.debug('成功加载情景预设从JSON文件: {}'.format(key)) | ||
return __scenario_from_files__[key]['prompt'] | ||
|
||
# 默认预设方式 | ||
elif preset_mode == 'default' or preset_mode == 'normal': | ||
|
||
default_dict = get_prompt_dict() | ||
|
||
for key in default_dict: | ||
if key.lower().startswith(name.lower()): | ||
return [ | ||
{ | ||
"role": "user", | ||
"content": default_dict[key] | ||
}, | ||
{ | ||
"role": "assistant", | ||
"content": "好的。" | ||
} | ||
] | ||
if config.preset_mode == "default": | ||
config.preset_mode = "normal" | ||
|
||
raise KeyError("未找到默认情景预设: " + name) | ||
return scenario_mode_mapping[config.preset_mode] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.