Skip to content

Commit

Permalink
move tasks to beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Nov 16, 2023
1 parent 12e1822 commit 3ddeb1a
Showing 1 changed file with 63 additions and 14 deletions.
77 changes: 63 additions & 14 deletions src/marvin/components/ai_task.py → src/marvin/beta/ai_tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import functools
from contextlib import contextmanager
from enum import Enum, auto
from typing import Any, Callable, Generic, Optional, TypeVar

from prefect import flow as prefect_flow
from prefect import task as prefect_task
from pydantic import BaseModel, Field
from rich.prompt import Prompt
from typing_extensions import ParamSpec
Expand All @@ -11,7 +14,6 @@
from marvin.beta.assistants.runs import CancelRun
from marvin.serializers import create_tool_from_type
from marvin.tools.assistants import AssistantTools
from marvin.utilities.asyncio import run_sync
from marvin.utilities.context import ScopedContext
from marvin.utilities.jinja import Environment as JinjaEnvironment
from marvin.utilities.tools import tool_from_function
Expand All @@ -25,11 +27,16 @@
INSTRUCTIONS = """
# Workflow
You are an assistant working with a user to complete a series of tasks. The
You are an assistant working to complete a series of tasks. The
tasks will change from time to time, which is why you may see messages that
appear unrelated to the current task. Each task will be part of a continuous
conversation, so do not continue to reintroduce yourself each time.
Note: Sometimes you will be able to complete a task without user input; other
times you will need to engage the user in conversation. Pay attention to your
instructions.
## Progress
{% for task in tasks -%}
- {{ task.name }}: {{ task.status }}
Expand All @@ -41,7 +48,7 @@
Your job is to complete the "{{ name }}" task.
## Current task description
## Current task instructions
{{ instructions }}
Expand All @@ -54,7 +61,7 @@
{% endif %}
# Completing the task
## Completing a task
After achieving your goal, you MUST call the `task_completed` tool to mark the
task as complete and update these instructions to reflect the next one. The
Expand All @@ -71,15 +78,15 @@
The user CAN NOT see what you post to `task_completed`. It is not a way to
communicate with the user.
# Failing the task
## Failing a task
It may take you a few tries to complete the task. However, if you are ultimately
unable to work with the user to complete it, call the `task_failed` tool to mark
the task as failed and move on to the next one. The payload to `task_failed` is
a string describing why the task failed.
{% if args or kwargs -%}
# Task inputs
## Task inputs
In addition to the thread messages, the following parameters were provided:
{% set sig = inspect.signature(func) -%}
Expand Down Expand Up @@ -121,7 +128,9 @@ def __call__(self, *args: P.args, _thread_id: str = None, **kwargs: P.kwargs) ->
if _thread_id is None:
_thread_id = thread_context.get("thread_id")

return run_sync(self.call(*args, _thread_id=_thread_id, **kwargs))
ptask = prefect_task(name=self.name)(self.call)

return ptask(*args, _thread_id=_thread_id, **kwargs)

async def call(self, *args, _thread_id: str = None, **kwargs):
thread = Thread(id=_thread_id)
Expand All @@ -142,8 +151,8 @@ async def call(self, *args, _thread_id: str = None, **kwargs):
instructions = self.get_instructions(
tasks=thread_context.get("tasks", []),
iterations=iterations,
*args,
**kwargs,
args=args,
kwargs=kwargs,
)

if iterations > 1:
Expand Down Expand Up @@ -173,7 +182,11 @@ async def call(self, *args, _thread_id: str = None, **kwargs):
return self.result

def get_instructions(
self, tasks: list["AITask"], iterations: int, *args: P.args, **kwargs: P.kwargs
self,
tasks: list["AITask"],
iterations: int,
args: tuple[Any],
kwargs: dict[str, Any],
) -> str:
return JinjaEnvironment.render(
INSTRUCTIONS,
Expand All @@ -192,11 +205,10 @@ def _task_completed_tool(self):
_type=self.fn.__annotations__["return"],
model_name="task_completed",
model_description=(
"Use this tool to complete the objective and provide a result that"
" contains its result."
"Indicate that the task completed and produced the provided `result`."
),
field_name="result",
field_description="The objective result",
field_description="The task result",
)

def task_completed(result: T):
Expand Down Expand Up @@ -239,5 +251,42 @@ def wrapper(*func_args, **func_kwargs):
return decorator


@contextmanager
def ai_flow_context(thread_id: str = None, tasks: list[AITask] = None, **kwargs):
# create a new thread for the flow
thread = Thread(id=thread_id)
if thread_id is None:
thread.create()

# create a holder for the tasks
tasks = tasks or []

# enter the thread context
with thread_context(thread_id=thread.id, tasks=tasks, **kwargs):
yield


class AIFlow(BaseModel):
pass
name: Optional[str] = None
fn: Callable

def __call__(self, *args, **kwargs):
pflow = prefect_flow(name=self.name)(self.fn)
# Set up the thread context and execute the flow
with ai_flow_context():
return pflow(*args, **kwargs)


def ai_flow(*args, name=None):
def decorator(func):
@functools.wraps(func)
def wrapper(*func_args, **func_kwargs):
ai_flow_instance = AIFlow(fn=func, name=name or func.__name__)
return ai_flow_instance(*func_args, **func_kwargs)

return wrapper

if args and callable(args[0]):
return decorator(args[0])

return decorator

0 comments on commit 3ddeb1a

Please sign in to comment.