Skip to content

Commit

Permalink
Added backwards compatability for yield_messages parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed May 9, 2024
1 parent 9945078 commit 2834240
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 11 deletions.
26 changes: 20 additions & 6 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,25 @@ def get_completion(self, message: str,
Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread.
"""
if yield_messages:
print("Warning: yield_messages parameter is deprecated. Use streaming instead.")
print("Warning: yield_messages parameter will be deprecated soon. Use streaming instead.")

return self.main_thread.get_completion(message=message,
res = self.main_thread.get_completion(message=message,
message_files=message_files,
attachments=attachments,
recipient_agent=recipient_agent,
additional_instructions=additional_instructions,
tool_choice=tool_choice)
tool_choice=tool_choice,
yield_messages=yield_messages)

if not yield_messages:
while True:
try:
next(res)
except StopIteration as e:
return e.value

return res


def get_completion_stream(self,
message: str,
Expand Down Expand Up @@ -178,9 +189,12 @@ def get_completion_stream(self,
tool_choice=tool_choice
)

event_handler.on_all_streams_end()

return res
while True:
try:
next(res)
except StopIteration as e:
event_handler.on_all_streams_end()
return e.value

def demo_gradio(self, height=450, dark_mode=True, **kwargs):
"""
Expand Down
6 changes: 3 additions & 3 deletions agency_swarm/messages/message_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from rich.live import Live

console = Console()
live_display = Live()

class MessageOutput:
def __init__(self, msg_type: Literal["function", "function_output", "text", "system"], sender_name: str,
Expand Down Expand Up @@ -110,8 +109,9 @@ def __init__(self, msg_type: Literal["function", "function_output", "text", "sys
console.rule()

def __del__(self):
self.live_display.stop()
self.live_display = None
if self.live_display:
self.live_display.stop()
self.live_display = None

def cprint_update(self, snapshot):
"""
Expand Down
20 changes: 20 additions & 0 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def get_completion(self,
attachments=attachments
)

if yield_messages:
yield MessageOutput("text", self.agent.name, recipient_agent.name, message)

self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice)

error_attempts = 0
Expand All @@ -126,13 +129,23 @@ def get_completion(self,
tool_outputs = []
tool_names = []
for tool_call in tool_calls:
if yield_messages:
yield MessageOutput("function", recipient_agent.name, self.agent.name,
str(tool_call.function))

output = self.execute_tool(tool_call, recipient_agent, event_handler, tool_names)
if inspect.isgenerator(output):
try:
while True:
item = next(output)
if isinstance(item, MessageOutput) and yield_messages:
yield item
except StopIteration as e:
output = e.value
else:
if yield_messages:
yield MessageOutput("function_output", tool_call.function.name, recipient_agent.name,
output)
if event_handler:
event_handler.agent_name = self.agent.name
event_handler.recipient_agent_name = recipient_agent.name
Expand Down Expand Up @@ -199,6 +212,9 @@ def get_completion(self,
else:
full_message += self._get_last_message_text()

if yield_messages:
yield MessageOutput("text", recipient_agent.name, self.agent.name, full_message)

if recipient_agent.response_validator:
try:
if isinstance(recipient_agent, Agent):
Expand All @@ -211,6 +227,10 @@ def get_completion(self,
content=str(e),
)

if yield_messages:
yield MessageOutput("text", self.agent.name, recipient_agent.name,
message.content[0].text.value)

if event_handler:
handler = event_handler()
handler.on_message_created(message)
Expand Down
9 changes: 7 additions & 2 deletions agency_swarm/threads/thread_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@ def worker(self,
additional_instructions: str = None,
tool_choice: AssistantToolChoice = None
):
output = super().get_completion(message=message,
gen = super().get_completion(message=message,
message_files=message_files,
attachments=attachments,
recipient_agent=recipient_agent,
additional_instructions=additional_instructions,
tool_choice=tool_choice)

self.response = f"""{self.recipient_agent.name}'s Response: '{output}'"""
while True:
try:
next(gen)
except StopIteration as e:
self.response = f"""{self.recipient_agent.name}'s Response: '{e.value}'"""
break

return

Expand Down

0 comments on commit 2834240

Please sign in to comment.