Skip to content

Commit

Permalink
Introduce confidence to autochain (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
yyiilluu authored Nov 29, 2023
1 parent 6c756ee commit 5a1203b
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 39 deletions.
12 changes: 10 additions & 2 deletions autochain/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from string import Template
from typing import Any, List, Optional, Sequence, Union

from pydantic import BaseModel, Extra

from autochain.agent.message import ChatMessageHistory
from autochain.agent.prompt_formatter import JSONPromptTemplate
from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser
from autochain.models.base import BaseLanguageModel
from autochain.tools.base import Tool
from pydantic import BaseModel


class BaseAgent(BaseModel, ABC):
Expand Down Expand Up @@ -104,3 +103,12 @@ def get_prompt_template(
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return JSONPromptTemplate(template=template, input_variables=input_variables)

def is_generation_confident(
self,
history: ChatMessageHistory,
agent_output: Union[AgentAction, AgentFinish],
min_confidence: int = 3,
) -> bool:
"""Check if the generation is confident enough to take action"""
return True
15 changes: 13 additions & 2 deletions autochain/agent/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum
from abc import abstractmethod
from typing import List, Any, Dict
from typing import Any, Dict, List

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -60,6 +60,7 @@ class FunctionMessage(BaseMessage):
"""Type of message that is a function message."""

name: str
conversational_message: str = ""

@property
def type(self) -> str:
Expand All @@ -76,14 +77,24 @@ def save_message(self, message: str, message_type: MessageType, **kwargs):
elif message_type == MessageType.UserMessage:
self.messages.append(UserMessage(content=message))
elif message_type == MessageType.FunctionMessage:
self.messages.append(FunctionMessage(content=message, name=kwargs["name"]))
self.messages.append(
FunctionMessage(
content=message,
name=kwargs["name"],
conversational_message=kwargs["conversational_message"],
)
)
elif message_type == MessageType.SystemMessage:
self.messages.append(SystemMessage(content=message))

def format_message(self):
string_messages = []
if len(self.messages) > 0:
for m in self.messages:
if isinstance(m, FunctionMessage):
string_messages.append(f"Action: {m.conversational_message}")
continue

if isinstance(m, UserMessage):
role = "User"
elif isinstance(m, AIMessage):
Expand Down
110 changes: 86 additions & 24 deletions autochain/agent/openai_functions_agent/openai_functions_agent.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from __future__ import annotations

import logging
from string import Template
from typing import Any, Dict, List, Optional, Union

from colorama import Fore

from autochain.agent.base_agent import BaseAgent
from autochain.agent.message import ChatMessageHistory, SystemMessage
from autochain.agent.message import ChatMessageHistory, SystemMessage, UserMessage
from autochain.agent.openai_functions_agent.output_parser import (
OpenAIFunctionOutputParser,
)
from autochain.agent.openai_functions_agent.prompt import ESTIMATE_CONFIDENCE_PROMPT
from autochain.agent.structs import AgentAction, AgentFinish
from autochain.models.base import BaseLanguageModel, Generation
from autochain.tools.base import Tool
from autochain.utils import print_with_color
from colorama import Fore

logger = logging.getLogger(__name__)

Expand All @@ -30,6 +31,7 @@ class OpenAIFunctionsAgent(BaseAgent):
allowed_tools: Dict[str, Tool] = {}
tools: List[Tool] = []
prompt: Optional[str] = None
min_confidence: int = 3

@classmethod
def from_llm_and_tools(
Expand All @@ -38,6 +40,7 @@ def from_llm_and_tools(
tools: Optional[List[Tool]] = None,
output_parser: Optional[OpenAIFunctionOutputParser] = None,
prompt: str = None,
min_confidence: int = 3,
**kwargs: Any,
) -> OpenAIFunctionsAgent:
tools = tools or []
Expand All @@ -50,39 +53,98 @@ def from_llm_and_tools(
output_parser=_output_parser,
tools=tools,
prompt=prompt,
min_confidence=min_confidence,
**kwargs,
)

def plan(
self,
history: ChatMessageHistory,
intermediate_steps: List[AgentAction],
retries: int = 2,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
print_with_color("Planning", Fore.LIGHTYELLOW_EX)
while retries > 0:
print_with_color("Planning", Fore.LIGHTYELLOW_EX)

final_messages = []
if self.prompt:
final_messages.append(SystemMessage(content=self.prompt))
final_messages += history.messages
final_messages = []
if self.prompt:
final_messages.append(SystemMessage(content=self.prompt))
final_messages += history.messages

logger.info(f"\nPlanning Input: {[m.content for m in final_messages]} \n")
full_output: Generation = self.llm.generate(
final_messages, self.tools
).generations[0]
logger.info(f"\nPlanning Input: {[m.content for m in final_messages]} \n")
full_output: Generation = self.llm.generate(
final_messages, self.tools
).generations[0]

agent_output: Union[AgentAction, AgentFinish] = self.output_parser.parse(
full_output.message
agent_output: Union[AgentAction, AgentFinish] = self.output_parser.parse(
full_output.message
)
print(
f"Planning output: \nmessage content: {repr(full_output.message.content)}; "
f"function_call: "
f"{repr(full_output.message.function_call)}",
Fore.YELLOW,
)
if isinstance(agent_output, AgentAction):
print_with_color(
f"Plan to take action '{agent_output.tool}'", Fore.LIGHTYELLOW_EX
)

generation_is_confident = self.is_generation_confident(
history=history,
agent_output=agent_output,
min_confidence=self.min_confidence,
)
if not generation_is_confident:
retries -= 1
print_with_color(
f"Generation is not confident, {retries} retries left",
Fore.LIGHTYELLOW_EX,
)
continue
else:
return agent_output

def is_generation_confident(
self,
history: ChatMessageHistory,
agent_output: Union[AgentAction, AgentFinish],
min_confidence: int = 3,
) -> bool:
"""
Estimate the confidence of the generation
Args:
history: history of the conversation
agent_output: the output from the agent
min_confidence: minimum confidence score to be considered as confident
"""

def _format_assistant_message(action_output: Union[AgentAction, AgentFinish]):
if isinstance(action_output, AgentFinish):
assistant_message = f"Assistant: {action_output.message}"
elif isinstance(action_output, AgentAction):
assistant_message = f"Action: {action_output.tool} with input: {action_output.tool_input}"
else:
raise ValueError("Unsupported action for estimating confidence score")

return assistant_message

prompt = Template(ESTIMATE_CONFIDENCE_PROMPT).substitute(
policy=self.prompt,
conversation_history=history.format_message(),
assistant_message=_format_assistant_message(agent_output),
)
print(
f"Planning output: \nmessage content: {repr(full_output.message.content)}; "
f"function_call: "
f"{repr(full_output.message.function_call)}",
Fore.YELLOW,
logger.info(f"\nEstimate confidence prompt: {prompt} \n")

message = UserMessage(content=prompt)

full_output: Generation = self.llm.generate([message], self.tools).generations[
0
]

estimated_confidence = self.output_parser.parse_estimated_confidence(
full_output.message
)
if isinstance(agent_output, AgentAction):
print_with_color(
f"Plan to take action '{agent_output.tool}'", Fore.LIGHTYELLOW_EX
)

return agent_output
return estimated_confidence >= min_confidence
27 changes: 27 additions & 0 deletions autochain/agent/openai_functions_agent/output_parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import json
import logging
import re
from typing import Union

from autochain.agent.message import AIMessage
from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser

logger = logging.getLogger(__name__)


class OpenAIFunctionOutputParser(AgentOutputParser):
def parse(self, message: AIMessage) -> Union[AgentAction, AgentFinish]:
Expand All @@ -18,3 +22,26 @@ def parse(self, message: AIMessage) -> Union[AgentAction, AgentFinish]:
)
else:
return AgentFinish(message=message.content, log=message.content)

def parse_estimated_confidence(self, message: AIMessage) -> int:
"""Parse estimated confidence from the message"""

def find_first_integer(input_string):
# Define a regular expression pattern to match integers
pattern = re.compile(r"\d+")

# Search for the first match in the input string
match = pattern.search(input_string)

# Check if a match is found
if match:
# Extract and return the matched integer
return int(match.group())
else:
# Return 0 if no integer is found
logger.info(f"\nCannot find confidence in message: {input_string}\n")
return 0

content = message.content.strip()

return find_first_integer(content)
11 changes: 11 additions & 0 deletions autochain/agent/openai_functions_agent/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
ESTIMATE_CONFIDENCE_PROMPT = """Given the system policy assistant needs to strictly follow and
the conversation history between user and assistant so far,
"System policy: ${policy}
${conversation_history}"
How confident are you the next step from assistant should be the following:
"${assistant_message}"
Estimate the confidence from 1-5, 1 being the least confident and 5 being the most confident.
Confidence:
"""
14 changes: 7 additions & 7 deletions autochain/agent/structs.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import json
import re
from abc import abstractmethod
from typing import Union, Any, Dict, List
from typing import Any, Dict, List, Union

from autochain.agent.message import BaseMessage, UserMessage
from autochain.chain import constants
from autochain.models.base import Generation

from autochain.models.chat_openai import ChatOpenAI
from pydantic import BaseModel

from autochain.agent.message import BaseMessage, UserMessage
from autochain.chain import constants
from autochain.errors import OutputParserException


class AgentAction(BaseModel):
"""Agent's action to take."""
Expand Down Expand Up @@ -89,3 +85,7 @@ def parse_clarification(
) -> Union[AgentAction, AgentFinish]:
"""Parse clarification outputs"""
return agent_action

def parse_estimated_confidence(self, message: BaseMessage) -> int:
"""Parse estimated confidence from the message"""
return 1
2 changes: 2 additions & 0 deletions autochain/chain/base_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def _run(
self.memory.save_conversation(
message=str(next_step_output.tool_output),
name=next_step_output.tool,
conversational_message=f"{next_step_output.tool} with input: "
f"{next_step_output.tool_input}",
message_type=MessageType.FunctionMessage,
)

Expand Down
2 changes: 1 addition & 1 deletion autochain/chain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def handle_repeated_action(self, agent_action: AgentAction) -> AgentFinish:
print("No response from agent. Gracefully exit due to repeated action")
return AgentFinish(
message=self.graceful_exit_tool.run(),
log=f"Gracefully exit due to repeated action",
log="Gracefully exit due to repeated action",
)

def take_next_step(
Expand Down
Loading

0 comments on commit 5a1203b

Please sign in to comment.