Skip to content

Commit

Permalink
Merge pull request #843 from PrefectHQ/assistant-tests
Browse files Browse the repository at this point in the history
Speed up assistant tests
  • Loading branch information
jlowin authored Feb 12, 2024
2 parents 6c503d8 + 4276fcc commit 6722933
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
12 changes: 6 additions & 6 deletions src/marvin/beta/assistants/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from openai.types.beta.threads.runs import RunStep as OpenAIRunStep
from pydantic import BaseModel, Field, PrivateAttr, field_validator

import marvin.utilities.openai
import marvin.utilities.tools
from marvin.tools.assistants import AssistantTool, CancelRun
from marvin.types import Tool
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from marvin.utilities.logging import get_logger
from marvin.utilities.openai import get_openai_client

from .assistants import Assistant
from .threads import Thread
Expand Down Expand Up @@ -75,23 +75,23 @@ def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]):
@expose_sync_method("refresh")
async def refresh_async(self):
"""Refreshes the run."""
client = get_openai_client()
client = marvin.utilities.openai.get_openai_client()
self.run = await client.beta.threads.runs.retrieve(
run_id=self.run.id, thread_id=self.thread.id
)

@expose_sync_method("cancel")
async def cancel_async(self):
"""Cancels the run."""
client = get_openai_client()
client = marvin.utilities.openai.get_openai_client()
await client.beta.threads.runs.cancel(
run_id=self.run.id, thread_id=self.thread.id
)

async def _handle_step_requires_action(
self,
) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]:
client = get_openai_client()
client = marvin.utilities.openai.get_openai_client()
if self.run.status != "requires_action":
return None, None
if self.run.required_action.type == "submit_tool_outputs":
Expand Down Expand Up @@ -146,7 +146,7 @@ def get_tools(self) -> list[AssistantTool]:

async def run_async(self) -> "Run":
"""Excutes a run asynchronously."""
client = get_openai_client()
client = marvin.utilities.openai.get_openai_client()

create_kwargs = {}

Expand Down Expand Up @@ -233,7 +233,7 @@ async def refresh_run_steps_async(self):
max_attempts = max_fetched / limit + 2

# Fetch the latest run steps
client = get_openai_client()
client = marvin.utilities.openai.get_openai_client()

response = await client.beta.threads.runs.steps.list(
run_id=self.run.id,
Expand Down
10 changes: 5 additions & 5 deletions src/marvin/beta/assistants/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from openai.types.beta.threads import ThreadMessage
from pydantic import BaseModel, Field, PrivateAttr

import marvin.utilities.openai
from marvin.beta.assistants.formatting import pprint_message
from marvin.utilities.asyncio import (
ExposeSyncMethodsMixin,
expose_sync_method,
run_sync,
)
from marvin.utilities.logging import get_logger
from marvin.utilities.openai import get_openai_client
from marvin.utilities.pydantic import parse_as

logger = get_logger("Threads")
Expand Down Expand Up @@ -59,7 +59,7 @@ async def create_async(self, messages: list[str] = None):
raise ValueError("Thread has already been created.")
if messages is not None:
messages = [{"role": "user", "content": message} for message in messages]
client = get_openai_client()
client = marvin.utilities.openai.get_openai_client()
response = await client.beta.threads.create(messages=messages)
self.id = response.id
return self
Expand All @@ -71,7 +71,7 @@ async def add_async(
"""
Add a user message to the thread.
"""
client = get_openai_client()
client = marvin.utilities.openai.get_openai_client()

if self.id is None:
await self.create_async()
Expand Down Expand Up @@ -118,7 +118,7 @@ async def get_messages_async(

if self.id is None:
await self.create_async()
client = get_openai_client()
client = marvin.utilities.openai.get_openai_client()

response = await client.beta.threads.messages.list(
thread_id=self.id,
Expand All @@ -136,7 +136,7 @@ async def get_messages_async(

@expose_sync_method("delete")
async def delete_async(self):
client = get_openai_client()
client = marvin.utilities.openai.get_openai_client()
await client.beta.threads.delete(thread_id=self.id)
self.id = None

Expand Down
22 changes: 13 additions & 9 deletions tests/beta/assistants/test_assistants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import AsyncMock, patch

import marvin
import openai
Expand All @@ -18,6 +18,11 @@ def mock_get_client(monkeypatch):
monkeypatch.setattr("marvin.utilities.openai.get_openai_client", mocked_client)


@pytest.fixture(autouse=True)
def mock_say(monkeypatch):
monkeypatch.setattr(client.beta.threads.runs, "create", AsyncMock())


class TestLifeCycle:
@patch.object(client.beta.assistants, "delete", wraps=client.beta.assistants.delete)
@patch.object(client.beta.assistants, "create", wraps=client.beta.assistants.create)
Expand All @@ -26,8 +31,7 @@ def test_interactive(self, mock_create, mock_delete):
ai = Assistant()
mock_create.assert_not_called()
assert not ai.id
response = ai.say("repeat the word hi")
assert response
ai.say("hi")
assert not ai.id
mock_create.assert_called()
mock_delete.assert_called()
Expand All @@ -40,9 +44,9 @@ def test_context_manager(self, mock_create, mock_delete):
with ai:
mock_create.assert_called()
assert ai.id
ai.say("repeat the word hi")
ai.say("hi")
mock_delete.assert_not_called()
ai.say("repeat the word hi")
ai.say("hi")
mock_delete.assert_not_called()
assert ai.id
assert not ai.id
Expand All @@ -58,9 +62,9 @@ def test_manual_lifecycle(self, mock_create, mock_delete):
ai.create()
mock_create.assert_called()
assert ai.id
ai.say("repeat the word hi")
ai.say("hi")
mock_delete.assert_not_called()
ai.say("repeat the word hi")
ai.say("hi")
mock_delete.assert_not_called()
assert ai.id
ai.delete()
Expand Down Expand Up @@ -88,9 +92,9 @@ def test_load_from_api(self):
mock_create.assert_not_called()

assert ai.id
ai.say("repeat the word hi")
ai.say("hi")
mock_delete.assert_not_called()
ai.say("repeat the word hi")
ai.say("hi")
mock_delete.assert_not_called()
assert ai.id
ai.delete()
Expand Down

0 comments on commit 6722933

Please sign in to comment.