From dacf2dafb7898ca474d3e87b2b9aa5243143c135 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 12 Feb 2024 13:23:15 -0500 Subject: [PATCH 1/2] Speed up assistant tests --- src/marvin/beta/assistants/runs.py | 12 ++++++------ src/marvin/beta/assistants/threads.py | 10 +++++----- tests/beta/assistants/test_assistants.py | 21 +++++++++++++-------- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index 7f46756e9..3e77c2266 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -8,12 +8,12 @@ from openai.types.beta.threads.runs import RunStep as OpenAIRunStep from pydantic import BaseModel, Field, PrivateAttr, field_validator +import marvin.utilities.openai import marvin.utilities.tools from marvin.tools.assistants import AssistantTool, CancelRun from marvin.types import Tool from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.logging import get_logger -from marvin.utilities.openai import get_openai_client from .assistants import Assistant from .threads import Thread @@ -75,7 +75,7 @@ def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]): @expose_sync_method("refresh") async def refresh_async(self): """Refreshes the run.""" - client = get_openai_client() + client = marvin.utilities.openai.get_openai_client() self.run = await client.beta.threads.runs.retrieve( run_id=self.run.id, thread_id=self.thread.id ) @@ -83,7 +83,7 @@ async def refresh_async(self): @expose_sync_method("cancel") async def cancel_async(self): """Cancels the run.""" - client = get_openai_client() + client = marvin.utilities.openai.get_openai_client() await client.beta.threads.runs.cancel( run_id=self.run.id, thread_id=self.thread.id ) @@ -91,7 +91,7 @@ async def cancel_async(self): async def _handle_step_requires_action( self, ) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]: - client = get_openai_client() + client = marvin.utilities.openai.get_openai_client() if self.run.status != "requires_action": return None, None if self.run.required_action.type == "submit_tool_outputs": @@ -146,7 +146,7 @@ def get_tools(self) -> list[AssistantTool]: async def run_async(self) -> "Run": """Excutes a run asynchronously.""" - client = get_openai_client() + client = marvin.utilities.openai.get_openai_client() create_kwargs = {} @@ -233,7 +233,7 @@ async def refresh_run_steps_async(self): max_attempts = max_fetched / limit + 2 # Fetch the latest run steps - client = get_openai_client() + client = marvin.utilities.openai.get_openai_client() response = await client.beta.threads.runs.steps.list( run_id=self.run.id, diff --git a/src/marvin/beta/assistants/threads.py b/src/marvin/beta/assistants/threads.py index a6a547b76..42704edf7 100644 --- a/src/marvin/beta/assistants/threads.py +++ b/src/marvin/beta/assistants/threads.py @@ -5,6 +5,7 @@ from openai.types.beta.threads import ThreadMessage from pydantic import BaseModel, Field, PrivateAttr +import marvin.utilities.openai from marvin.beta.assistants.formatting import pprint_message from marvin.utilities.asyncio import ( ExposeSyncMethodsMixin, @@ -12,7 +13,6 @@ run_sync, ) from marvin.utilities.logging import get_logger -from marvin.utilities.openai import get_openai_client from marvin.utilities.pydantic import parse_as logger = get_logger("Threads") @@ -59,7 +59,7 @@ async def create_async(self, messages: list[str] = None): raise ValueError("Thread has already been created.") if messages is not None: messages = [{"role": "user", "content": message} for message in messages] - client = get_openai_client() + client = marvin.utilities.openai.get_openai_client() response = await client.beta.threads.create(messages=messages) self.id = response.id return self @@ -71,7 +71,7 @@ async def add_async( """ Add a user message to the thread. """ - client = get_openai_client() + client = marvin.utilities.openai.get_openai_client() if self.id is None: await self.create_async() @@ -118,7 +118,7 @@ async def get_messages_async( if self.id is None: await self.create_async() - client = get_openai_client() + client = marvin.utilities.openai.get_openai_client() response = await client.beta.threads.messages.list( thread_id=self.id, @@ -136,7 +136,7 @@ async def get_messages_async( @expose_sync_method("delete") async def delete_async(self): - client = get_openai_client() + client = marvin.utilities.openai.get_openai_client() await client.beta.threads.delete(thread_id=self.id) self.id = None diff --git a/tests/beta/assistants/test_assistants.py b/tests/beta/assistants/test_assistants.py index 073672fe3..4b183514c 100644 --- a/tests/beta/assistants/test_assistants.py +++ b/tests/beta/assistants/test_assistants.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import marvin import openai @@ -18,6 +18,11 @@ def mock_get_client(monkeypatch): monkeypatch.setattr("marvin.utilities.openai.get_openai_client", mocked_client) +@pytest.fixture(autouse=True) +def mock_say(monkeypatch): + monkeypatch.setattr(client.beta.threads.runs, "create", AsyncMock()) + + class TestLifeCycle: @patch.object(client.beta.assistants, "delete", wraps=client.beta.assistants.delete) @patch.object(client.beta.assistants, "create", wraps=client.beta.assistants.create) @@ -26,7 +31,7 @@ def test_interactive(self, mock_create, mock_delete): ai = Assistant() mock_create.assert_not_called() assert not ai.id - response = ai.say("repeat the word hi") + response = ai.say("hi") assert response assert not ai.id mock_create.assert_called() @@ -40,9 +45,9 @@ def test_context_manager(self, mock_create, mock_delete): with ai: mock_create.assert_called() assert ai.id - ai.say("repeat the word hi") + ai.say("hi") mock_delete.assert_not_called() - ai.say("repeat the word hi") + ai.say("hi") mock_delete.assert_not_called() assert ai.id assert not ai.id @@ -58,9 +63,9 @@ def test_manual_lifecycle(self, mock_create, mock_delete): ai.create() mock_create.assert_called() assert ai.id - ai.say("repeat the word hi") + ai.say("hi") mock_delete.assert_not_called() - ai.say("repeat the word hi") + ai.say("hi") mock_delete.assert_not_called() assert ai.id ai.delete() @@ -88,9 +93,9 @@ def test_load_from_api(self): mock_create.assert_not_called() assert ai.id - ai.say("repeat the word hi") + ai.say("hi") mock_delete.assert_not_called() - ai.say("repeat the word hi") + ai.say("hi") mock_delete.assert_not_called() assert ai.id ai.delete() From 4276fccea7d1c49303902465894044207f24b7fb Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 12 Feb 2024 13:29:02 -0500 Subject: [PATCH 2/2] Update test_assistants.py --- tests/beta/assistants/test_assistants.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/beta/assistants/test_assistants.py b/tests/beta/assistants/test_assistants.py index 4b183514c..1c4ec197b 100644 --- a/tests/beta/assistants/test_assistants.py +++ b/tests/beta/assistants/test_assistants.py @@ -31,8 +31,7 @@ def test_interactive(self, mock_create, mock_delete): ai = Assistant() mock_create.assert_not_called() assert not ai.id - response = ai.say("hi") - assert response + ai.say("hi") assert not ai.id mock_create.assert_called() mock_delete.assert_called()