Skip to content

Commit

Permalink
Run pre-commit formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
salman1993 committed Feb 8, 2024
1 parent 2d7e134 commit ecda861
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
14 changes: 8 additions & 6 deletions src/marvin/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import TYPE_CHECKING, Callable, Optional, Union

from openai.types.beta.threads.required_action_function_tool_call import (
RequiredActionFunctionToolCall,
)
from pydantic import BaseModel, Field, PrivateAttr
from openai.types.beta.threads.required_action_function_tool_call import RequiredActionFunctionToolCall

import marvin.utilities.tools
from marvin.tools.assistants import AssistantTool
Expand Down Expand Up @@ -170,9 +172,9 @@ def pre_run_hook(self, run: "Run"):
pass

def post_run_hook(
self,
run: "Run",
tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None,
tool_outputs: Optional[list[dict[str, str]]] = None
):
self,
run: "Run",
tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None,
tool_outputs: Optional[list[dict[str, str]]] = None,
):
pass
17 changes: 13 additions & 4 deletions src/marvin/beta/assistants/runs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
from typing import Any, Callable, Optional, Union

from openai.types.beta.threads.required_action_function_tool_call import (
RequiredActionFunctionToolCall,
)
from openai.types.beta.threads.run import Run as OpenAIRun
from openai.types.beta.threads.runs import RunStep as OpenAIRunStep
from pydantic import BaseModel, Field, PrivateAttr, field_validator
Expand All @@ -11,7 +14,6 @@
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from marvin.utilities.logging import get_logger
from marvin.utilities.openai import get_openai_client
from openai.types.beta.threads.required_action_function_tool_call import RequiredActionFunctionToolCall

from .assistants import Assistant
from .threads import Thread
Expand Down Expand Up @@ -86,7 +88,9 @@ async def cancel_async(self):
run_id=self.run.id, thread_id=self.thread.id
)

async def _handle_step_requires_action(self) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]:
async def _handle_step_requires_action(
self,
) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]:
client = get_openai_client()
if self.run.status != "requires_action":
return None, None
Expand Down Expand Up @@ -167,7 +171,10 @@ async def run_async(self) -> "Run":
try:
while self.run.status in ("queued", "in_progress", "requires_action"):
if self.run.status == "requires_action":
tool_calls, tool_outputs = await self._handle_step_requires_action()
(
tool_calls,
tool_outputs,
) = await self._handle_step_requires_action()
await asyncio.sleep(0.1)
await self.refresh_async()
except CancelRun as exc:
Expand All @@ -181,7 +188,9 @@ async def run_async(self) -> "Run":
if self.run.status == "failed":
logger.debug(f"Run failed. Last error was: {self.run.last_error}")

self.assistant.post_run_hook(run=self, tool_calls=tool_calls, tool_outputs=tool_outputs)
self.assistant.post_run_hook(
run=self, tool_calls=tool_calls, tool_outputs=tool_outputs
)
return self


Expand Down

0 comments on commit ecda861

Please sign in to comment.