diff --git a/src/marvin/components/ai_task.py b/src/marvin/beta/ai_tasks.py similarity index 78% rename from src/marvin/components/ai_task.py rename to src/marvin/beta/ai_tasks.py index 8949bb7e6..3ddb47e9e 100644 --- a/src/marvin/components/ai_task.py +++ b/src/marvin/beta/ai_tasks.py @@ -1,7 +1,10 @@ import functools +from contextlib import contextmanager from enum import Enum, auto from typing import Any, Callable, Generic, Optional, TypeVar +from prefect import flow as prefect_flow +from prefect import task as prefect_task from pydantic import BaseModel, Field from rich.prompt import Prompt from typing_extensions import ParamSpec @@ -11,7 +14,6 @@ from marvin.beta.assistants.runs import CancelRun from marvin.serializers import create_tool_from_type from marvin.tools.assistants import AssistantTools -from marvin.utilities.asyncio import run_sync from marvin.utilities.context import ScopedContext from marvin.utilities.jinja import Environment as JinjaEnvironment from marvin.utilities.tools import tool_from_function @@ -25,11 +27,16 @@ INSTRUCTIONS = """ # Workflow -You are an assistant working with a user to complete a series of tasks. The +You are an assistant working to complete a series of tasks. The tasks will change from time to time, which is why you may see messages that appear unrelated to the current task. Each task will be part of a continuous conversation, so do not continue to reintroduce yourself each time. +Note: Sometimes you will be able to complete a task without user input; other +times you will need to engage the user in conversation. Pay attention to your +instructions. + + ## Progress {% for task in tasks -%} - {{ task.name }}: {{ task.status }} @@ -41,7 +48,7 @@ Your job is to complete the "{{ name }}" task. -## Current task description +## Current task instructions {{ instructions }} @@ -54,7 +61,7 @@ {% endif %} -# Completing the task +## Completing a task After achieving your goal, you MUST call the `task_completed` tool to mark the task as complete and update these instructions to reflect the next one. The @@ -71,7 +78,7 @@ The user CAN NOT see what you post to `task_completed`. It is not a way to communicate with the user. -# Failing the task +## Failing a task It may take you a few tries to complete the task. However, if you are ultimately unable to work with the user to complete it, call the `task_failed` tool to mark @@ -79,7 +86,7 @@ a string describing why the task failed. {% if args or kwargs -%} -# Task inputs +## Task inputs In addition to the thread messages, the following parameters were provided: {% set sig = inspect.signature(func) -%} @@ -121,7 +128,9 @@ def __call__(self, *args: P.args, _thread_id: str = None, **kwargs: P.kwargs) -> if _thread_id is None: _thread_id = thread_context.get("thread_id") - return run_sync(self.call(*args, _thread_id=_thread_id, **kwargs)) + ptask = prefect_task(name=self.name)(self.call) + + return ptask(*args, _thread_id=_thread_id, **kwargs) async def call(self, *args, _thread_id: str = None, **kwargs): thread = Thread(id=_thread_id) @@ -142,8 +151,8 @@ async def call(self, *args, _thread_id: str = None, **kwargs): instructions = self.get_instructions( tasks=thread_context.get("tasks", []), iterations=iterations, - *args, - **kwargs, + args=args, + kwargs=kwargs, ) if iterations > 1: @@ -173,7 +182,11 @@ async def call(self, *args, _thread_id: str = None, **kwargs): return self.result def get_instructions( - self, tasks: list["AITask"], iterations: int, *args: P.args, **kwargs: P.kwargs + self, + tasks: list["AITask"], + iterations: int, + args: tuple[Any], + kwargs: dict[str, Any], ) -> str: return JinjaEnvironment.render( INSTRUCTIONS, @@ -192,11 +205,10 @@ def _task_completed_tool(self): _type=self.fn.__annotations__["return"], model_name="task_completed", model_description=( - "Use this tool to complete the objective and provide a result that" - " contains its result." + "Indicate that the task completed and produced the provided `result`." ), field_name="result", - field_description="The objective result", + field_description="The task result", ) def task_completed(result: T): @@ -239,5 +251,42 @@ def wrapper(*func_args, **func_kwargs): return decorator +@contextmanager +def ai_flow_context(thread_id: str = None, tasks: list[AITask] = None, **kwargs): + # create a new thread for the flow + thread = Thread(id=thread_id) + if thread_id is None: + thread.create() + + # create a holder for the tasks + tasks = tasks or [] + + # enter the thread context + with thread_context(thread_id=thread.id, tasks=tasks, **kwargs): + yield + + class AIFlow(BaseModel): - pass + name: Optional[str] = None + fn: Callable + + def __call__(self, *args, **kwargs): + pflow = prefect_flow(name=self.name)(self.fn) + # Set up the thread context and execute the flow + with ai_flow_context(): + return pflow(*args, **kwargs) + + +def ai_flow(*args, name=None): + def decorator(func): + @functools.wraps(func) + def wrapper(*func_args, **func_kwargs): + ai_flow_instance = AIFlow(fn=func, name=name or func.__name__) + return ai_flow_instance(*func_args, **func_kwargs) + + return wrapper + + if args and callable(args[0]): + return decorator(args[0]) + + return decorator