Skip to content

Commit

Permalink
Merge pull request #742 from PrefectHQ/easy-assistants
Browse files Browse the repository at this point in the history
Improve Assistants API
jlowin authored Jan 14, 2024
2 parents c00415c + 0e306dd commit 84fc5a3
Showing 13 changed files with 224 additions and 298 deletions.
224 changes: 131 additions & 93 deletions docs/ai/interactive/assistants.md

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/images/ai/assistants/custom_tools.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/assets/images/ai/assistants/quickstart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/images/ai/assistants/sin_x.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/assets/images/ai/assistants/using_tools.png
Binary file not shown.
136 changes: 0 additions & 136 deletions src/marvin/beta/assistants/README.md

This file was deleted.

71 changes: 43 additions & 28 deletions src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr

import marvin.utilities.tools
from marvin.tools.assistants import AssistantTool
@@ -11,7 +11,7 @@
run_sync,
)
from marvin.utilities.logging import get_logger
from marvin.utilities.openai import get_client
from marvin.utilities.openai import get_openai_client

from .threads import Thread

@@ -29,6 +29,8 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin):
tools: list[Union[AssistantTool, Callable]] = []
file_ids: list[str] = []
metadata: dict[str, str] = {}
# context level tracks nested assistant contexts
_context_level: int = PrivateAttr(0)

default_thread: Thread = Field(
default_factory=Thread,
@@ -57,47 +59,61 @@ async def say_async(
self,
message: str,
file_paths: Optional[list[str]] = None,
thread: Optional[Thread] = None,
**run_kwargs,
) -> "Run":
"""
A convenience method for adding a user message to the assistant's
default thread, running the assistant, and returning the assistant's
messages.
"""
thread = thread or self.default_thread

last_message = await thread.get_messages_async(limit=1)
if last_message:
last_msg_id = last_message[0].id
else:
last_msg_id = None

# post the message
if message:
await self.default_thread.add_async(message, file_paths=file_paths)
await thread.add_async(message, file_paths=file_paths)

run = await self.default_thread.run_async(
assistant=self,
**run_kwargs,
)
return run
# run the thread
async with self:
await thread.run_async(assistant=self, **run_kwargs)

# load all messages, including the user message
response_messages = await thread.get_messages_async(after_message=last_msg_id)
return response_messages

def __enter__(self):
self.create()
return self
return run_sync(self.__aenter__())

def __exit__(self, exc_type, exc_val, exc_tb):
self.delete()
# If an exception has occurred, you might want to handle it or pass it through
# Returning False here will re-raise any exception that occurred in the context
return False
return run_sync(self.__aexit__(exc_type, exc_val, exc_tb))

async def __aenter__(self):
await self.create_async()
self._context_level += 1
# if this is the outermost context and no ID is set, create the assistant
if self.id is None and self._context_level == 1:
await self.create_async()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.delete_async()
# If an exception has occurred, you might want to handle it or pass it through
# Returning False here will re-raise any exception that occurred in the context
# If this is the outermost context, delete the assistant
if self._context_level == 1:
await self.delete_async()
self._context_level -= 1
return False

@expose_sync_method("create")
async def create_async(self):
if self.id is not None:
raise ValueError("Assistant has already been created.")
client = get_client()
raise ValueError(
"Assistant has an ID and has already been created in the OpenAI API."
)
client = get_openai_client()
response = await client.beta.assistants.create(
**self.model_dump(
include={"name", "model", "metadata", "file_ids", "metadata"}
@@ -106,25 +122,24 @@ async def create_async(self):
instructions=self.get_instructions(),
)
self.id = response.id
self.clear_default_thread()

@expose_sync_method("delete")
async def delete_async(self):
if not self.id:
raise ValueError("Assistant has not been created.")
client = get_client()
raise ValueError("Assistant has no ID and doesn't exist in the OpenAI API.")
client = get_openai_client()
await client.beta.assistants.delete(assistant_id=self.id)
self.id = None

@classmethod
def load(cls, assistant_id: str):
return run_sync(cls.load_async(assistant_id))
def load(cls, assistant_id: str, **kwargs):
return run_sync(cls.load_async(assistant_id, **kwargs))

@classmethod
async def load_async(cls, assistant_id: str):
client = get_client()
async def load_async(cls, assistant_id: str, **kwargs):
client = get_openai_client()
response = await client.beta.assistants.retrieve(assistant_id=assistant_id)
return cls.model_validate(response)
return cls(**(response.model_dump() | kwargs))

def chat(self, thread: Thread = None):
if thread is None:
55 changes: 29 additions & 26 deletions src/marvin/beta/assistants/runs.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from marvin.types import Tool
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from marvin.utilities.logging import get_logger
from marvin.utilities.openai import get_client
from marvin.utilities.openai import get_openai_client

from .assistants import Assistant
from .threads import Thread
@@ -54,20 +54,20 @@ def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]):

@expose_sync_method("refresh")
async def refresh_async(self):
client = get_client()
client = get_openai_client()
self.run = await client.beta.threads.runs.retrieve(
run_id=self.run.id, thread_id=self.thread.id
)

@expose_sync_method("cancel")
async def cancel_async(self):
client = get_client()
client = get_openai_client()
await client.beta.threads.runs.cancel(
run_id=self.run.id, thread_id=self.thread.id
)

async def _handle_step_requires_action(self):
client = get_client()
client = get_openai_client()
if self.run.status != "requires_action":
return
if self.run.required_action.type == "submit_tool_outputs":
@@ -117,7 +117,7 @@ def get_tools(self) -> list[AssistantTool]:
return tools

async def run_async(self) -> "Run":
client = get_client()
client = get_openai_client()

create_kwargs = {}

@@ -127,30 +127,33 @@ async def run_async(self) -> "Run":
if self.tools is not None or self.additional_tools is not None:
create_kwargs["tools"] = self.get_tools()

self.run = await client.beta.threads.runs.create(
thread_id=self.thread.id, assistant_id=self.assistant.id, **create_kwargs
)

self.assistant.pre_run_hook(run=self)
async with self.assistant:
self.run = await client.beta.threads.runs.create(
thread_id=self.thread.id,
assistant_id=self.assistant.id,
**create_kwargs,
)

try:
while self.run.status in ("queued", "in_progress", "requires_action"):
if self.run.status == "requires_action":
await self._handle_step_requires_action()
await asyncio.sleep(0.1)
self.assistant.pre_run_hook(run=self)

try:
while self.run.status in ("queued", "in_progress", "requires_action"):
if self.run.status == "requires_action":
await self._handle_step_requires_action()
await asyncio.sleep(0.1)
await self.refresh_async()
except CancelRun as exc:
logger.debug(f"`CancelRun` raised; ending run with data: {exc.data}")
await client.beta.threads.runs.cancel(
run_id=self.run.id, thread_id=self.thread.id
)
self.data = exc.data
await self.refresh_async()
except CancelRun as exc:
logger.debug(f"`CancelRun` raised; ending run with data: {exc.data}")
await client.beta.threads.runs.cancel(
run_id=self.run.id, thread_id=self.thread.id
)
self.data = exc.data
await self.refresh_async()

if self.run.status == "failed":
logger.debug(f"Run failed. Last error was: {self.run.last_error}")
if self.run.status == "failed":
logger.debug(f"Run failed. Last error was: {self.run.last_error}")

self.assistant.post_run_hook(run=self)
self.assistant.post_run_hook(run=self)
return self


@@ -193,7 +196,7 @@ async def refresh_run_steps_async(self):
max_attempts = max_fetched / limit + 2

# Fetch the latest run steps
client = get_client()
client = get_openai_client()

response = await client.beta.threads.runs.steps.list(
run_id=self.run.id,
25 changes: 15 additions & 10 deletions src/marvin/beta/assistants/threads.py
Original file line number Diff line number Diff line change
@@ -9,9 +9,10 @@
from marvin.utilities.asyncio import (
ExposeSyncMethodsMixin,
expose_sync_method,
run_sync,
)
from marvin.utilities.logging import get_logger
from marvin.utilities.openai import get_client
from marvin.utilities.openai import get_openai_client
from marvin.utilities.pydantic import parse_as

logger = get_logger("Threads")
@@ -27,13 +28,17 @@ class Thread(BaseModel, ExposeSyncMethodsMixin):
messages: list[ThreadMessage] = Field([], repr=False)

def __enter__(self):
self.create()
return self
return run_sync(self.__aenter__)

def __exit__(self, exc_type, exc_val, exc_tb):
self.delete()
# If an exception has occurred, you might want to handle it or pass it through
# Returning False here will re-raise any exception that occurred in the context
return run_sync(self.__aexit__, exc_type, exc_val, exc_tb)

async def __aenter__(self):
await self.create_async()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.delete_async()
return False

@expose_sync_method("create")
@@ -45,7 +50,7 @@ async def create_async(self, messages: list[str] = None):
raise ValueError("Thread has already been created.")
if messages is not None:
messages = [{"role": "user", "content": message} for message in messages]
client = get_client()
client = get_openai_client()
response = await client.beta.threads.create(messages=messages)
self.id = response.id
return self
@@ -57,7 +62,7 @@ async def add_async(
"""
Add a user message to the thread.
"""
client = get_client()
client = get_openai_client()

if self.id is None:
await self.create_async()
@@ -85,7 +90,7 @@ async def get_messages_async(
) -> list[Union[ThreadMessage, dict]]:
if self.id is None:
await self.create_async()
client = get_client()
client = get_openai_client()

response = await client.beta.threads.messages.list(
thread_id=self.id,
@@ -103,7 +108,7 @@ async def get_messages_async(

@expose_sync_method("delete")
async def delete_async(self):
client = get_client()
client = get_openai_client()
await client.beta.threads.delete(thread_id=self.id)
self.id = None

4 changes: 2 additions & 2 deletions src/marvin/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from rich.console import Console
from typing import Optional
from marvin.utilities.asyncio import run_sync
from marvin.utilities.openai import get_client
from marvin.utilities.openai import get_openai_client
from marvin.cli.version import display_version

app = typer.Typer()
@@ -27,7 +27,7 @@ def main(


async def process_stdin(model: str, max_tokens: int):
client = get_client()
client = get_openai_client()
content = sys.stdin.read()
last_chunk_ended_with_space = False

4 changes: 2 additions & 2 deletions src/marvin/types.py
Original file line number Diff line number Diff line change
@@ -52,11 +52,11 @@ class ToolSet(BaseModel, Generic[T]):


class RetrievalTool(Tool[T]):
type: str = Field(default="retrieval")
type: Literal["retrieval"] = "retrieval"


class CodeInterpreterTool(Tool[T]):
type: str = Field(default="code_interpreter")
type: Literal["code_interpreter"] = "code_interpreter"


class FunctionCall(BaseModel):
3 changes: 2 additions & 1 deletion src/marvin/utilities/openai.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Module for working with OpenAI."""

import asyncio
from functools import lru_cache
from typing import Optional

from openai import AsyncClient


def get_client() -> AsyncClient:
def get_openai_client() -> AsyncClient:
"""
Retrieves an OpenAI client with the given api key and organization.

0 comments on commit 84fc5a3

Please sign in to comment.