Skip to content

Commit

Permalink
fix(agents): generate proper history blocks in agent (#570)
Browse files Browse the repository at this point in the history
With the inflight changes for streaming, we unfortunately merged a
commit that left message selection in a broken state. The existing
testing was not enough to capture the issue. This PR attempts to restore
proper message selection functionality.

Follow on PRs will be necessary to clean up and streamline the selection
bits added in this PR.

Co-authored-by: Douglas Reid <[email protected]>
  • Loading branch information
douglas-reid and Douglas Reid authored Oct 3, 2023
1 parent f659229 commit b9e32cc
Show file tree
Hide file tree
Showing 13 changed files with 384 additions and 61 deletions.
18 changes: 16 additions & 2 deletions src/steamship/agents/examples/example_assistant.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Type

from pydantic.fields import Field

from steamship.agents.functional import FunctionsBasedAgent
from steamship.agents.llms.openai import ChatOpenAI
from steamship.agents.schema.message_selectors import MessageWindowMessageSelector
from steamship.agents.service.agent_service import AgentService
from steamship.agents.tools.image_generation import DalleTool
from steamship.agents.tools.search import SearchTool
from steamship.invocable import Config
from steamship.utils.repl import AgentREPL


Expand All @@ -13,6 +18,13 @@ class MyFunctionsBasedAssistant(AgentService):
to provide an overview of the types of tasks it can accomplish (here, search
and image generation)."""

class AgentConfig(Config):
model_name: str = Field(default="gpt-4")

@classmethod
def config_cls(cls) -> Type[Config]:
return MyFunctionsBasedAssistant.AgentConfig

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.set_default_agent(
Expand All @@ -21,7 +33,7 @@ def __init__(self, **kwargs):
SearchTool(),
DalleTool(),
],
llm=ChatOpenAI(self.client, temperature=0),
llm=ChatOpenAI(self.client, temperature=0, model_name=self.config.model_name),
message_selector=MessageWindowMessageSelector(k=2),
)
)
Expand All @@ -31,4 +43,6 @@ def __init__(self, **kwargs):
# AgentREPL provides a mechanism for local execution of an AgentService method.
# This is used for simplified debugging as agents and tools are developed and
# added.
AgentREPL(MyFunctionsBasedAssistant, agent_package_config={}).run(dump_history_on_exit=True)
AgentREPL(MyFunctionsBasedAssistant, agent_package_config={"model_name": "gpt-3.5-turbo"}).run(
dump_history_on_exit=True
)
86 changes: 78 additions & 8 deletions src/steamship/agents/functional/functions_based.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
from operator import attrgetter
from typing import List

from steamship import Block
from steamship import Block, MimeTypes, Tag
from steamship.agents.functional.output_parser import FunctionsBasedOutputParser
from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, Tool
from steamship.data.tags.tag_constants import RoleTag
from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, FinishAction, Tool
from steamship.data.tags.tag_constants import ChatTag, RoleTag, TagKind, TagValueKey
from steamship.data.tags.tag_utils import get_tag


class FunctionsBasedAgent(ChatAgent):
Expand Down Expand Up @@ -54,6 +57,8 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]:
# get most recent context
messages_from_memory.extend(context.chat_history.select_messages(self.message_selector))

messages_from_memory.sort(key=attrgetter("index_in_file"))

# de-dupe the messages from memory
ids = [context.chat_history.last_user_message.id]
for msg in messages_from_memory:
Expand All @@ -67,10 +72,8 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]:
# this should happen BEFORE any agent/assistant messages related to tool selection
messages.append(context.chat_history.last_user_message)

# get completed steps
actions = context.completed_steps
for action in actions:
messages.extend(action.to_chat_messages())
# get working history (completed actions)
messages.extend(self._function_calls_since_last_user_message(context))

return messages

Expand All @@ -81,4 +84,71 @@ def next_action(self, context: AgentContext) -> Action:
# Run the default LLM on those messages
output_blocks = self.llm.chat(messages=messages, tools=self.tools)

return self.output_parser.parse(output_blocks[0].text, context)
future_action = self.output_parser.parse(output_blocks[0].text, context)
if not isinstance(future_action, FinishAction):
# record the LLM's function response in history
self._record_action_selection(future_action, context)
return future_action

def _function_calls_since_last_user_message(self, context: AgentContext) -> List[Block]:
function_calls = []
for block in context.chat_history.messages[::-1]: # is this too inefficient at scale?
if block.chat_role == RoleTag.USER:
return reversed(function_calls)
if get_tag(block.tags, kind=TagKind.ROLE, name=RoleTag.FUNCTION):
function_calls.append(block)
elif get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION):
function_calls.append(block)
return reversed(function_calls)

def _to_openai_function_selection(self, action: Action) -> str:
"""NOTE: Temporary placeholder. Should be refactored"""
fc = {"name": action.tool}
args = {}
for block in action.input:
for t in block.tags:
if t.kind == TagKind.FUNCTION_ARG:
args[t.name] = block.as_llm_input(exclude_block_wrapper=True)

fc["arguments"] = json.dumps(args) # the arguments must be a string value NOT a dict
return json.dumps(fc)

def _record_action_selection(self, action: Action, context: AgentContext):
tags = [
Tag(
kind=TagKind.CHAT,
name=ChatTag.ROLE,
value={TagValueKey.STRING_VALUE: RoleTag.ASSISTANT},
),
Tag(kind=TagKind.FUNCTION_SELECTION, name=action.tool),
]
context.chat_history.file.append_block(
text=self._to_openai_function_selection(action), tags=tags, mime_type=MimeTypes.TXT
)

def record_action_run(self, action: Action, context: AgentContext):
super().record_action_run(action, context)

if isinstance(action, FinishAction):
return

tags = [
Tag(
kind=TagKind.ROLE,
name=RoleTag.FUNCTION,
value={TagValueKey.STRING_VALUE: action.tool},
),
# need the following tag for backwards compatibility with older gpt-4 plugin
Tag(
kind="name",
name=action.tool,
),
]
# TODO(dougreid): I'm not convinced this is correct for tools that return multiple values.
# It _feels_ like these should be named and inlined as a single message in history, etc.
for block in action.output:
context.chat_history.file.append_block(
text=block.as_llm_input(exclude_block_wrapper=True),
tags=tags,
mime_type=block.mime_type,
)
43 changes: 36 additions & 7 deletions src/steamship/agents/functional/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from json import JSONDecodeError
from typing import Dict, List, Optional

from steamship import Block, MimeTypes, Steamship
from steamship import Block, MimeTypes, Steamship, Tag
from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool
from steamship.data.tags.tag_constants import RoleTag
from steamship.data.tags.tag_constants import RoleTag, TagKind
from steamship.utils.utils import is_valid_uuid4


Expand Down Expand Up @@ -43,16 +43,45 @@ def _extract_action_from_function_call(self, text: str, context: AgentContext) -
try:
args = json.loads(arguments)
if text := args.get("text"):
input_blocks.append(Block(text=text, mime_type=MimeTypes.TXT))
input_blocks.append(
Block(
text=text,
tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")],
mime_type=MimeTypes.TXT,
)
)
elif uuid_arg := args.get("uuid"):
input_blocks.append(Block.get(context.client, _id=uuid_arg))
existing_block = Block.get(context.client, _id=uuid_arg)
tag = Tag.create(
existing_block.client,
file_id=existing_block.file_id,
block_id=existing_block.id,
kind=TagKind.FUNCTION_ARG,
name="uuid",
)
existing_block.tags.append(tag)
input_blocks.append(existing_block)
except json.decoder.JSONDecodeError:
if isinstance(arguments, str):
if is_valid_uuid4(arguments):
input_blocks.append(Block.get(context.client, _id=uuid_arg))
existing_block = Block.get(context.client, _id=arguments)
tag = Tag.create(
existing_block.client,
file_id=existing_block.file_id,
block_id=existing_block.id,
kind=TagKind.FUNCTION_ARG,
name="uuid",
)
existing_block.tags.append(tag)
input_blocks.append(existing_block)
else:
input_blocks.append(Block(text=arguments, mime_type=MimeTypes.TXT))

input_blocks.append(
Block(
text=arguments,
tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")],
mime_type=MimeTypes.TXT,
)
)
return Action(tool=tool.name, input=input_blocks, context=context)

@staticmethod
Expand Down
52 changes: 30 additions & 22 deletions src/steamship/agents/schema/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from pydantic import BaseModel

from steamship import Block, Tag
from steamship.data import TagKind
from steamship.data.tags.tag_constants import RoleTag
from steamship import Block


class Action(BaseModel):
Expand All @@ -28,25 +26,35 @@ class Action(BaseModel):
Setting this to True means that the executing Agent should halt any reasoning.
"""

def to_chat_messages(self) -> List[Block]:
tags = [
Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION),
Tag(kind="name", name=self.tool),
]
blocks = []
for block in self.output:
# TODO(dougreid): should we revisit as_llm_input? we might need only the UUID...
blocks.append(
Block(
text=block.as_llm_input(exclude_block_wrapper=True),
tags=tags,
mime_type=block.mime_type,
)
)

# TODO(dougreid): revisit when have multiple output functions.
# Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation.
return blocks
# def to_chat_messages(self) -> List[Block]:
# blocks = []
# for arg in self.input:
#
#
# blocks.append(
# Block(
# text=json.dumps({"name": f"{self.tool}", "arguments": "{ \"text\": \"who is the current president of Taiwan?\" }"}),
# )
# )
#
# tags = [
# Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION),
# Tag(kind="name", name=self.tool),
# ]
#
# for block in self.output:
# # TODO(dougreid): should we revisit as_llm_input? we might need only the UUID...
# blocks.append(
# Block(
# text=block.as_llm_input(exclude_block_wrapper=True),
# tags=tags,
# mime_type=block.mime_type,
# )
# )
#
# # TODO(dougreid): revisit when have multiple output functions.
# # Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation.
# return blocks


class FinishAction(Action):
Expand Down
4 changes: 4 additions & 0 deletions src/steamship/agents/schema/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class Agent(BaseModel, ABC):
def next_action(self, context: AgentContext) -> Action:
pass

def record_action_run(self, action: Action, context: AgentContext):
# TODO(dougreid): should this method (or just bit) actually be on AgentContext?
context.completed_steps.append(action)


class LLMAgent(Agent):
"""LLMAgents choose next actions for an AgentService based on interactions with an LLM."""
Expand Down
6 changes: 5 additions & 1 deletion src/steamship/agents/schema/chathistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ def append_message_with_role(
text=text, tags=tags, content=content, url=url, mime_type=mime_type
)
# don't index status messages
if self.embedding_index is not None and role is not RoleTag.AGENT:
if self.embedding_index is not None and role not in [
RoleTag.AGENT,
RoleTag.TOOL,
RoleTag.LLM,
]:
chunk_tags = self.text_splitter.chunk_text_to_tags(
block, kind=TagKind.CHAT, name=ChatTag.CHUNK
)
Expand Down
55 changes: 44 additions & 11 deletions src/steamship/agents/schema/message_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pydantic.main import BaseModel

from steamship import Block
from steamship.data.tags.tag_constants import RoleTag
from steamship.data.tags.tag_constants import RoleTag, TagKind
from steamship.data.tags.tag_utils import get_tag


class MessageSelector(BaseModel, ABC):
Expand All @@ -29,23 +30,53 @@ def is_assistant_message(block: Block) -> bool:
return role == RoleTag.ASSISTANT


def is_function_message(block: Block) -> bool:
is_function_selection = get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION)
return is_function_selection


def is_tool_function_message(block: Block) -> bool:
is_function_call = get_tag(block.tags, kind=TagKind.ROLE, name=RoleTag.FUNCTION)
return is_function_call


def is_user_history_message(block: Block) -> bool:
return is_user_message(block) or (
is_assistant_message(block) and not is_function_message(block)
)


class MessageWindowMessageSelector(MessageSelector):
k: int

def get_messages(self, messages: List[Block]) -> List[Block]:
msgs = messages[:]
msgs.pop() # don't add the current prompt to the memory
if len(msgs) <= (self.k * 2):
return msgs

# msgs.pop()
have_seen_user_message = False
if is_user_message(msgs[-1]):
have_seen_user_message = True
msgs.pop() # don't add the current prompt to the memory
selected_msgs = []
conversation_messages = 0
limit = self.k * 2
scope = msgs[len(messages) - limit :]
for block in scope:
if is_user_message(block) or is_assistant_message(block):
message_index = len(msgs) - 1
while (conversation_messages < limit) and (message_index > 0):
# TODO(dougreid): i _think_ we don't need the function return if we have a user-assistant pair
# but, for safety here, we try to add non-current function blocks from past iterations.
block = msgs[message_index]
if is_user_message(block):
have_seen_user_message = True
if is_user_history_message(block):
selected_msgs.append(block)
conversation_messages += 1
elif have_seen_user_message and (
is_function_message(block) or is_tool_function_message(block)
):
# conditionally append working function call messages
selected_msgs.append(block)
message_index -= 1

return selected_msgs
return reversed(selected_msgs)


def tokens(block: Block) -> int:
Expand All @@ -62,9 +93,11 @@ def get_messages(self, messages: List[Block]) -> List[Block]:
current_tokens = 0

msgs = messages[:]
msgs.pop() # don't add the current prompt to the memory
if is_user_message(msgs[-1]):
msgs.pop() # don't add the current prompt to the memory

for block in reversed(msgs):
if block.chat_role != RoleTag.SYSTEM and current_tokens < self.max_tokens:
if is_user_history_message(block) and current_tokens < self.max_tokens:
block_tokens = tokens(block)
if block_tokens + current_tokens < self.max_tokens:
selected_messages.append(block)
Expand Down
Loading

0 comments on commit b9e32cc

Please sign in to comment.