From 040a6ab9e9a4c7b709d94768a01d1cf266438f5c Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 6 Apr 2024 11:57:35 -0400 Subject: [PATCH 1/5] =?UTF-8?q?CancelRun=20=E2=86=92=20EndRun?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/marvin/beta/ai_flow/ai_task.py | 8 ++++---- src/marvin/beta/assistants/runs.py | 8 ++++---- src/marvin/tools/assistants.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/marvin/beta/ai_flow/ai_task.py b/src/marvin/beta/ai_flow/ai_task.py index 9dce85d4a..b7c925701 100644 --- a/src/marvin/beta/ai_flow/ai_task.py +++ b/src/marvin/beta/ai_flow/ai_task.py @@ -8,7 +8,7 @@ from typing_extensions import ParamSpec from marvin.beta.assistants import Assistant, Run, Thread -from marvin.beta.assistants.runs import CancelRun +from marvin.beta.assistants.runs import EndRun from marvin.tools.assistants import AssistantTool from marvin.utilities.context import ScopedContext from marvin.utilities.jinja import Environment as JinjaEnvironment @@ -238,7 +238,7 @@ def _task_completed_tool(self): def task_completed(): self.status = Status.COMPLETED - raise CancelRun() + raise EndRun() return task_completed @@ -249,7 +249,7 @@ def task_completed(): def task_completed_with_result(result: T): self.status = Status.COMPLETED self.result = result - raise CancelRun() + raise EndRun() tool.function._python_fn = task_completed_with_result @@ -261,7 +261,7 @@ def task_failed(reason: str) -> None: """Indicate that the task failed for the provided `reason`.""" self.status = Status.FAILED self.result = reason - raise CancelRun() + raise EndRun() return tool_from_function(task_failed) diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 8cb5bf3c2..f0cbbcfbf 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -8,7 +8,7 @@ import marvin.utilities.openai import marvin.utilities.tools -from marvin.tools.assistants import AssistantTool, CancelRun +from marvin.tools.assistants import AssistantTool, EndRun from marvin.types import Tool from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.logging import get_logger @@ -172,7 +172,7 @@ async def get_tool_outputs(self, run: OpenAIRun) -> list[dict[str, str]]: function_arguments_json=tool_call.function.arguments, return_string=True, ) - except CancelRun as exc: + except EndRun as exc: logger.debug(f"Ending run with data: {exc.data}") raise except Exception as exc: @@ -229,8 +229,8 @@ async def run_async(self) -> "Run": await stream.until_done() await self._update_run_from_handler(handler) - except CancelRun as exc: - logger.debug(f"`CancelRun` raised; ending run with data: {exc.data}") + except EndRun as exc: + logger.debug(f"`EndRun` raised; ending run with data: {exc.data}") await self.cancel_async() self.data = exc.data diff --git a/src/marvin/tools/assistants.py b/src/marvin/tools/assistants.py index abe32c3cc..176258ab1 100644 --- a/src/marvin/tools/assistants.py +++ b/src/marvin/tools/assistants.py @@ -8,7 +8,7 @@ AssistantTool = Union[RetrievalTool, CodeInterpreterTool, Tool] -class CancelRun(Exception): +class EndRun(Exception): """ A special exception that can be raised in a tool to end the run immediately. """ From 1b178faff97f333a7f1fef6c039c34f0b81e72fd Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 6 Apr 2024 11:57:49 -0400 Subject: [PATCH 2/5] Convert outputs to strings in helper fn --- src/marvin/beta/assistants/runs.py | 11 +++++++++-- src/marvin/utilities/tools.py | 22 ++++++++++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index f0cbbcfbf..64632170f 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -156,7 +156,9 @@ def _get_run_kwargs(self, thread: Thread = None, **run_kwargs) -> dict: return run_kwargs - async def get_tool_outputs(self, run: OpenAIRun) -> list[dict[str, str]]: + async def get_tool_outputs( + self, run: OpenAIRun, as_strings: bool = True + ) -> list[dict[str, str]]: if run.status != "requires_action": return None, None if run.required_action.type == "submit_tool_outputs": @@ -183,6 +185,9 @@ async def get_tool_outputs(self, run: OpenAIRun) -> list[dict[str, str]]: ) tool_calls.append(tool_call) + if as_strings: + tool_outputs = marvin.utilities.tools.get_string_outputs(tool_outputs) + return tool_outputs async def run_async(self) -> "Run": @@ -216,7 +221,9 @@ async def run_async(self) -> "Run": await self._update_run_from_handler(handler) while handler.current_run.status in ["requires_action"]: - tool_outputs = await self.get_tool_outputs(run=handler.current_run) + tool_outputs = await self.get_tool_outputs( + run=handler.current_run, as_strings=True + ) handler = event_handler_class(**self.event_handler_kwargs) diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index d9485e913..529d2b093 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -180,6 +180,24 @@ def call_function_tool( if len(truncated_output) < len(str(output)): truncated_output += "..." logger.debug_kv(f"{tool.function.name}", f"returned: {truncated_output}", "green") - if return_string and not isinstance(output, str): - output = json.dumps(output) return output + + +def get_string_outputs(tool_outputs: list[Any]) -> list[str]: + """ + Function outputs must be provided as strings + """ + string_outputs = [] + for o in tool_outputs: + if isinstance(o, None): + o = "" + elif not isinstance(o, str): + if isinstance(o, BaseModel): + o = o.model_dump_json() + else: + try: + o = json.dumps(o) + except json.JSONDecodeError: + o = str(o) + string_outputs.append(o) + return string_outputs From 13820da63af7abb89b4da0c857030cf56ded2af4 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 6 Apr 2024 13:55:40 -0400 Subject: [PATCH 3/5] Improve endrun handling --- docs/docs/interactive/assistants.md | 37 +++++++++++++++++++++++++- src/marvin/beta/assistants/__init__.py | 2 +- src/marvin/beta/assistants/runs.py | 26 +++++++++--------- src/marvin/tools/assistants.py | 2 ++ src/marvin/utilities/tools.py | 28 +++++++++---------- 5 files changed, 65 insertions(+), 30 deletions(-) diff --git a/docs/docs/interactive/assistants.md b/docs/docs/interactive/assistants.md index cc772311c..77716395a 100644 --- a/docs/docs/interactive/assistants.md +++ b/docs/docs/interactive/assistants.md @@ -228,6 +228,41 @@ Marvin makes it easy to give your assistants custom tools. To do so, pass one or !!! success "Result" ![](/assets/images/docs/assistants/custom_tools.png) +#### Ending a run early + +Normally, the assistant will continue to run until it decides to stop, which usually happens after generating a response. Sometimes it may be useful to end a run early, for example if the assistant uses a tool that indicates the conversation is over. To do this, you can raise an `EndRun` exception from within a tool. This will cause the assistant to cancel the current run and return control. EndRun exceptions can contain data. + +There are three ways to raise an `EndRun` exception: + +1. Raise the exception directly from the tool function: +```python +from marvin.beta.assistants import Assistant, EndRun + +def my_tool(): + raise EndRun(data="The final result") + +ai = Assistant(tools=[my_tool]) +``` +1. Return the exception from the tool function. This is useful if e.g. your tools are wrapped in custom exception handlers: +```python +from marvin.beta.assistants import Assistant, EndRun + +def my_tool(): + return EndRun(data="The final result") + +ai = Assistant(tools=[my_tool]) +``` +1. Return a special string value from the tool function. This is useful if you don't have full control over the tool itself, or need to ensure the tool output is JSON-compatible. Note that this approach does not allow you to attach any data to the exception: +```python +from marvin.beta.assistants import Assistant, ENDRUN_TOKEN + +def my_tool(): + return ENDRUN_TOKEN + +ai = Assistant(tools=[my_tool]) +``` + + ### Lifecycle management Assistants are Marvin objects that correspond to remote objects in the OpenAI API. You can not communicate with an assistant unless it has been registered with the API. @@ -387,7 +422,7 @@ This will return a `Run` object that represents the OpenAI run. You can use this When threads are `run` with an assistant, the same lifecycle management rules apply as when you use the assistant's `say` method. In the above example, lazy lifecycle management is used for conveneince. See [lifecycle management](#lifecycle-management) for more information. !!! warning "Threads are locked while running" - When an assistant is running a thread, the thread is locked and no other messages can be added to it. This applies to both user and assistant messages. + When an assistant is running a thread, the thread is locked and no other messages can be added to it. This applies to both user and assistant messages. To end a run early, you must [use a custom tool](#ending-a-run-early). ### Reading messages diff --git a/src/marvin/beta/assistants/__init__.py b/src/marvin/beta/assistants/__init__.py index 7f2791983..29e401e60 100644 --- a/src/marvin/beta/assistants/__init__.py +++ b/src/marvin/beta/assistants/__init__.py @@ -1,4 +1,4 @@ -from .runs import Run +from .runs import Run, EndRun, ENDRUN_TOKEN from .threads import Thread from .assistants import Assistant from .handlers import PrintHandler diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 64632170f..257387d61 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -8,7 +8,7 @@ import marvin.utilities.openai import marvin.utilities.tools -from marvin.tools.assistants import AssistantTool, EndRun +from marvin.tools.assistants import ENDRUN_TOKEN, AssistantTool, EndRun from marvin.types import Tool from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.logging import get_logger @@ -156,9 +156,7 @@ def _get_run_kwargs(self, thread: Thread = None, **run_kwargs) -> dict: return run_kwargs - async def get_tool_outputs( - self, run: OpenAIRun, as_strings: bool = True - ) -> list[dict[str, str]]: + async def get_tool_outputs(self, run: OpenAIRun) -> list[Any]: if run.status != "requires_action": return None, None if run.required_action.type == "submit_tool_outputs": @@ -172,8 +170,13 @@ async def get_tool_outputs( tools=tools, function_name=tool_call.function.name, function_arguments_json=tool_call.function.arguments, - return_string=True, ) + # functions can raise EndRun, return an EndRun, or return the endrun token + # to end the run + if isinstance(output, EndRun): + raise output + elif output == ENDRUN_TOKEN: + raise EndRun() except EndRun as exc: logger.debug(f"Ending run with data: {exc.data}") raise @@ -185,9 +188,6 @@ async def get_tool_outputs( ) tool_calls.append(tool_call) - if as_strings: - tool_outputs = marvin.utilities.tools.get_string_outputs(tool_outputs) - return tool_outputs async def run_async(self) -> "Run": @@ -221,16 +221,18 @@ async def run_async(self) -> "Run": await self._update_run_from_handler(handler) while handler.current_run.status in ["requires_action"]: - tool_outputs = await self.get_tool_outputs( - run=handler.current_run, as_strings=True - ) + tool_outputs = await self.get_tool_outputs(run=handler.current_run) + + string_outputs = [ + marvin.utilities.tools.output_to_string(o) for o in tool_outputs + ] handler = event_handler_class(**self.event_handler_kwargs) async with client.beta.threads.runs.submit_tool_outputs_stream( thread_id=self.thread.id, run_id=self.run.id, - tool_outputs=tool_outputs, + tool_outputs=string_outputs, event_handler=handler, ) as stream: await stream.until_done() diff --git a/src/marvin/tools/assistants.py b/src/marvin/tools/assistants.py index 176258ab1..99075eab8 100644 --- a/src/marvin/tools/assistants.py +++ b/src/marvin/tools/assistants.py @@ -7,6 +7,8 @@ AssistantTool = Union[RetrievalTool, CodeInterpreterTool, Tool] +ENDRUN_TOKEN = "<|ENDRUN|>" + class EndRun(Exception): """ diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index 529d2b093..aa0737a07 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -145,7 +145,6 @@ def call_function_tool( tools: list[FunctionTool], function_name: str, function_arguments_json: str, - return_string: bool = False, ) -> str: """ Helper function for calling a function tool from a list of tools, using the arguments @@ -183,21 +182,18 @@ def call_function_tool( return output -def get_string_outputs(tool_outputs: list[Any]) -> list[str]: +def output_to_string(output: Any) -> str: """ Function outputs must be provided as strings """ - string_outputs = [] - for o in tool_outputs: - if isinstance(o, None): - o = "" - elif not isinstance(o, str): - if isinstance(o, BaseModel): - o = o.model_dump_json() - else: - try: - o = json.dumps(o) - except json.JSONDecodeError: - o = str(o) - string_outputs.append(o) - return string_outputs + if isinstance(output, None): + output = "" + elif not isinstance(output, str): + if isinstance(output, BaseModel): + output = output.model_dump_json() + else: + try: + output = json.dumps(output) + except json.JSONDecodeError: + output = str(output) + return output From 966818eaffd335d6a601e4e3d33885c6743379fe Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 6 Apr 2024 14:07:51 -0400 Subject: [PATCH 4/5] Fix string conversion --- src/marvin/utilities/tools.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index aa0737a07..39b977688 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -186,14 +186,11 @@ def output_to_string(output: Any) -> str: """ Function outputs must be provided as strings """ - if isinstance(output, None): + if output is None: output = "" elif not isinstance(output, str): - if isinstance(output, BaseModel): - output = output.model_dump_json() - else: - try: - output = json.dumps(output) - except json.JSONDecodeError: - output = str(output) + try: + output = TypeAdapter(type(output)).dump_json(output) + except Exception: + output = str(output) return output From 80b17321702e7f5dd737dff0dd788644b735a7ff Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 6 Apr 2024 14:23:11 -0400 Subject: [PATCH 5/5] Fix string handling --- src/marvin/beta/assistants/runs.py | 15 ++++++++------- src/marvin/types.py | 13 ++++++++++--- src/marvin/utilities/tools.py | 2 +- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 257387d61..fabd17f57 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -156,7 +156,7 @@ def _get_run_kwargs(self, thread: Thread = None, **run_kwargs) -> dict: return run_kwargs - async def get_tool_outputs(self, run: OpenAIRun) -> list[Any]: + async def get_tool_outputs(self, run: OpenAIRun) -> list[str]: if run.status != "requires_action": return None, None if run.required_action.type == "submit_tool_outputs": @@ -177,14 +177,19 @@ async def get_tool_outputs(self, run: OpenAIRun) -> list[Any]: raise output elif output == ENDRUN_TOKEN: raise EndRun() + except EndRun as exc: logger.debug(f"Ending run with data: {exc.data}") raise except Exception as exc: output = f"Error calling function {tool_call.function.name}: {exc}" logger.error(output) + string_output = marvin.utilities.tools.output_to_string(output) tool_outputs.append( - dict(tool_call_id=tool_call.id, output=output or "") + dict( + tool_call_id=tool_call.id, + output=string_output, + ) ) tool_calls.append(tool_call) @@ -223,16 +228,12 @@ async def run_async(self) -> "Run": while handler.current_run.status in ["requires_action"]: tool_outputs = await self.get_tool_outputs(run=handler.current_run) - string_outputs = [ - marvin.utilities.tools.output_to_string(o) for o in tool_outputs - ] - handler = event_handler_class(**self.event_handler_kwargs) async with client.beta.threads.runs.submit_tool_outputs_stream( thread_id=self.thread.id, run_id=self.run.id, - tool_outputs=string_outputs, + tool_outputs=tool_outputs, event_handler=handler, ) as stream: await stream.until_done() diff --git a/src/marvin/types.py b/src/marvin/types.py index 1bdc2f30c..d2c564d81 100644 --- a/src/marvin/types.py +++ b/src/marvin/types.py @@ -1,11 +1,12 @@ import base64 import datetime +import inspect from pathlib import Path from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union import openai.types.chat from openai.types.chat import ChatCompletion -from pydantic import BaseModel, Field, PrivateAttr, computed_field +from pydantic import BaseModel, Field, PrivateAttr, computed_field, field_validator from typing_extensions import Annotated, Self from marvin.settings import settings @@ -32,12 +33,12 @@ class ResponseFormat(BaseModel): class MarvinType(BaseModel): # by default, mavin types are not allowed to have extra fields # because they are used for validation throughout the codebase - model_config = dict(extra="forbid") + model_config = dict(extra="forbid", validate_assignment=True) class Function(MarvinType, Generic[T]): name: str - description: Optional[str] + description: Optional[str] = Field(validate_default=True) parameters: dict[str, Any] model: Optional[type[T]] = Field(default=None, exclude=True, repr=False) @@ -45,6 +46,12 @@ class Function(MarvinType, Generic[T]): # Private field that holds the executable function, if available _python_fn: Optional[Callable[..., Any]] = PrivateAttr(default=None) + @field_validator("description", mode="before") + def _clean_description(cls, v): + if isinstance(v, str): + v = inspect.cleandoc(v) + return v + def validate_json(self: Self, json_data: Union[str, bytes, bytearray]) -> T: if self.model is None: raise ValueError("This Function was not initialized with a model.") diff --git a/src/marvin/utilities/tools.py b/src/marvin/utilities/tools.py index 39b977688..56884f522 100644 --- a/src/marvin/utilities/tools.py +++ b/src/marvin/utilities/tools.py @@ -190,7 +190,7 @@ def output_to_string(output: Any) -> str: output = "" elif not isinstance(output, str): try: - output = TypeAdapter(type(output)).dump_json(output) + output = TypeAdapter(type(output)).dump_json(output).decode() except Exception: output = str(output) return output