diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index 570167f99..5c1ddbc5d 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -1,5 +1,8 @@ from typing import TYPE_CHECKING, Callable, Optional, Union +from openai.types.beta.threads.required_action_function_tool_call import ( + RequiredActionFunctionToolCall, +) from pydantic import BaseModel, Field, PrivateAttr import marvin.utilities.tools @@ -168,5 +171,10 @@ def chat(self, thread: Thread = None): def pre_run_hook(self, run: "Run"): pass - def post_run_hook(self, run: "Run"): + def post_run_hook( + self, + run: "Run", + tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None, + tool_outputs: Optional[list[dict[str, str]]] = None, + ): pass diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 014aad8b1..7f46756e9 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -1,6 +1,9 @@ import asyncio from typing import Any, Callable, Optional, Union +from openai.types.beta.threads.required_action_function_tool_call import ( + RequiredActionFunctionToolCall, +) from openai.types.beta.threads.run import Run as OpenAIRun from openai.types.beta.threads.runs import RunStep as OpenAIRunStep from pydantic import BaseModel, Field, PrivateAttr, field_validator @@ -85,11 +88,14 @@ async def cancel_async(self): run_id=self.run.id, thread_id=self.thread.id ) - async def _handle_step_requires_action(self): + async def _handle_step_requires_action( + self, + ) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]: client = get_openai_client() if self.run.status != "requires_action": - return + return None, None if self.run.required_action.type == "submit_tool_outputs": + tool_calls = [] tool_outputs = [] tools = self.get_tools() @@ -110,10 +116,12 @@ async def _handle_step_requires_action(self): tool_outputs.append( dict(tool_call_id=tool_call.id, output=output or "") ) + tool_calls.append(tool_call) await client.beta.threads.runs.submit_tool_outputs( thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_outputs ) + return tool_calls, tool_outputs def get_instructions(self) -> str: if self.instructions is None: @@ -157,10 +165,16 @@ async def run_async(self) -> "Run": self.assistant.pre_run_hook(run=self) + tool_calls = None + tool_outputs = None + try: while self.run.status in ("queued", "in_progress", "requires_action"): if self.run.status == "requires_action": - await self._handle_step_requires_action() + ( + tool_calls, + tool_outputs, + ) = await self._handle_step_requires_action() await asyncio.sleep(0.1) await self.refresh_async() except CancelRun as exc: @@ -174,7 +188,9 @@ async def run_async(self) -> "Run": if self.run.status == "failed": logger.debug(f"Run failed. Last error was: {self.run.last_error}") - self.assistant.post_run_hook(run=self) + self.assistant.post_run_hook( + run=self, tool_calls=tool_calls, tool_outputs=tool_outputs + ) return self