diff --git a/agency_swarm/agency/agency.py b/agency_swarm/agency/agency.py index 9aef2e84..dfb120a6 100644 --- a/agency_swarm/agency/agency.py +++ b/agency_swarm/agency/agency.py @@ -133,14 +133,25 @@ def get_completion(self, message: str, Generator or final response: Depending on the 'yield_messages' flag, this method returns either a generator yielding intermediate messages or the final response from the main thread. """ if yield_messages: - print("Warning: yield_messages parameter is deprecated. Use streaming instead.") + print("Warning: yield_messages parameter will be deprecated soon. Use streaming instead.") - return self.main_thread.get_completion(message=message, + res = self.main_thread.get_completion(message=message, message_files=message_files, attachments=attachments, recipient_agent=recipient_agent, additional_instructions=additional_instructions, - tool_choice=tool_choice) + tool_choice=tool_choice, + yield_messages=yield_messages) + + if not yield_messages: + while True: + try: + next(res) + except StopIteration as e: + return e.value + + return res + def get_completion_stream(self, message: str, @@ -178,9 +189,12 @@ def get_completion_stream(self, tool_choice=tool_choice ) - event_handler.on_all_streams_end() - - return res + while True: + try: + next(res) + except StopIteration as e: + event_handler.on_all_streams_end() + return e.value def demo_gradio(self, height=450, dark_mode=True, **kwargs): """ diff --git a/agency_swarm/messages/message_output.py b/agency_swarm/messages/message_output.py index c97e816b..f6ae0ab3 100644 --- a/agency_swarm/messages/message_output.py +++ b/agency_swarm/messages/message_output.py @@ -5,7 +5,6 @@ from rich.live import Live console = Console() -live_display = Live() class MessageOutput: def __init__(self, msg_type: Literal["function", "function_output", "text", "system"], sender_name: str, @@ -110,8 +109,9 @@ def __init__(self, msg_type: Literal["function", "function_output", "text", "sys console.rule() def __del__(self): - self.live_display.stop() - self.live_display = None + if self.live_display: + self.live_display.stop() + self.live_display = None def cprint_update(self, snapshot): """ diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 5af2d013..a7ee18f4 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -112,6 +112,9 @@ def get_completion(self, attachments=attachments ) + if yield_messages: + yield MessageOutput("text", self.agent.name, recipient_agent.name, message) + self._create_run(recipient_agent, additional_instructions, event_handler, tool_choice) error_attempts = 0 @@ -126,13 +129,23 @@ def get_completion(self, tool_outputs = [] tool_names = [] for tool_call in tool_calls: + if yield_messages: + yield MessageOutput("function", recipient_agent.name, self.agent.name, + str(tool_call.function)) + output = self.execute_tool(tool_call, recipient_agent, event_handler, tool_names) if inspect.isgenerator(output): try: while True: item = next(output) + if isinstance(item, MessageOutput) and yield_messages: + yield item except StopIteration as e: output = e.value + else: + if yield_messages: + yield MessageOutput("function_output", tool_call.function.name, recipient_agent.name, + output) if event_handler: event_handler.agent_name = self.agent.name event_handler.recipient_agent_name = recipient_agent.name @@ -199,6 +212,9 @@ def get_completion(self, else: full_message += self._get_last_message_text() + if yield_messages: + yield MessageOutput("text", recipient_agent.name, self.agent.name, full_message) + if recipient_agent.response_validator: try: if isinstance(recipient_agent, Agent): @@ -211,6 +227,10 @@ def get_completion(self, content=str(e), ) + if yield_messages: + yield MessageOutput("text", self.agent.name, recipient_agent.name, + message.content[0].text.value) + if event_handler: handler = event_handler() handler.on_message_created(message) diff --git a/agency_swarm/threads/thread_async.py b/agency_swarm/threads/thread_async.py index f4c6daad..73954251 100644 --- a/agency_swarm/threads/thread_async.py +++ b/agency_swarm/threads/thread_async.py @@ -22,14 +22,19 @@ def worker(self, additional_instructions: str = None, tool_choice: AssistantToolChoice = None ): - output = super().get_completion(message=message, + gen = super().get_completion(message=message, message_files=message_files, attachments=attachments, recipient_agent=recipient_agent, additional_instructions=additional_instructions, tool_choice=tool_choice) - self.response = f"""{self.recipient_agent.name}'s Response: '{output}'""" + while True: + try: + next(gen) + except StopIteration as e: + self.response = f"""{self.recipient_agent.name}'s Response: '{e.value}'""" + break return