Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Assistants API #742

Merged
merged 3 commits into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 131 additions & 93 deletions docs/ai/interactive/assistants.md

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/images/ai/assistants/custom_tools.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/assets/images/ai/assistants/quickstart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/images/ai/assistants/sin_x.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/assets/images/ai/assistants/using_tools.png
Binary file not shown.
136 changes: 0 additions & 136 deletions src/marvin/beta/assistants/README.md

This file was deleted.

71 changes: 43 additions & 28 deletions src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr

import marvin.utilities.tools
from marvin.tools.assistants import AssistantTool
Expand All @@ -11,7 +11,7 @@
run_sync,
)
from marvin.utilities.logging import get_logger
from marvin.utilities.openai import get_client
from marvin.utilities.openai import get_openai_client

from .threads import Thread

Expand All @@ -29,6 +29,8 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin):
tools: list[Union[AssistantTool, Callable]] = []
file_ids: list[str] = []
metadata: dict[str, str] = {}
# context level tracks nested assistant contexts
_context_level: int = PrivateAttr(0)

default_thread: Thread = Field(
default_factory=Thread,
Expand Down Expand Up @@ -57,47 +59,61 @@ async def say_async(
self,
message: str,
file_paths: Optional[list[str]] = None,
thread: Optional[Thread] = None,
**run_kwargs,
) -> "Run":
"""
A convenience method for adding a user message to the assistant's
default thread, running the assistant, and returning the assistant's
messages.
"""
thread = thread or self.default_thread

last_message = await thread.get_messages_async(limit=1)
if last_message:
last_msg_id = last_message[0].id
else:
last_msg_id = None

# post the message
if message:
await self.default_thread.add_async(message, file_paths=file_paths)
await thread.add_async(message, file_paths=file_paths)

run = await self.default_thread.run_async(
assistant=self,
**run_kwargs,
)
return run
# run the thread
async with self:
await thread.run_async(assistant=self, **run_kwargs)

# load all messages, including the user message
response_messages = await thread.get_messages_async(after_message=last_msg_id)
return response_messages

def __enter__(self):
self.create()
return self
return run_sync(self.__aenter__())

def __exit__(self, exc_type, exc_val, exc_tb):
self.delete()
# If an exception has occurred, you might want to handle it or pass it through
# Returning False here will re-raise any exception that occurred in the context
return False
return run_sync(self.__aexit__(exc_type, exc_val, exc_tb))

async def __aenter__(self):
await self.create_async()
self._context_level += 1
# if this is the outermost context and no ID is set, create the assistant
if self.id is None and self._context_level == 1:
await self.create_async()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.delete_async()
# If an exception has occurred, you might want to handle it or pass it through
# Returning False here will re-raise any exception that occurred in the context
# If this is the outermost context, delete the assistant
if self._context_level == 1:
await self.delete_async()
self._context_level -= 1
return False

@expose_sync_method("create")
async def create_async(self):
if self.id is not None:
raise ValueError("Assistant has already been created.")
client = get_client()
raise ValueError(
"Assistant has an ID and has already been created in the OpenAI API."
)
client = get_openai_client()
response = await client.beta.assistants.create(
**self.model_dump(
include={"name", "model", "metadata", "file_ids", "metadata"}
Expand All @@ -106,25 +122,24 @@ async def create_async(self):
instructions=self.get_instructions(),
)
self.id = response.id
self.clear_default_thread()

@expose_sync_method("delete")
async def delete_async(self):
if not self.id:
raise ValueError("Assistant has not been created.")
client = get_client()
raise ValueError("Assistant has no ID and doesn't exist in the OpenAI API.")
client = get_openai_client()
await client.beta.assistants.delete(assistant_id=self.id)
self.id = None

@classmethod
def load(cls, assistant_id: str):
return run_sync(cls.load_async(assistant_id))
def load(cls, assistant_id: str, **kwargs):
return run_sync(cls.load_async(assistant_id, **kwargs))

@classmethod
async def load_async(cls, assistant_id: str):
client = get_client()
async def load_async(cls, assistant_id: str, **kwargs):
client = get_openai_client()
response = await client.beta.assistants.retrieve(assistant_id=assistant_id)
return cls.model_validate(response)
return cls(**(response.model_dump() | kwargs))

def chat(self, thread: Thread = None):
if thread is None:
Expand Down
55 changes: 29 additions & 26 deletions src/marvin/beta/assistants/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from marvin.types import Tool
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from marvin.utilities.logging import get_logger
from marvin.utilities.openai import get_client
from marvin.utilities.openai import get_openai_client

from .assistants import Assistant
from .threads import Thread
Expand Down Expand Up @@ -54,20 +54,20 @@ def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]):

@expose_sync_method("refresh")
async def refresh_async(self):
client = get_client()
client = get_openai_client()
self.run = await client.beta.threads.runs.retrieve(
run_id=self.run.id, thread_id=self.thread.id
)

@expose_sync_method("cancel")
async def cancel_async(self):
client = get_client()
client = get_openai_client()
await client.beta.threads.runs.cancel(
run_id=self.run.id, thread_id=self.thread.id
)

async def _handle_step_requires_action(self):
client = get_client()
client = get_openai_client()
if self.run.status != "requires_action":
return
if self.run.required_action.type == "submit_tool_outputs":
Expand Down Expand Up @@ -117,7 +117,7 @@ def get_tools(self) -> list[AssistantTool]:
return tools

async def run_async(self) -> "Run":
client = get_client()
client = get_openai_client()

create_kwargs = {}

Expand All @@ -127,30 +127,33 @@ async def run_async(self) -> "Run":
if self.tools is not None or self.additional_tools is not None:
create_kwargs["tools"] = self.get_tools()

self.run = await client.beta.threads.runs.create(
thread_id=self.thread.id, assistant_id=self.assistant.id, **create_kwargs
)

self.assistant.pre_run_hook(run=self)
async with self.assistant:
self.run = await client.beta.threads.runs.create(
thread_id=self.thread.id,
assistant_id=self.assistant.id,
**create_kwargs,
)

try:
while self.run.status in ("queued", "in_progress", "requires_action"):
if self.run.status == "requires_action":
await self._handle_step_requires_action()
await asyncio.sleep(0.1)
self.assistant.pre_run_hook(run=self)

try:
while self.run.status in ("queued", "in_progress", "requires_action"):
if self.run.status == "requires_action":
await self._handle_step_requires_action()
await asyncio.sleep(0.1)
await self.refresh_async()
except CancelRun as exc:
logger.debug(f"`CancelRun` raised; ending run with data: {exc.data}")
await client.beta.threads.runs.cancel(
run_id=self.run.id, thread_id=self.thread.id
)
self.data = exc.data
await self.refresh_async()
except CancelRun as exc:
logger.debug(f"`CancelRun` raised; ending run with data: {exc.data}")
await client.beta.threads.runs.cancel(
run_id=self.run.id, thread_id=self.thread.id
)
self.data = exc.data
await self.refresh_async()

if self.run.status == "failed":
logger.debug(f"Run failed. Last error was: {self.run.last_error}")
if self.run.status == "failed":
logger.debug(f"Run failed. Last error was: {self.run.last_error}")

self.assistant.post_run_hook(run=self)
self.assistant.post_run_hook(run=self)
return self


Expand Down Expand Up @@ -193,7 +196,7 @@ async def refresh_run_steps_async(self):
max_attempts = max_fetched / limit + 2

# Fetch the latest run steps
client = get_client()
client = get_openai_client()

response = await client.beta.threads.runs.steps.list(
run_id=self.run.id,
Expand Down
Loading