Skip to content

Commit

Permalink
Update doc and output parser (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
yyiilluu committed Jun 27, 2023
1 parent bd09265 commit 5423c7e
Show file tree
Hide file tree
Showing 18 changed files with 228 additions and 176 deletions.
21 changes: 15 additions & 6 deletions autochain/agent/conversational_agent/conversational_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
CLARIFYING_QUESTION_PROMPT,
PLANNING_PROMPT,
)
from autochain.agent.message import BaseMessage
from autochain.agent.message import BaseMessage, ChatMessageHistory
from autochain.agent.prompt_formatter import JSONPromptTemplate
from autochain.agent.structs import AgentAction, AgentFinish
from autochain.models.base import BaseLanguageModel, Generation
Expand Down Expand Up @@ -64,7 +64,7 @@ def from_llm_and_tools(
)

@staticmethod
def get_final_prompt(
def format_prompt(
template: JSONPromptTemplate,
intermediate_steps: List[AgentAction],
**kwargs: Any,
Expand Down Expand Up @@ -105,7 +105,10 @@ def get_prompt_template(
return JSONPromptTemplate(template=template, input_variables=input_variables)

def plan(
self, intermediate_steps: List[AgentAction], **kwargs: Any
self,
history: ChatMessageHistory,
intermediate_steps: List[AgentAction],
**kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""
Plan the next step. either taking an action with AgentAction or respond to user with AgentFinish
Expand All @@ -121,8 +124,11 @@ def plan(
tool_strings = "\n\n".join(
[f"> {tool.name}: \n{tool.description}" for tool in self.tools]
)
inputs = {"tool_names": tool_names, "tools": tool_strings, **kwargs}
final_prompt = self.get_final_prompt(
inputs = {"tool_names": tool_names,
"tools": tool_strings,
"history": history.format_message(),
**kwargs}
final_prompt = self.format_prompt(
self.prompt_template, intermediate_steps, **inputs
)
logger.info(f"\nFull Input: {final_prompt[0].content} \n")
Expand All @@ -145,6 +151,7 @@ def plan(
def clarify_args_for_agent_action(
self,
agent_action: AgentAction,
history: ChatMessageHistory,
intermediate_steps: List[AgentAction],
**kwargs: Any,
):
Expand All @@ -156,6 +163,7 @@ def clarify_args_for_agent_action(
Args:
agent_action: agent action about to take
history: conversation history including the latest query
intermediate_steps: list of agent action taken so far
**kwargs:
Expand All @@ -169,14 +177,15 @@ def clarify_args_for_agent_action(
inputs = {
"tool_name": agent_action.tool,
"tool_desp": self.allowed_tools.get(agent_action.tool).description,
"history": history.format_message(),
**kwargs,
}

clarifying_template = self.get_prompt_template(
prompt=CLARIFYING_QUESTION_PROMPT
)

final_prompt = self.get_final_prompt(
final_prompt = self.format_prompt(
clarifying_template, intermediate_steps, **inputs
)
logger.info(f"\nClarification inputs: {final_prompt[0].content}")
Expand Down
18 changes: 3 additions & 15 deletions autochain/agent/conversational_agent/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@

class ConvoJSONOutputParser(AgentOutputParser):
def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]:
text = message.content
try:
clean_text = text[text.index("{") : text.rindex("}") + 1].strip()
response = json.loads(clean_text)
except Exception:
raise OutputParserException(f"Not a valid json: `{text}`")
response = self.load_json_output(message)

action_name = response.get("tool", {}).get("name")
action_args = response.get("tool", {}).get("args")
Expand All @@ -39,17 +34,10 @@ def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]:
model_response=response.get("response", ""),
)

@staticmethod
def parse_clarification(
message: BaseMessage, agent_action: AgentAction
self, message: BaseMessage, agent_action: AgentAction
) -> Union[AgentAction, AgentFinish]:
text = message.content
try:
clean_text = text[text.index("{") : text.rindex("}") + 1].strip()
response = json.loads(clean_text)
print_with_color(f"Full clarification output: {response}", Fore.YELLOW)
except Exception:
raise OutputParserException(f"Not a valid json: `{text}`")
response = self.load_json_output(message)

has_arg_value = response.get("has_arg_value", "")
clarifying_question = response.get("clarifying_question", "")
Expand Down
2 changes: 1 addition & 1 deletion autochain/agent/conversational_agent/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"arg_name": "arg value from conversation history or observation to run tool"
}
},
"response": "clarifying required args for that tool or response to user. this cannot be empty",
"response": "Response to user",
}
Ensure the response can be parsed by Python json.loads
Expand Down
20 changes: 0 additions & 20 deletions autochain/agent/openai_funtions_agent/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,3 @@ def parse(self, message: AIMessage) -> Union[AgentAction, AgentFinish]:
)
else:
return AgentFinish(message=message.content, log=message.content)

@staticmethod
def parse_clarification(
message: BaseMessage, agent_action: AgentAction
) -> Union[AgentAction, AgentFinish]:
text = message.content
try:
clean_text = text[text.index("{") : text.rindex("}") + 1].strip()
response = json.loads(clean_text)
print_with_color(f"Full clarification output: {response}", Fore.YELLOW)
except Exception:
raise OutputParserException(f"Not a valid json: `{text}`")

has_arg_value = response.get("has_arg_value", "")
clarifying_question = response.get("clarifying_question", "")

if "no" in has_arg_value.lower() and clarifying_question:
return AgentFinish(message=clarifying_question, log=clarifying_question)
else:
return agent_action
8 changes: 0 additions & 8 deletions autochain/agent/openai_funtions_agent/prompt.py

This file was deleted.

17 changes: 15 additions & 2 deletions autochain/agent/structs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from abc import abstractmethod
from typing import Union, Any, Dict, List

from autochain.errors import OutputParserException
from pydantic import BaseModel

from autochain.agent.message import BaseMessage
Expand Down Expand Up @@ -50,13 +52,24 @@ def format_output(self) -> Dict[str, Any]:


class AgentOutputParser(BaseModel):
@staticmethod
def load_json_output(message: BaseMessage) -> Dict[str, Any]:
"""If the message contains a json response, try to parse it into dictionary"""
text = message.content
try:
clean_text = text[text.index("{") : text.rindex("}") + 1].strip()
response = json.loads(clean_text)
except Exception:
raise OutputParserException(f"Not a valid json: `{text}`")

return response

@abstractmethod
def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]:
"""Parse text into agent action/finish."""

@staticmethod
def parse_clarification(
message: BaseMessage, agent_action: AgentAction
self, message: BaseMessage, agent_action: AgentAction
) -> Union[AgentAction, AgentFinish]:
"""Parse clarification outputs"""
return agent_action
23 changes: 3 additions & 20 deletions autochain/agent/support_agent/output_parser.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
import json
from typing import Union

from colorama import Fore

from autochain.agent.message import BaseMessage
from autochain.errors import OutputParserException
from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser
from autochain.tools.simple_handoff.tool import HandOffToAgent
from autochain.utils import print_with_color


class SupportJSONOutputParser(AgentOutputParser):
def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]:
text = message.content
try:
clean_text = text[text.index("{") : text.rindex("}") + 1].strip()
response = json.loads(clean_text)
except Exception:
raise OutputParserException(f"Not a valid json: `{text}`")
response = self.load_json_output(message)

handoff_action = HandOffToAgent()
action_name = response.get("tool", {}).get("name")
Expand Down Expand Up @@ -51,17 +41,10 @@ def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]:
model_response=response.get("response", ""),
)

@staticmethod
def parse_clarification(
message: BaseMessage, agent_action: AgentAction
self, message: BaseMessage, agent_action: AgentAction
) -> Union[AgentAction, AgentFinish]:
text = message.content
try:
clean_text = text[text.index("{") : text.rindex("}") + 1].strip()
response = json.loads(clean_text)
print_with_color(f"Full clarification output: {response}", Fore.YELLOW)
except Exception:
raise OutputParserException(f"Not a valid json: `{text}`")
response = self.load_json_output(message)

has_arg_value = response.get("has_arg_value", "")
clarifying_question = response.get("clarifying_question", "")
Expand Down
4 changes: 3 additions & 1 deletion autochain/agent/support_agent/support_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def plan(
AgentAction or AgentFinish
"""
print_with_color("Planning", Fore.LIGHTYELLOW_EX)

tool_names = ", ".join([tool.name for tool in self.tools])
tool_strings = "\n\n".join(
[f"> {tool.name}: \n{tool.description}" for tool in self.tools]
Expand All @@ -177,7 +178,8 @@ def plan(
)

print_with_color(
f"Full output: {json.loads(full_output.message.content)}", Fore.YELLOW
f"Full planning output: \n{json.loads(full_output.message.content)}",
Fore.YELLOW,
)
if isinstance(agent_output, AgentAction):
print_with_color(
Expand Down
10 changes: 2 additions & 8 deletions autochain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,12 @@ class Tool(ABC, BaseModel):
"""

arg_description: Optional[Dict[str, Any]] = None
"""Dictionary of arg name and description when using OpenAIFunctionAgent to provide
additional argument information"""

args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""

return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means
that after the tool is called, the AgentExecutor will stop looping.
"""
verbose: bool = False
"""Whether to log the tool's progress."""

func: Union[Callable[..., str], None] = None

@root_validator()
Expand Down
37 changes: 18 additions & 19 deletions autochain/workflows_evaluation/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from autochain.models.chat_openai import ChatOpenAI
from autochain.tools.base import Tool
from autochain.utils import print_with_color
from autochain.workflows_evaluation.test_utils import parse_evaluation_response


@dataclass
Expand Down Expand Up @@ -134,18 +135,17 @@ def determine_if_conversation_ends(self, last_utterance: str) -> bool:
UserMessage(
content=f"""The most recent reply from assistant
assistant: "{last_utterance}"
Is assistant asking a clarifying question or getting additional information from user? answer with
yes or no"""
Has assistant finish assisting the user? Answer with yes or no"""
),
]
output: Generation = self.llm.generate(messages=messages).generations[0]

if "yes" in output.message.content.lower():
# this is a clarifying question
return False
else:
# conversation should end
# finish assisting; conversation should end
return True
else:
# not yet finished; conversation should continue
return False

def get_next_user_query(
self, conversation_history: List[Tuple[str, str]], user_context: str
Expand Down Expand Up @@ -175,28 +175,27 @@ def get_next_user_query(

def determine_if_agent_solved_problem(
self, conversation_history: List[Tuple[str, str]], expected_outcome: str
) -> (bool, str):
) -> Dict[str, str]:
messages = []
conversation = ""
for user_type, utterance in conversation_history:
conversation += f"{user_type}: {utterance}\n"

messages.append(
UserMessage(
content=f"""Previous conversation:
content=f"""You are an admin for assistant and check if assistant meets the expected outcome based on previous conversation.
Previous conversation:
{conversation}
Expected outcome is {expected_outcome}
Does conversation reach the expected outcome for user? Answer with yes or no with explanation"""
Expected outcome is "{expected_outcome}"
Does conversation reach the expected outcome for user? answer in JSON format
{{
"reason": "explain step by step if conversation reaches the expected outcome",
"rating": "rating from 1 to 5; 1 for not meeting the expected outcome at all, 5 for completely meeting the expected outcome",
}}"""
)
)

output: Generation = self.llm.generate(
messages=messages, stop=["."]
).generations[0]
if "yes" in output.message.content.lower():
# Agent solved the problem
return True, output.message.content
else:
# did not solve the problem
return False, output.message.content
output: Generation = self.llm.generate(messages=messages).generations[0]
return parse_evaluation_response(output.message)
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def change_shipping_address(order_id: str, new_address: str, **kwargs):
}


class TestChangeShippingAddress(BaseTest):
class TestChangeShippingAddressWithFunctionCalling(BaseTest):
policy = """You are an AI assistant for customer support for the company Figs which sells nurse and medical staff clothes.
When a customer requests to change their shipping address, verify the order status in the system based on order id.
If the order has not yet shipped, update the shipping address as requested and confirm with the customer that it has been updated.
Expand Down Expand Up @@ -98,7 +98,7 @@ class TestChangeShippingAddress(BaseTest):

if __name__ == "__main__":
tester = WorkflowTester(
tests=[TestChangeShippingAddress()], output_dir="./test_results"
tests=[TestChangeShippingAddressWithFunctionCalling()], output_dir="./test_results"
)

args = get_test_args()
Expand Down
18 changes: 17 additions & 1 deletion autochain/workflows_evaluation/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import argparse
import logging
from typing import List, Optional
from typing import List, Optional, Dict

from autochain.agent.structs import AgentOutputParser

from autochain.agent.message import BaseMessage

from autochain.agent.support_agent.support_agent import SupportAgent
from autochain.chain.chain import Chain
Expand Down Expand Up @@ -54,3 +58,15 @@ def create_chain_from_test(
memory = memory or BufferMemory()
agent = agent_cls.from_llm_and_tools(llm=llm, tools=tools, **kwargs)
return Chain(agent=agent, memory=memory)


def parse_evaluation_response(message: BaseMessage) -> Dict[str, str]:
"""
Parse the reason and rating from the call to determine if the conversation reaches the
expected outcome
"""
response = AgentOutputParser.load_json_output(message)
return {
"rating": response.get("rating"),
"reason": response.get("reason"),
}
Loading

0 comments on commit 5423c7e

Please sign in to comment.