Skip to content

Commit

Permalink
Support assistants v2
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Apr 18, 2024
1 parent f1500a0 commit b8c8d57
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 36 deletions.
12 changes: 5 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"httpx>=0.24.1",
"jinja2>=3.1.2",
"jsonpatch>=1.33",
"openai>=1.16.0",
"openai>=1.21.0",
"prompt-toolkit>=3.0.33",
"pydantic>=2.4.2",
"pydantic_settings",
Expand Down Expand Up @@ -59,7 +59,7 @@ tests = [
"pytest-xdist",
]
audio = [
"SpeechRecognition>=3.10",
"SpeechRecognition>=3.10",
"PyAudio>=0.2.11",
# playsound reqs
"playsound >= 1.0",
Expand All @@ -69,9 +69,7 @@ audio = [
"pydub>=0.25",
"simpleaudio>=1.0",
]
video = [
"opencv-python >= 4.5",
]
video = ["opencv-python >= 4.5"]
slackbot = ["marvin[prefect]", "numpy", "raggy", "turbopuffer"]

[project.urls]
Expand All @@ -92,8 +90,8 @@ write_to = "src/marvin/_version.py"
[tool.pytest.ini_options]
markers = [
"llm: indicates that a test calls an LLM (may be slow).",
"no_llm: indicates that a test does not require an LLM."
]
"no_llm: indicates that a test does not require an LLM.",
]
timeout = 20
testpaths = ["tests"]

Expand Down
2 changes: 1 addition & 1 deletion src/marvin/beta/assistants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .assistants import Assistant
from .handlers import PrintHandler
from .formatting import pprint_messages, pprint_steps, pprint_run
from marvin.tools.assistants import Retrieval, CodeInterpreter
from marvin.tools.assistants import FileSearch, CodeInterpreter
31 changes: 22 additions & 9 deletions src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from openai import AsyncAssistantEventHandler
from prompt_toolkit import PromptSession
Expand Down Expand Up @@ -47,20 +47,21 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin):
id (str): The unique identifier of the assistant. None if the assistant
hasn't been created yet.
name (str): The name of the assistant.
description (str): A description of the assistant.
instructions (str): Instructions for the assistant.
model (str): The model used by the assistant.
metadata (dict): Additional data about the assistant.
file_ids (list): List of file IDs associated with the assistant.
tools (list): List of tools used by the assistant.
instructions (str): Instructions for the assistant.
tool_resources (dict): dict of tool resources associated with the assistant.
metadata (dict): Additional data about the assistant.
"""

id: Optional[str] = None
name: str = "Assistant"
description: Optional[str] = None
instructions: Optional[str] = Field(None)
model: str = Field(None, validate_default=True)
instructions: Optional[str] = Field(None, repr=False)
tools: list[Union[AssistantTool, Callable]] = []
file_ids: list[str] = []
tool_resources: dict[str, Any] = {}
metadata: dict[str, str] = {}
# context level tracks nested assistant contexts
_context_level: int = PrivateAttr(0)
Expand Down Expand Up @@ -99,7 +100,8 @@ def get_instructions(self, thread: Thread = None) -> str:
async def say_async(
self,
message: str,
file_paths: Optional[list[str]] = None,
code_interpreter_files: Optional[list[str]] = None,
file_search_files: Optional[list[str]] = None,
thread: Optional[Thread] = None,
event_handler_class: type[AsyncAssistantEventHandler] = NOT_PROVIDED,
**run_kwargs,
Expand All @@ -110,7 +112,11 @@ async def say_async(
event_handler_class = default_run_handler_class()

# post the message
user_message = await thread.add_async(message, file_paths=file_paths)
user_message = await thread.add_async(
message,
code_interpreter_files=code_interpreter_files,
file_search_files=file_search_files,
)

from marvin.beta.assistants.runs import Run

Expand Down Expand Up @@ -159,7 +165,14 @@ async def create_async(self, _auto_delete: bool = False):
client = marvin.utilities.openai.get_openai_client()
response = await client.beta.assistants.create(
**self.model_dump(
include={"name", "model", "metadata", "file_ids", "metadata"}
include={
"name",
"model",
"description",
"metadata",
"tool_resources",
"metadata",
}
),
tools=[tool.model_dump() for tool in self.get_tools()],
instructions=self.get_instructions(),
Expand Down
16 changes: 9 additions & 7 deletions src/marvin/beta/assistants/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def download_temp_file(file_id: str):

client = get_openai_client(is_async=False)
file = client.files.retrieve(file_id)
if file.purpose == "assistants":
return
filename = Path(file.filename).name
response = client.files.content(file_id)

Expand Down Expand Up @@ -76,8 +78,8 @@ def format_code_interpreter_tool_call(step, tool_call):
else:
for o in tool_call.code_interpreter.outputs:
if o.type == "image":
local_file_path = download_temp_file(o.image.file_id)
attachments.append(local_file_path)
if local_file_path := download_temp_file(o.image.file_id):
attachments.append(local_file_path)
status = ":heavy_check_mark: Finished!"
content = [status, "\n", code]
if attachments:
Expand Down Expand Up @@ -182,12 +184,12 @@ def format_message(message: Message) -> Panel:
elif item.type == "image_file":
# Use the download_temp_file function to download the file and get
# the local path
local_file_path = download_temp_file(item.image_file.file_id)
attachments.append(local_file_path)
if local_file_path := download_temp_file(item.image_file.file_id):
attachments.append(local_file_path)

for file_id in message.file_ids:
local_file_path = download_temp_file(file_id)
attachments.append(local_file_path)
for attachment in message.attachments:
if local_file_path := download_temp_file(attachment.file_id):
attachments.append(local_file_path)

if attachments:
content.append(
Expand Down
24 changes: 18 additions & 6 deletions src/marvin/beta/assistants/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ async def create_async(self, messages: list[str] = None):

@expose_sync_method("add")
async def add_async(
self, message: str, file_paths: Optional[list[str]] = None, role: str = "user"
self,
message: str,
role: str = "user",
code_interpreter_files: Optional[list[str]] = None,
file_search_files: Optional[list[str]] = None,
) -> Message:
"""
Add a user message to the thread.
Expand All @@ -73,15 +77,23 @@ async def add_async(
await self.create_async()

# Upload files and collect their IDs
file_ids = []
for file_path in file_paths or []:
with open(file_path, mode="rb") as file:
attachments = []
for fp in code_interpreter_files or []:
with open(fp, mode="rb") as file:
response = await client.files.create(file=file, purpose="assistants")
attachments.append(
dict(file_id=response.id, tools=[dict(type="code_interpreter")])
)
for fp in file_search_files or []:
with open(fp, mode="rb") as file:
response = await client.files.create(file=file, purpose="assistants")
file_ids.append(response.id)
attachments.append(
dict(file_id=response.id, tools=[dict(type="file_search")])
)

# Create the message with the attached files
response = await client.beta.threads.messages.create(
thread_id=self.id, role=role, content=message, file_ids=file_ids
thread_id=self.id, role=role, content=message, attachments=attachments
)
return response

Expand Down
6 changes: 3 additions & 3 deletions src/marvin/tools/assistants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, Union

from marvin.types import CodeInterpreterTool, RetrievalTool, Tool
from marvin.types import CodeInterpreterTool, FileSearchTool, Tool

Retrieval = RetrievalTool()
FileSearch = FileSearchTool()
CodeInterpreter = CodeInterpreterTool()

AssistantTool = Union[RetrievalTool, CodeInterpreterTool, Tool]
AssistantTool = Union[FileSearchTool, CodeInterpreterTool, Tool]

ENDRUN_TOKEN = "<|ENDRUN|>"

Expand Down
6 changes: 3 additions & 3 deletions src/marvin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class ToolSet(MarvinType, Generic[T]):
tool_choice: Optional[Union[Literal["auto"], dict[str, Any]]] = None


class RetrievalTool(Tool):
type: Literal["retrieval"] = "retrieval"
class FileSearchTool(Tool):
type: Literal["file_search"] = "file_search"


class CodeInterpreterTool(Tool):
Expand Down Expand Up @@ -244,7 +244,7 @@ class AssistantMessage(BaseMessage):
created_at: int
assistant_id: Optional[str] = None
run_id: Optional[str] = None
file_ids: list[str] = []
attachments: list[dict] = []
metadata: dict[str, Any] = {}


Expand Down

0 comments on commit b8c8d57

Please sign in to comment.