Skip to content

Commit

Permalink
Merge pull request #650 from PrefectHQ/bring-back-map-for-ai-fn
Browse files Browse the repository at this point in the history
bring back map for `AIFunction`
  • Loading branch information
zzstoatzz authored Nov 16, 2023
2 parents 7144c47 + 3f40b94 commit 4415618
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 130 deletions.
132 changes: 96 additions & 36 deletions src/marvin/components/ai_function.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import inspect
import json
from functools import partial, wraps
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -18,6 +18,11 @@

from marvin.components.prompt import PromptFunction
from marvin.serializers import create_tool_from_type
from marvin.utilities.asyncio import (
ExposeSyncMethodsMixin,
expose_sync_method,
run_async,
)
from marvin.utilities.jinja import (
BaseEnvironment,
)
Expand All @@ -30,7 +35,7 @@
P = ParamSpec("P")


class AIFunction(BaseModel, Generic[P, T]):
class AIFunction(BaseModel, Generic[P, T], ExposeSyncMethodsMixin):
fn: Optional[Callable[P, T]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=inspect.cleandoc("""
Expand All @@ -57,14 +62,32 @@ class AIFunction(BaseModel, Generic[P, T]):

create: Optional[Callable[..., "ChatCompletion"]] = Field(default=None)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
create = self.create
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Union[T, Awaitable[T]]:
if self.fn is None:
raise NotImplementedError
if create is None:
from marvin.settings import settings

create = settings.openai.chat.completions.create
from marvin import settings

is_async_fn = asyncio.iscoroutinefunction(self.fn)

call = "async_call" if is_async_fn else "sync_call"
create = (
self.create or settings.openai.chat.completions.acreate
if is_async_fn
else settings.openai.chat.completions.create
)

return getattr(self, call)(create, *args, **kwargs)

async def async_call(
self, acreate: Callable[..., Awaitable[Any]], *args: P.args, **kwargs: P.kwargs
) -> T:
_response = await acreate(**self.as_prompt(*args, **kwargs).serialize())
return self.parse(_response)

def sync_call(
self, create: Callable[..., Any], *args: P.args, **kwargs: P.kwargs
) -> T:
_response = create(**self.as_prompt(*args, **kwargs).serialize())
return self.parse(_response)

Expand Down Expand Up @@ -93,6 +116,46 @@ def parse(self, response: "ChatCompletion") -> T:
_arguments: str = json.dumps({self.field_name: json.loads(arguments)})
return getattr(tool.model.model_validate_json(_arguments), self.field_name)

@expose_sync_method("map")
async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[T]:
"""
Map the AI function over a sequence of arguments. Runs concurrently.
A `map` twin method is provided by the `expose_sync_method` decorator.
You can use `map` or `amap` synchronously or asynchronously, respectively,
regardless of whether the user function is synchronous or asynchronous.
Arguments should be provided as if calling the function normally, but
each argument must be a list. The function is called once for each item
in the list, and the results are returned in a list.
For example, fn.map([1, 2]) is equivalent to [fn(1), fn(2)].
fn.map([1, 2], x=['a', 'b']) is equivalent to [fn(1, x='a'), fn(2, x='b')].
"""
tasks: list[Any] = []
if map_args and map_kwargs:
max_length = max(
len(arg) for arg in (map_args + tuple(map_kwargs.values()))
)
elif map_args:
max_length = max(len(arg) for arg in map_args)
else:
max_length = max(len(v) for v in map_kwargs.values())

for i in range(max_length):
call_args = [arg[i] if i < len(arg) else None for arg in map_args]
call_kwargs = (
{k: v[i] if i < len(v) else None for k, v in map_kwargs.items()}
if map_kwargs
else {}
)

tasks.append(run_async(self, *call_args, **call_kwargs))

return await asyncio.gather(*tasks)

def as_prompt(
self,
*args: P.args,
Expand Down Expand Up @@ -153,33 +216,24 @@ def as_decorator(
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
acreate: Optional[Callable[..., Awaitable[Any]]] = None,
**render_kwargs: Any,
) -> Union[Self, Callable[[Callable[P, T]], Self]]:
if fn is None:
return partial(
cls,
) -> Union[Callable[[Callable[P, T]], Self], Self]:
def decorator(func: Callable[P, T]) -> Self:
return cls(
fn=func,
environment=environment,
prompt=prompt,
model_name=model_name,
model_description=model_description,
name=model_name,
description=model_description,
field_name=field_name,
field_description=field_description,
acreate=acreate,
**({"prompt": prompt} if prompt else {}),
**render_kwargs,
)

return cls(
fn=fn,
environment=environment,
name=model_name,
description=model_description,
field_name=field_name,
field_description=field_description,
**({"prompt": prompt} if prompt else {}),
**render_kwargs,
)
if fn is not None:
return decorator(fn)

return decorator


@overload
Expand Down Expand Up @@ -221,23 +275,29 @@ def ai_fn(
field_name: str = "data",
field_description: str = "The data to format.",
**render_kwargs: Any,
) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T],]:
def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
return AIFunction[P, T].as_decorator(
func,
) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]:
if fn is not None:
return AIFunction.as_decorator( # type: ignore
fn=fn,
environment=environment,
prompt=prompt,
model_name=model_name,
model_description=model_description,
field_name=field_name,
field_description=field_description,
**render_kwargs,
)(*args, **kwargs)

if fn is not None:
return wraps(fn)(partial(wrapper, fn))
)

def decorator(fn: Callable[P, T]) -> Callable[P, T]:
return wraps(fn)(partial(wrapper, fn))
def decorator(func: Callable[P, T]) -> Callable[P, T]:
return AIFunction.as_decorator( # type: ignore
fn=func,
environment=environment,
prompt=prompt,
model_name=model_name,
model_description=model_description,
field_name=field_name,
field_description=field_description,
**render_kwargs,
)

return decorator
Loading

0 comments on commit 4415618

Please sign in to comment.