Skip to content

Commit

Permalink
Merge pull request #904 from PrefectHQ/end-runs
Browse files Browse the repository at this point in the history
Improve EndRun handling
jlowin authored Apr 7, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents dfa3136 + 676abca commit 76382aa
Showing 7 changed files with 85 additions and 20 deletions.
37 changes: 36 additions & 1 deletion docs/docs/interactive/assistants.md
Original file line number Diff line number Diff line change
@@ -228,6 +228,41 @@ Marvin makes it easy to give your assistants custom tools. To do so, pass one or
!!! success "Result"
![](/assets/images/docs/assistants/custom_tools.png)

#### Ending a run early

Normally, the assistant will continue to run until it decides to stop, which usually happens after generating a response. Sometimes it may be useful to end a run early, for example if the assistant uses a tool that indicates the conversation is over. To do this, you can raise an `EndRun` exception from within a tool. This will cause the assistant to cancel the current run and return control. EndRun exceptions can contain data.

There are three ways to raise an `EndRun` exception:

1. Raise the exception directly from the tool function:
```python
from marvin.beta.assistants import Assistant, EndRun

def my_tool():
raise EndRun(data="The final result")

ai = Assistant(tools=[my_tool])
```
1. Return the exception from the tool function. This is useful if e.g. your tools are wrapped in custom exception handlers:
```python
from marvin.beta.assistants import Assistant, EndRun

def my_tool():
return EndRun(data="The final result")

ai = Assistant(tools=[my_tool])
```
1. Return a special string value from the tool function. This is useful if you don't have full control over the tool itself, or need to ensure the tool output is JSON-compatible. Note that this approach does not allow you to attach any data to the exception:
```python
from marvin.beta.assistants import Assistant, ENDRUN_TOKEN

def my_tool():
return ENDRUN_TOKEN

ai = Assistant(tools=[my_tool])
```


### Lifecycle management

Assistants are Marvin objects that correspond to remote objects in the OpenAI API. You can not communicate with an assistant unless it has been registered with the API.
@@ -387,7 +422,7 @@ This will return a `Run` object that represents the OpenAI run. You can use this
When threads are `run` with an assistant, the same lifecycle management rules apply as when you use the assistant's `say` method. In the above example, lazy lifecycle management is used for conveneince. See [lifecycle management](#lifecycle-management) for more information.

!!! warning "Threads are locked while running"
When an assistant is running a thread, the thread is locked and no other messages can be added to it. This applies to both user and assistant messages.
When an assistant is running a thread, the thread is locked and no other messages can be added to it. This applies to both user and assistant messages. To end a run early, you must [use a custom tool](#ending-a-run-early).

### Reading messages

8 changes: 4 additions & 4 deletions src/marvin/beta/ai_flow/ai_task.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
from typing_extensions import ParamSpec

from marvin.beta.assistants import Assistant, Run, Thread
from marvin.beta.assistants.runs import CancelRun
from marvin.beta.assistants.runs import EndRun
from marvin.tools.assistants import AssistantTool
from marvin.utilities.context import ScopedContext
from marvin.utilities.jinja import Environment as JinjaEnvironment
@@ -238,7 +238,7 @@ def _task_completed_tool(self):

def task_completed():
self.status = Status.COMPLETED
raise CancelRun()
raise EndRun()

return task_completed

@@ -249,7 +249,7 @@ def task_completed():
def task_completed_with_result(result: T):
self.status = Status.COMPLETED
self.result = result
raise CancelRun()
raise EndRun()

tool.function._python_fn = task_completed_with_result

@@ -261,7 +261,7 @@ def task_failed(reason: str) -> None:
"""Indicate that the task failed for the provided `reason`."""
self.status = Status.FAILED
self.result = reason
raise CancelRun()
raise EndRun()

return tool_from_function(task_failed)

2 changes: 1 addition & 1 deletion src/marvin/beta/assistants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .runs import Run
from .runs import Run, EndRun, ENDRUN_TOKEN
from .threads import Thread
from .assistants import Assistant
from .handlers import PrintHandler
24 changes: 17 additions & 7 deletions src/marvin/beta/assistants/runs.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@

import marvin.utilities.openai
import marvin.utilities.tools
from marvin.tools.assistants import AssistantTool, CancelRun
from marvin.tools.assistants import ENDRUN_TOKEN, AssistantTool, EndRun
from marvin.types import Tool
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from marvin.utilities.logging import get_logger
@@ -156,7 +156,7 @@ def _get_run_kwargs(self, thread: Thread = None, **run_kwargs) -> dict:

return run_kwargs

async def get_tool_outputs(self, run: OpenAIRun) -> list[dict[str, str]]:
async def get_tool_outputs(self, run: OpenAIRun) -> list[str]:
if run.status != "requires_action":
return None, None
if run.required_action.type == "submit_tool_outputs":
@@ -170,16 +170,26 @@ async def get_tool_outputs(self, run: OpenAIRun) -> list[dict[str, str]]:
tools=tools,
function_name=tool_call.function.name,
function_arguments_json=tool_call.function.arguments,
return_string=True,
)
except CancelRun as exc:
# functions can raise EndRun, return an EndRun, or return the endrun token
# to end the run
if isinstance(output, EndRun):
raise output
elif output == ENDRUN_TOKEN:
raise EndRun()

except EndRun as exc:
logger.debug(f"Ending run with data: {exc.data}")
raise
except Exception as exc:
output = f"Error calling function {tool_call.function.name}: {exc}"
logger.error(output)
string_output = marvin.utilities.tools.output_to_string(output)
tool_outputs.append(
dict(tool_call_id=tool_call.id, output=output or "")
dict(
tool_call_id=tool_call.id,
output=string_output,
)
)
tool_calls.append(tool_call)

@@ -229,8 +239,8 @@ async def run_async(self) -> "Run":
await stream.until_done()
await self._update_run_from_handler(handler)

except CancelRun as exc:
logger.debug(f"`CancelRun` raised; ending run with data: {exc.data}")
except EndRun as exc:
logger.debug(f"`EndRun` raised; ending run with data: {exc.data}")
await self.cancel_async()
self.data = exc.data

4 changes: 3 additions & 1 deletion src/marvin/tools/assistants.py
Original file line number Diff line number Diff line change
@@ -7,8 +7,10 @@

AssistantTool = Union[RetrievalTool, CodeInterpreterTool, Tool]

ENDRUN_TOKEN = "<|ENDRUN|>"

class CancelRun(Exception):

class EndRun(Exception):
"""
A special exception that can be raised in a tool to end the run immediately.
"""
13 changes: 10 additions & 3 deletions src/marvin/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import base64
import datetime
import inspect
from pathlib import Path
from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union

import openai.types.chat
from openai.types.chat import ChatCompletion
from pydantic import BaseModel, Field, PrivateAttr, computed_field
from pydantic import BaseModel, Field, PrivateAttr, computed_field, field_validator
from typing_extensions import Annotated, Self

from marvin.settings import settings
@@ -32,19 +33,25 @@ class ResponseFormat(BaseModel):
class MarvinType(BaseModel):
# by default, mavin types are not allowed to have extra fields
# because they are used for validation throughout the codebase
model_config = dict(extra="forbid")
model_config = dict(extra="forbid", validate_assignment=True)


class Function(MarvinType, Generic[T]):
name: str
description: Optional[str]
description: Optional[str] = Field(validate_default=True)
parameters: dict[str, Any]

model: Optional[type[T]] = Field(default=None, exclude=True, repr=False)

# Private field that holds the executable function, if available
_python_fn: Optional[Callable[..., Any]] = PrivateAttr(default=None)

@field_validator("description", mode="before")
def _clean_description(cls, v):
if isinstance(v, str):
v = inspect.cleandoc(v)
return v

def validate_json(self: Self, json_data: Union[str, bytes, bytearray]) -> T:
if self.model is None:
raise ValueError("This Function was not initialized with a model.")
17 changes: 14 additions & 3 deletions src/marvin/utilities/tools.py
Original file line number Diff line number Diff line change
@@ -145,7 +145,6 @@ def call_function_tool(
tools: list[FunctionTool],
function_name: str,
function_arguments_json: str,
return_string: bool = False,
) -> str:
"""
Helper function for calling a function tool from a list of tools, using the arguments
@@ -180,6 +179,18 @@ def call_function_tool(
if len(truncated_output) < len(str(output)):
truncated_output += "..."
logger.debug_kv(f"{tool.function.name}", f"returned: {truncated_output}", "green")
if return_string and not isinstance(output, str):
output = json.dumps(output)
return output


def output_to_string(output: Any) -> str:
"""
Function outputs must be provided as strings
"""
if output is None:
output = ""
elif not isinstance(output, str):
try:
output = TypeAdapter(type(output)).dump_json(output).decode()
except Exception:
output = str(output)
return output

0 comments on commit 76382aa

Please sign in to comment.