Skip to content

Commit

Permalink
add memory-related agents: super-long dialogue and virtual memory
Browse files Browse the repository at this point in the history
  • Loading branch information
tuhahaha authored and JianxinMa committed Jun 5, 2024
1 parent 9a21e64 commit 7fcde89
Show file tree
Hide file tree
Showing 19 changed files with 464 additions and 141 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ examples/*.ipynb
**/workspace/*
test/*
tests/env.sh
examples/docqa_multi_agent.py
examples/docqa_multihp_agents.py
**/workspace/*
test/*
tests/env.sh
32 changes: 32 additions & 0 deletions examples/assistant_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from qwen_agent.agents import Assistant
from qwen_agent.gui import WebUI


def test():
bot = Assistant(llm={'model': 'qwen-plus'})
messages = [{'role': 'user', 'content': [{'text': '介绍图一'}, {'file': 'https://arxiv.org/pdf/1706.03762.pdf'}]}]
for rsp in bot.run(messages):
print(rsp)


def app_gui():
# Define the agent
bot = Assistant(llm={'model': 'qwen-plus'},
name='Assistant',
description='使用RAG检索并回答,支持文件类型:PDF/Word/PPT/TXT/HTML。')
chatbot_config = {
'prompt.suggestions': [
{
'text': '介绍图一'
},
{
'text': '第二章第一句话是什么?'
},
]
}
WebUI(bot, chatbot_config=chatbot_config).run()


if __name__ == '__main__':
# test()
app_gui()
2 changes: 1 addition & 1 deletion examples/gpt_mentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def app_gui():
'content': [{
'text': '试试看 @代码解释器 来问我~'
}]
}])
}], enable_mention=True)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/group_chat_chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
}


def test(query: str):
def test(query: str = '<1,1>'):
bot = GroupChat(agents=CFGS, llm={'model': 'qwen-max'})

messages = [Message('user', query, name=USER_NAME)]
Expand Down
41 changes: 41 additions & 0 deletions examples/long_dialogue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from qwen_agent.agents import DialogueRetrievalAgent
from qwen_agent.gui import WebUI


def test():
# Define the agent
bot = DialogueRetrievalAgent(llm={'model': 'qwen-max'})

# Chat
long_text = ','.join(['这是干扰内容'] * 1000 + ['小明的爸爸叫大头'] + ['这是干扰内容'] * 1000)
messages = [{'role': 'user', 'content': f'小明爸爸叫什么?\n{long_text}'}]

for response in bot.run(messages):
print('bot response:', response)


def app_tui():
bot = DialogueRetrievalAgent(llm={'model': 'qwen-max'})

# Chat
messages = []
while True:
query = input('user question: ')
messages.append({'role': 'user', 'content': query})
response = []
for response in bot.run(messages=messages):
print('bot response:', response)
messages.extend(response)


def app_gui():
# Define the agent
bot = DialogueRetrievalAgent(llm={'model': 'qwen-max'})

WebUI(bot).run()


if __name__ == '__main__':
# test()
# app_tui()
app_gui()
74 changes: 74 additions & 0 deletions examples/virtual_memory_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""A retrieval docqa assistant implemented by virtual memory agent"""

import os

from qwen_agent.agents import VirtualMemoryAgent
from qwen_agent.gui import WebUI

ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), 'resource')


def init_agent_service():
llm_cfg = {'model': 'qwen-max'}
system = '一个文档问答助手。'
bot = VirtualMemoryAgent(
llm=llm_cfg,
system_message=system,
)

return bot


def test(query='简单列出这篇文章的贡献https://qianwen-res.oss-cn-beijing.aliyuncs.com/QWEN_TECHNICAL_REPORT.pdf',):
# Define the agent
bot = init_agent_service()

# Chat
messages = [{'role': 'user', 'content': query}]

for response in bot.run(messages):
print('bot response:', response)


def app_tui():
# Define the agent
bot = init_agent_service()

# Chat
messages = []
while True:
# Query example: 简单列出这篇文章的贡献https://qianwen-res.oss-cn-beijing.aliyuncs.com/QWEN_TECHNICAL_REPORT.pdf
query = input('user question: ')
# File example: resource/poem.pdf
file = input('file url (press enter if no file): ').strip()
if not query:
print('user question cannot be empty!')
continue
if not file:
messages.append({'role': 'user', 'content': query})
else:
messages.append({'role': 'user', 'content': [{'text': query}, {'file': file}]})

response = []
for response in bot.run(messages):
print('bot response:', response)
messages.extend(response)


def app_gui():
# Define the agent
bot = init_agent_service()
chatbot_config = {
'prompt.suggestions': ['简单列出这篇文章的贡献https://qianwen-res.oss-cn-beijing.aliyuncs.com/QWEN_TECHNICAL_REPORT.pdf']
}

WebUI(
bot,
chatbot_config=chatbot_config,
).run()


if __name__ == '__main__':
# test()
# app_tui()
app_gui()
4 changes: 2 additions & 2 deletions qwen_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ def _call_llm(
if messages[0][ROLE] != SYSTEM:
messages.insert(0, Message(role=SYSTEM, content=self.system_message))
elif isinstance(messages[0][CONTENT], str):
messages[0][CONTENT] = self.system_message + messages[0][CONTENT]
messages[0][CONTENT] = self.system_message + '\n\n' + messages[0][CONTENT]
else:
assert isinstance(messages[0][CONTENT], list)
messages[0][CONTENT] = [ContentItem(text=self.system_message)] + messages[0][CONTENT]
messages[0][CONTENT] = [ContentItem(text=self.system_message + '\n\n')] + messages[0][CONTENT]
return self.llm.chat(messages=messages,
functions=functions,
stream=stream,
Expand Down
6 changes: 5 additions & 1 deletion qwen_agent/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@

from .article_agent import ArticleAgent
from .assistant import Assistant
from .dialogue_retrieval_agent import DialogueRetrievalAgent
# DocQAAgent is the default solution for long document question answering.
# The actual implementation of DocQAAgent may change with every release.
from .doc_qa.basic_doc_qa import BasicDocQA as DocQAAgent
from .doc_qa import BasicDocQA as DocQAAgent
from .fncall_agent import FnCallAgent
from .group_chat import GroupChat
from .group_chat_auto_router import GroupChatAutoRouter
from .group_chat_creator import GroupChatCreator
from .react_chat import ReActChat
from .router import Router
from .user_agent import UserAgent
from .virtual_memory_agent import VirtualMemoryAgent
from .write_from_scratch import WriteFromScratch

__all__ = [
Expand All @@ -29,4 +31,6 @@
'GroupChatCreator',
'GroupChatAutoRouter',
'FnCallAgent',
'VirtualMemoryAgent',
'DialogueRetrievalAgent',
]
17 changes: 6 additions & 11 deletions qwen_agent/agents/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,18 @@

import json5

from qwen_agent.agents.fncall_agent import FnCallAgent
from qwen_agent.llm import BaseChatModel
from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE, SYSTEM, Message
from qwen_agent.log import logger
from qwen_agent.tools import BaseTool
from qwen_agent.utils.utils import get_basename_from_url, print_traceback

from ..llm import BaseChatModel
from ..tools import BaseTool
from .fncall_agent import FnCallAgent

KNOWLEDGE_TEMPLATE_ZH = """
# 知识库
KNOWLEDGE_TEMPLATE_ZH = """# 知识库
{knowledge}"""

KNOWLEDGE_TEMPLATE_EN = """
# Knowledge Base
KNOWLEDGE_TEMPLATE_EN = """# Knowledge Base
{knowledge}"""

Expand Down Expand Up @@ -131,7 +126,7 @@ def _prepend_knowledge_prompt(self,

if knowledge_prompt:
if messages[0][ROLE] == SYSTEM:
messages[0][CONTENT] += knowledge_prompt
messages[0][CONTENT] += '\n\n' + knowledge_prompt
else:
messages = [Message(role=SYSTEM, content=knowledge_prompt)] + messages
return messages
Expand Down
77 changes: 77 additions & 0 deletions qwen_agent/agents/dialogue_retrieval_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import datetime
import os
from typing import Iterator, List

from qwen_agent.agents.assistant import Assistant
from qwen_agent.llm.schema import SYSTEM, USER, ContentItem, Message
from qwen_agent.settings import DEFAULT_WORKSPACE
from qwen_agent.utils.utils import extract_text_from_message, save_text_to_file

MAX_TRUNCATED_QUERY_LENGTH = 1000

EXTRACT_QUERY_TEMPLATE_ZH = """<给定文本>
{ref_doc}
上面的文本是包括一段材料和一个用户请求,这个请求一般在最开头或最末尾,请你帮我提取出那个请求,你不需要回答这个请求,只需要打印出用户的请求即可!"""

EXTRACT_QUERY_TEMPLATE_EN = """<Given Text>
{ref_doc}
The text above includes a section of reference material and a user request. The user request is typically located either at the beginning or the end. Please extract that request for me. You do not need to answer the request, just print out the user's request!"""

EXTRACT_QUERY_TEMPLATE = {'zh': EXTRACT_QUERY_TEMPLATE_ZH, 'en': EXTRACT_QUERY_TEMPLATE_EN}


# TODO: merge to retrieval tool
class DialogueRetrievalAgent(Assistant):
"""This is an agent for super long dialogue."""

def _run(self,
messages: List[Message],
lang: str = 'en',
session_id: str = '',
**kwargs) -> Iterator[List[Message]]:
"""Process messages and response
Answer questions by storing the long dialogue in a file
and using the file retrieval approach to retrieve relevant information
"""
assert messages and messages[-1].role == USER
new_messages = []
content = []
for msg in messages[:-1]:
if msg.role == SYSTEM:
new_messages.append(msg)
else:
content.append(f'{msg.role}: {extract_text_from_message(msg, add_upload_info=True)}')
# Process the newest user message
text = extract_text_from_message(messages[-1], add_upload_info=False)
if len(text) <= MAX_TRUNCATED_QUERY_LENGTH:
query = text
else:
if len(text) <= MAX_TRUNCATED_QUERY_LENGTH * 2:
latent_query = text
else:
latent_query = f'{text[:MAX_TRUNCATED_QUERY_LENGTH]} ... {text[-MAX_TRUNCATED_QUERY_LENGTH:]}'

*_, last = self._call_llm(
messages=[Message(role=USER, content=EXTRACT_QUERY_TEMPLATE[lang].format(ref_doc=latent_query))])
query = last[-1].content
# A little tricky: If the extracted query is different from the original query, it cannot be removed
text = text.replace(query, '')
content.append(text)

# Save content as file: This file is related to the session and the time
content = '\n'.join(content)
file_path = os.path.join(DEFAULT_WORKSPACE, f'dialogue_history_{session_id}_{datetime.datetime.now()}.txt')
save_text_to_file(file_path, content)

new_content = [ContentItem(text=query), ContentItem(file=file_path)]
if isinstance(messages[-1].content, list):
for item in messages[-1].content:
if item.file or item.image:
new_content.append(item)
new_messages.append(Message(role=USER, content=new_content))

return super()._run(messages=new_messages, lang=lang, **kwargs)
30 changes: 23 additions & 7 deletions qwen_agent/agents/doc_qa/basic_doc_qa.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import copy
from typing import Dict, Iterator, List, Optional, Union

from qwen_agent.agents.assistant import Assistant
from qwen_agent.llm.base import BaseChatModel
from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, Message
from qwen_agent.prompts import DocQA
from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE, SYSTEM, Message
from qwen_agent.tools import BaseTool

DEFAULT_NAME = 'Basic DocQA'
DEFAULT_DESC = '可以根据问题,检索出知识库中的某个相关细节来回答。适用于需要定位到具体位置的问题,例如“介绍表1”等类型的问题'

PROMPT_TEMPLATE_ZH = """请充分理解以下参考资料内容,组织出满足用户提问的条理清晰的回复。
#参考资料:
{ref_doc}"""

PROMPT_TEMPLATE_EN = """Please fully understand the content of the following reference materials and organize a clear response that meets the user's questions.
# Reference materials:
{ref_doc}"""

PROMPT_TEMPLATE = {
'zh': PROMPT_TEMPLATE_ZH,
'en': PROMPT_TEMPLATE_EN,
}


class BasicDocQA(Assistant):
"""This is an agent for doc QA."""
Expand All @@ -28,16 +41,19 @@ def __init__(self,
description=description,
files=files,
rag_cfg=rag_cfg)
self.doc_qa = DocQA(llm=self.llm)

def _run(self, messages: List[Message], lang: str = 'en', **kwargs) -> Iterator[List[Message]]:
"""This agent using different doc qa prompt with Assistant"""
# Need to use Memory agent for data management
*_, last = self.mem.run(messages=messages, **kwargs)
_ref = last[-1][CONTENT]
knowledge = last[-1][CONTENT]

# Use RetrievalQA agent
# Todo: Prompt engineering
response = self.doc_qa.run(messages=messages, lang=lang, knowledge=_ref)
messages = copy.deepcopy(messages)
system_prompt = PROMPT_TEMPLATE[lang].format(ref_doc=knowledge)
if messages[0][ROLE] == SYSTEM:
messages[0][CONTENT] += '\n\n' + system_prompt
else:
messages.insert(0, Message(SYSTEM, system_prompt))

response = self._call_llm(messages=messages)
return response
Loading

0 comments on commit 7fcde89

Please sign in to comment.