From 2d7e134650d473b3feb2ece9b03568f9d2714103 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 8 Feb 2024 15:37:46 -0500 Subject: [PATCH 1/2] Give users access to the tool calls and tool outputs in post_run_hook --- src/marvin/beta/assistants/assistants.py | 8 +++++++- src/marvin/beta/assistants/runs.py | 15 +++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index 570167f99..1c79de191 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union from pydantic import BaseModel, Field, PrivateAttr +from openai.types.beta.threads.required_action_function_tool_call import RequiredActionFunctionToolCall import marvin.utilities.tools from marvin.tools.assistants import AssistantTool @@ -168,5 +169,10 @@ def chat(self, thread: Thread = None): def pre_run_hook(self, run: "Run"): pass - def post_run_hook(self, run: "Run"): + def post_run_hook( + self, + run: "Run", + tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None, + tool_outputs: Optional[list[dict[str, str]]] = None + ): pass diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 014aad8b1..4b940e491 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -11,6 +11,7 @@ from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.logging import get_logger from marvin.utilities.openai import get_openai_client +from openai.types.beta.threads.required_action_function_tool_call import RequiredActionFunctionToolCall from .assistants import Assistant from .threads import Thread @@ -85,11 +86,12 @@ async def cancel_async(self): run_id=self.run.id, thread_id=self.thread.id ) - async def _handle_step_requires_action(self): + async def _handle_step_requires_action(self) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]: client = get_openai_client() if self.run.status != "requires_action": - return + return None, None if self.run.required_action.type == "submit_tool_outputs": + tool_calls = [] tool_outputs = [] tools = self.get_tools() @@ -110,10 +112,12 @@ async def _handle_step_requires_action(self): tool_outputs.append( dict(tool_call_id=tool_call.id, output=output or "") ) + tool_calls.append(tool_call) await client.beta.threads.runs.submit_tool_outputs( thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_outputs ) + return tool_calls, tool_outputs def get_instructions(self) -> str: if self.instructions is None: @@ -157,10 +161,13 @@ async def run_async(self) -> "Run": self.assistant.pre_run_hook(run=self) + tool_calls = None + tool_outputs = None + try: while self.run.status in ("queued", "in_progress", "requires_action"): if self.run.status == "requires_action": - await self._handle_step_requires_action() + tool_calls, tool_outputs = await self._handle_step_requires_action() await asyncio.sleep(0.1) await self.refresh_async() except CancelRun as exc: @@ -174,7 +181,7 @@ async def run_async(self) -> "Run": if self.run.status == "failed": logger.debug(f"Run failed. Last error was: {self.run.last_error}") - self.assistant.post_run_hook(run=self) + self.assistant.post_run_hook(run=self, tool_calls=tool_calls, tool_outputs=tool_outputs) return self From ecda861d048e868a0dc99e8a5bbcad36f9d57d1f Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Thu, 8 Feb 2024 16:01:43 -0500 Subject: [PATCH 2/2] Run pre-commit formatting --- src/marvin/beta/assistants/assistants.py | 14 ++++++++------ src/marvin/beta/assistants/runs.py | 17 +++++++++++++---- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index 1c79de191..5c1ddbc5d 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -1,7 +1,9 @@ from typing import TYPE_CHECKING, Callable, Optional, Union +from openai.types.beta.threads.required_action_function_tool_call import ( + RequiredActionFunctionToolCall, +) from pydantic import BaseModel, Field, PrivateAttr -from openai.types.beta.threads.required_action_function_tool_call import RequiredActionFunctionToolCall import marvin.utilities.tools from marvin.tools.assistants import AssistantTool @@ -170,9 +172,9 @@ def pre_run_hook(self, run: "Run"): pass def post_run_hook( - self, - run: "Run", - tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None, - tool_outputs: Optional[list[dict[str, str]]] = None - ): + self, + run: "Run", + tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None, + tool_outputs: Optional[list[dict[str, str]]] = None, + ): pass diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 4b940e491..7f46756e9 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -1,6 +1,9 @@ import asyncio from typing import Any, Callable, Optional, Union +from openai.types.beta.threads.required_action_function_tool_call import ( + RequiredActionFunctionToolCall, +) from openai.types.beta.threads.run import Run as OpenAIRun from openai.types.beta.threads.runs import RunStep as OpenAIRunStep from pydantic import BaseModel, Field, PrivateAttr, field_validator @@ -11,7 +14,6 @@ from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.logging import get_logger from marvin.utilities.openai import get_openai_client -from openai.types.beta.threads.required_action_function_tool_call import RequiredActionFunctionToolCall from .assistants import Assistant from .threads import Thread @@ -86,7 +88,9 @@ async def cancel_async(self): run_id=self.run.id, thread_id=self.thread.id ) - async def _handle_step_requires_action(self) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]: + async def _handle_step_requires_action( + self, + ) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]: client = get_openai_client() if self.run.status != "requires_action": return None, None @@ -167,7 +171,10 @@ async def run_async(self) -> "Run": try: while self.run.status in ("queued", "in_progress", "requires_action"): if self.run.status == "requires_action": - tool_calls, tool_outputs = await self._handle_step_requires_action() + ( + tool_calls, + tool_outputs, + ) = await self._handle_step_requires_action() await asyncio.sleep(0.1) await self.refresh_async() except CancelRun as exc: @@ -181,7 +188,9 @@ async def run_async(self) -> "Run": if self.run.status == "failed": logger.debug(f"Run failed. Last error was: {self.run.last_error}") - self.assistant.post_run_hook(run=self, tool_calls=tool_calls, tool_outputs=tool_outputs) + self.assistant.post_run_hook( + run=self, tool_calls=tool_calls, tool_outputs=tool_outputs + ) return self