Skip to content

Commit

Permalink
fix: 修改了system prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
JieGenius committed Feb 4, 2024
1 parent a994e95 commit 4abd5ee
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 7 deletions.
4 changes: 4 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import os

if __name__ == '__main__':
os.system('streamlit run web_demo.py --server.port 7860 --server.enableStaticServing True')
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ xtuner[all]
lmdeploy[all]
streamlit
lagent==0.1.3
onnxruntime-gpu
onnxruntime-gpu
openxlab
1 change: 1 addition & 0 deletions utils/actions/fundus_diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DEFAULT_DESCRIPTION = """一个眼底图像诊断的工具,
可以诊断眼底图像中的病变类型,如青光眼、是否为糖尿病视网膜病变。
输入为眼底图的图像路径,可以为本地地址,也可以为网络地址(链接)
当且仅当用户上传了图片时,才可调用本工具。
"""

class FundusDiagnosis(BaseAction):
Expand Down
59 changes: 59 additions & 0 deletions utils/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from lagent.agents import ReAct
from lagent import AgentReturn, ActionReturn
import copy
from transformers import GenerationConfig
class MyReAct(ReAct):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def chat(self, message: str) -> AgentReturn:
self._inner_history = []
self._inner_history.append(dict(role='user', content=message))
agent_return = AgentReturn()
default_response = 'Sorry that I cannot answer your question.'
gen_config = GenerationConfig(
max_new_tokens=2048,
do_sample=True,
temperature=0.1,
top_p=0.75,
top_k=40,
repetition_penalty=1.002,
)

for turn in range(self.max_turn):
prompt = self._protocol.format(
chat_history=self.session_history,
inner_step=self._inner_history,
action_executor=self._action_executor,
force_stop=(turn == self.max_turn - 1))
response = self._llm.generate_from_template(prompt, 512, generation_config=gen_config)
self._inner_history.append(
dict(role='assistant', content=response))
thought, action, action_input = self._protocol.parse(
response, self._action_executor)
action_return: ActionReturn = self._action_executor(
action, action_input)

if action_return.type == "NoAction":
# 没有获取到action的情况
action_return.thought = "该回答不需要调用任何Action"
agent_return.response = response
print("模型输出异常,未按照指定模板生成Action,直接返回原始输出")
break
action_return.thought = thought
agent_return.actions.append(action_return)
if action_return.type == self._action_executor.finish_action.name:
agent_return.response = action_return.result['text']
break
self._inner_history.append(
dict(
role='system',
content=self._protocol.format_response(action_return)))
else:
agent_return.response = default_response
agent_return.inner_steps = copy.deepcopy(self._inner_history)
# only append the user and final response
self._session_history.append(dict(role='user', content=message))
self._session_history.append(
dict(role='assistant', content=agent_return.response))
return agent_return
49 changes: 43 additions & 6 deletions web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,35 @@
from streamlit.logger import get_logger

from lagent.actions import ActionExecutor
from lagent.agents.react import ReAct
from modelscope import snapshot_download
from lagent.agents.react import ReAct, ReActProtocol
from lagent.llms.huggingface import HFTransformerCasualLM
from utils.actions.fundus_diagnosis import FundusDiagnosis
from lagent.llms.meta_template import INTERNLM2_META as META

from utils.agent import MyReAct


# MODEL_DIR = "/share/model_repos/internlm2-chat-7b"
MODEL_DIR = "/root/OculiChatDA/merged_model_e1"
CALL_PROTOCOL_CN = """你是一名眼科专家,可以通过文字和图片来帮助用户诊断眼睛的状态。(请不要在回复中透露你的个人信息和工作单位)。
你可以调用外部工具来帮助你解决问题。
可以使用的工具包括:
{tool_description}
如果使用工具请遵循以下格式回复:
```
{thought}思考你当前步骤需要解决什么问题,是否需要使用工具
{action}工具名称,你的工具必须从 [{action_names}] 选择
{action_input}工具输入参数
```
工具返回按照以下格式回复:
```
{response}调用工具后的结果
```
如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
```
{thought}给出最终答案的思考过程
{finish}最终答案
```
开始!"""
class SessionState:

def init_state(self):
Expand All @@ -23,6 +46,7 @@ def init_state(self):
cache_dir = "glaucoma_cls_dr_grading"
model_path = os.path.join(cache_dir, "flyer123/GlauClsDRGrading", "model.onnx")
if not os.path.exists(model_path):
from modelscope import snapshot_download
snapshot_download("flyer123/GlauClsDRGrading", cache_dir=cache_dir)

action_list = [FundusDiagnosis(model_path=model_path)]
Expand All @@ -33,12 +57,15 @@ def init_state(self):
st.session_state['model_map'] = {}
st.session_state['model_selected'] = None
st.session_state['plugin_actions'] = set()
st.session_state["turn"] = 0 # 记录当前会话的轮次,第一轮需要添加system


def clear_state(self):
"""Clear the existing session state."""
st.session_state['assistant'] = []
st.session_state['user'] = []
st.session_state['model_selected'] = None
st.session_state["turn"] = 0
if 'chatbot' in st.session_state:
st.session_state['chatbot']._session_history = []

Expand Down Expand Up @@ -102,12 +129,12 @@ def init_model(self, option):
@st.cache_resource
def load_internlm2():
return HFTransformerCasualLM(
'/share/model_repos/internlm2-chat-7b', meta_template=META)
MODEL_DIR, meta_template=META)

def initialize_chatbot(self, model, plugin_action):
"""Initialize the chatbot with the given model and plugin actions."""
return ReAct(
llm=model, action_executor=ActionExecutor(actions=plugin_action))
return MyReAct(
llm=model, action_executor=ActionExecutor(actions=plugin_action), protocol=ReActProtocol(call_protocol=CALL_PROTOCOL_CN))

def render_user(self, prompt: str):
with st.chat_message('user', avatar="👦"):
Expand Down Expand Up @@ -231,10 +258,20 @@ def main():
st.session_state['assistant'].append(copy.deepcopy(agent_return))
logger.info("agent_return:",agent_return.inner_steps)
st.session_state['ui'].render_assistant(agent_return)
st.session_state["turn"] += 1


if __name__ == '__main__':
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_dir = os.path.join(root_dir, 'tmp_dir')
os.makedirs(root_dir, exist_ok=True)

if not os.path.exists(MODEL_DIR):
from openxlab.model import download

download(model_repo='OpenLMLab/internlm2-chat-7b', output=MODEL_DIR)

print("解压后目录结果如下:")
print(os.listdir(MODEL_DIR))

main()

0 comments on commit 4abd5ee

Please sign in to comment.