From dc93d97d01fd24c4ed7643708573357c622db787 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Thu, 16 Nov 2023 04:19:22 -0600 Subject: [PATCH 1/4] bring back map --- src/marvin/components/ai_function.py | 125 +++++++++++++++++++-------- 1 file changed, 89 insertions(+), 36 deletions(-) diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py index 5d820805c..26d4af106 100644 --- a/src/marvin/components/ai_function.py +++ b/src/marvin/components/ai_function.py @@ -1,6 +1,6 @@ +import asyncio import inspect import json -from functools import partial, wraps from typing import ( TYPE_CHECKING, Any, @@ -18,6 +18,11 @@ from marvin.components.prompt import PromptFunction from marvin.serializers import create_tool_from_type +from marvin.utilities.asyncio import ( + ExposeSyncMethodsMixin, + expose_sync_method, + run_async, +) from marvin.utilities.jinja import ( BaseEnvironment, ) @@ -30,7 +35,7 @@ P = ParamSpec("P") -class AIFunction(BaseModel, Generic[P, T]): +class AIFunction(BaseModel, Generic[P, T], ExposeSyncMethodsMixin): fn: Optional[Callable[P, T]] = None environment: Optional[BaseEnvironment] = None prompt: Optional[str] = Field(default=inspect.cleandoc(""" @@ -57,14 +62,28 @@ class AIFunction(BaseModel, Generic[P, T]): create: Optional[Callable[..., "ChatCompletion"]] = Field(default=None) - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - create = self.create + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Union[T, Awaitable[T]]: if self.fn is None: raise NotImplementedError - if create is None: - from marvin.settings import settings - create = settings.openai.chat.completions.create + from marvin import settings + + is_async_fn = asyncio.iscoroutinefunction(self.fn) + + call = "async_call" if is_async_fn else "sync_call" + create = ( + self.create or settings.openai.chat.completions.acreate + if is_async_fn + else settings.openai.chat.completions.create + ) + + return getattr(self, call)(create, *args, **kwargs) + + async def async_call(self, acreate, *args: P.args, **kwargs: P.kwargs) -> T: + _response = await acreate(**self.as_prompt(*args, **kwargs).serialize()) + return self.parse(_response) + + def sync_call(self, create, *args: P.args, **kwargs: P.kwargs) -> T: _response = create(**self.as_prompt(*args, **kwargs).serialize()) return self.parse(_response) @@ -93,6 +112,43 @@ def parse(self, response: "ChatCompletion") -> T: _arguments: str = json.dumps({self.field_name: json.loads(arguments)}) return getattr(tool.model.model_validate_json(_arguments), self.field_name) + @expose_sync_method("map") + async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[T]: + """ + Map the AI function over a sequence of arguments. Runs concurrently. + + Arguments should be provided as if calling the function normally, but + each argument must be a list. The function is called once for each item + in the list, and the results are returned in a list. + + This method should be called synchronously. + + For example, fn.map([1, 2]) is equivalent to [fn(1), fn(2)]. + + fn.map([1, 2], x=['a', 'b']) is equivalent to [fn(1, x='a'), fn(2, x='b')]. + """ + tasks: list[Any] = [] + if map_args and map_kwargs: + max_length = max( + len(arg) for arg in (map_args + tuple(map_kwargs.values())) + ) + elif map_args: + max_length = max(len(arg) for arg in map_args) + else: + max_length = max(len(v) for v in map_kwargs.values()) + + for i in range(max_length): + call_args = [arg[i] if i < len(arg) else None for arg in map_args] + call_kwargs = ( + {k: v[i] if i < len(v) else None for k, v in map_kwargs.items()} + if map_kwargs + else {} + ) + + tasks.append(run_async(self, *call_args, **call_kwargs)) + + return await asyncio.gather(*tasks) + def as_prompt( self, *args: P.args, @@ -153,33 +209,24 @@ def as_decorator( model_description: str = "Formats the response.", field_name: str = "data", field_description: str = "The data to format.", - acreate: Optional[Callable[..., Awaitable[Any]]] = None, **render_kwargs: Any, - ) -> Union[Self, Callable[[Callable[P, T]], Self]]: - if fn is None: - return partial( - cls, + ) -> Callable[..., Self]: + def decorator(func: Callable[P, T]) -> Self: + return cls( + fn=func, environment=environment, - prompt=prompt, - model_name=model_name, - model_description=model_description, + name=model_name, + description=model_description, field_name=field_name, field_description=field_description, - acreate=acreate, **({"prompt": prompt} if prompt else {}), **render_kwargs, ) - return cls( - fn=fn, - environment=environment, - name=model_name, - description=model_description, - field_name=field_name, - field_description=field_description, - **({"prompt": prompt} if prompt else {}), - **render_kwargs, - ) + if fn is not None: + return decorator(fn) + + return decorator @overload @@ -221,10 +268,10 @@ def ai_fn( field_name: str = "data", field_description: str = "The data to format.", **render_kwargs: Any, -) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T],]: - def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: - return AIFunction[P, T].as_decorator( - func, +) -> Union[Callable[[Callable[P, T]], AIFunction[P, T]], AIFunction[P, T]]: + if fn is not None: + return AIFunction.as_decorator( + fn=fn, environment=environment, prompt=prompt, model_name=model_name, @@ -232,12 +279,18 @@ def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: field_name=field_name, field_description=field_description, **render_kwargs, - )(*args, **kwargs) - - if fn is not None: - return wraps(fn)(partial(wrapper, fn)) + ) - def decorator(fn: Callable[P, T]) -> Callable[P, T]: - return wraps(fn)(partial(wrapper, fn)) + def decorator(func: Callable[P, T]) -> AIFunction[P, T]: + return AIFunction.as_decorator( + fn=func, + environment=environment, + prompt=prompt, + model_name=model_name, + model_description=model_description, + field_name=field_name, + field_description=field_description, + **render_kwargs, + ) return decorator From 2634789bd57b3b046dd8099071c90ec7c5723c40 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Thu, 16 Nov 2023 04:26:59 -0600 Subject: [PATCH 2/4] update ai_fn tests --- tests/test_components/test_ai_functions.py | 197 +++++++++++---------- 1 file changed, 103 insertions(+), 94 deletions(-) diff --git a/tests/test_components/test_ai_functions.py b/tests/test_components/test_ai_functions.py index 5b7860202..ffd5ab184 100644 --- a/tests/test_components/test_ai_functions.py +++ b/tests/test_components/test_ai_functions.py @@ -38,120 +38,129 @@ async def list_fruit(n: int) -> list[str]: result = await coro assert len(result) == 3 - def test_list_fruit_with_generic_type_hints(self): - @ai_fn - def list_fruit(n: int) -> List[str]: - """Returns a list of `n` fruit""" - - result = list_fruit(3) - assert len(result) == 3 - - def test_basemodel_return_annotation(self): - class Fruit(BaseModel): - name: str - color: str - - @ai_fn - def get_fruit(description: str) -> Fruit: - """Returns a fruit with the provided description""" - - fruit = get_fruit("loved by monkeys") - assert fruit.name.lower() == "banana" - assert fruit.color.lower() == "yellow" - - @pytest.mark.parametrize("name,expected", [("banana", True), ("car", False)]) - def test_bool_return_annotation(self, name, expected): - @ai_fn - def is_fruit(name: str) -> bool: - """Returns True if the provided name is a fruit""" - - assert is_fruit(name) == expected - - def test_plain_dict_return_type(self): - @ai_fn - def get_fruit(name: str) -> dict: - """Returns a fruit with the provided name and color""" - - fruit = get_fruit("banana") - assert fruit["name"].lower() == "banana" - assert fruit["color"].lower() == "yellow" - - def test_annotated_dict_return_type(self): - @ai_fn - def get_fruit(name: str) -> dict[str, str]: - """Returns a fruit with the provided name and color""" - - fruit = get_fruit("banana") - assert fruit["name"].lower() == "banana" - assert fruit["color"].lower() == "yellow" - - def test_generic_dict_return_type(self): - @ai_fn - def get_fruit(name: str) -> Dict[str, str]: - """Returns a fruit with the provided name and color""" - - fruit = get_fruit("banana") - assert fruit["name"].lower() == "banana" - assert fruit["color"].lower() == "yellow" - - def test_int_return_type(self): - @ai_fn - def get_fruit(name: str) -> int: - """Returns the number of letters in the provided fruit name""" - - assert get_fruit("banana") == 6 - - def test_float_return_type(self): - @ai_fn - def get_fruit(name: str) -> float: - """Returns the number of letters in the provided fruit name""" - - assert get_fruit("banana") == 6.0 - - def test_tuple_return_type(self): - @ai_fn - def get_fruit(name: str) -> tuple: - """Returns the number of letters in the provided fruit name""" - - assert get_fruit("banana") == (6,) - - def test_set_return_type(self): - @ai_fn - def get_fruit(name: str) -> set: - """Returns the letters in the provided fruit name""" - - assert get_fruit("banana") == {"a", "b", "n"} - - def test_frozenset_return_type(self): - @ai_fn - def get_fruit(name: str) -> frozenset: - """Returns the letters in the provided fruit name""" - - assert get_fruit("banana") == frozenset({"a", "b", "n"}) + class TestAnnotations: + def test_list_fruit_with_generic_type_hints(self): + @ai_fn + def list_fruit(n: int) -> List[str]: + """Returns a list of `n` fruit""" + + result = list_fruit(3) + assert len(result) == 3 + + def test_basemodel_return_annotation(self): + class Fruit(BaseModel): + name: str + color: str + + @ai_fn + def get_fruit(description: str) -> Fruit: + """Returns a fruit with the provided description""" + + fruit = get_fruit("loved by monkeys") + assert fruit.name.lower() == "banana" + assert fruit.color.lower() == "yellow" + + @pytest.mark.parametrize("name,expected", [("banana", True), ("car", False)]) + def test_bool_return_annotation(self, name, expected): + @ai_fn + def is_fruit(name: str) -> bool: + """Returns True if the provided name is a fruit""" + + assert is_fruit(name) == expected + + def test_plain_dict_return_type(self): + @ai_fn + def get_fruit(name: str) -> dict: + """Returns a fruit with the provided name and color""" + + fruit = get_fruit("banana") + assert fruit["name"].lower() == "banana" + assert fruit["color"].lower() == "yellow" + + def test_annotated_dict_return_type(self): + @ai_fn + def get_fruit(name: str) -> dict[str, str]: + """Returns a fruit with the provided name and color""" + + fruit = get_fruit("banana") + assert fruit["name"].lower() == "banana" + assert fruit["color"].lower() == "yellow" + + def test_generic_dict_return_type(self): + @ai_fn + def get_fruit(name: str) -> Dict[str, str]: + """Returns a fruit with the provided name and color""" + + fruit = get_fruit("banana") + assert fruit["name"].lower() == "banana" + assert fruit["color"].lower() == "yellow" + + def test_int_return_type(self): + @ai_fn + def get_fruit(name: str) -> int: + """Returns the number of letters in the provided fruit name""" + + assert get_fruit("banana") == 6 + + def test_float_return_type(self): + @ai_fn + def get_fruit(name: str) -> float: + """Returns the number of letters in the provided fruit name""" + + assert get_fruit("banana") == 6.0 + + def test_tuple_return_type(self): + @ai_fn + def get_fruit(name: str) -> tuple: + """Returns a tuple of fruit""" + + assert get_fruit("alphabet of fruit, first 3") == ( + "apple", + "banana", + "cherry", + ) + + @pytest.skip(reason="TODO") + def test_set_return_type(self): + @ai_fn + def get_fruit_letters(name: str) -> set: + """Returns the letters in the provided fruit name""" + + assert get_fruit_letters("banana") == {"a", "b", "n"} + + @pytest.skip(reason="TODO") + def test_frozenset_return_type(self): + @ai_fn + def get_fruit_letters(name: str) -> frozenset: + """Returns the letters in the provided fruit name""" + + assert get_fruit_letters("orange") == frozenset( + {"a", "e", "g", "n", "o", "r"} + ) @pytest_mark_class("llm") class TestAIFunctionsMap: def test_map(self): - result = list_fruit_color.map([2, 3]) + result = list_fruit.map([2, 3]) assert len(result) == 2 assert len(result[0]) == 2 assert len(result[1]) == 3 async def test_amap(self): - result = await list_fruit_color.amap([2, 3]) + result = await list_fruit.amap([2, 3]) assert len(result) == 2 assert len(result[0]) == 2 assert len(result[1]) == 3 def test_map_kwargs(self): - result = list_fruit_color.map(n=[2, 3]) + result = list_fruit.map(n=[2, 3]) assert len(result) == 2 assert len(result[0]) == 2 assert len(result[1]) == 3 def test_map_kwargs_and_args(self): - result = list_fruit_color.map([2, 3], color=[None, "red"]) + result = list_fruit_color.map([2, 3], color=["green", "red"]) assert len(result) == 2 assert len(result[0]) == 2 assert len(result[1]) == 3 From 8808800a2b0565456431c943d977404a7ce5a6f6 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Thu, 16 Nov 2023 05:58:09 -0600 Subject: [PATCH 3/4] clarify map __doc__ --- src/marvin/components/ai_function.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py index 26d4af106..7b030ab1f 100644 --- a/src/marvin/components/ai_function.py +++ b/src/marvin/components/ai_function.py @@ -117,12 +117,15 @@ async def amap(self, *map_args: list[Any], **map_kwargs: list[Any]) -> list[T]: """ Map the AI function over a sequence of arguments. Runs concurrently. + A `map` twin method is provided by the `expose_sync_method` decorator. + + You can use `map` or `amap` synchronously or asynchronously, respectively, + regardless of whether the user function is synchronous or asynchronous. + Arguments should be provided as if calling the function normally, but each argument must be a list. The function is called once for each item in the list, and the results are returned in a list. - This method should be called synchronously. - For example, fn.map([1, 2]) is equivalent to [fn(1), fn(2)]. fn.map([1, 2], x=['a', 'b']) is equivalent to [fn(1, x='a'), fn(2, x='b')]. From 9f9cd380fc3bd7114a775e52f47fe839389f9934 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Thu, 16 Nov 2023 13:21:02 -0600 Subject: [PATCH 4/4] try to fix typing --- src/marvin/components/ai_function.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py index 7b030ab1f..474d817d2 100644 --- a/src/marvin/components/ai_function.py +++ b/src/marvin/components/ai_function.py @@ -79,11 +79,15 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Union[T, Awaitable[T]]: return getattr(self, call)(create, *args, **kwargs) - async def async_call(self, acreate, *args: P.args, **kwargs: P.kwargs) -> T: + async def async_call( + self, acreate: Callable[..., Awaitable[Any]], *args: P.args, **kwargs: P.kwargs + ) -> T: _response = await acreate(**self.as_prompt(*args, **kwargs).serialize()) return self.parse(_response) - def sync_call(self, create, *args: P.args, **kwargs: P.kwargs) -> T: + def sync_call( + self, create: Callable[..., Any], *args: P.args, **kwargs: P.kwargs + ) -> T: _response = create(**self.as_prompt(*args, **kwargs).serialize()) return self.parse(_response) @@ -213,7 +217,7 @@ def as_decorator( field_name: str = "data", field_description: str = "The data to format.", **render_kwargs: Any, - ) -> Callable[..., Self]: + ) -> Union[Callable[[Callable[P, T]], Self], Self]: def decorator(func: Callable[P, T]) -> Self: return cls( fn=func, @@ -271,9 +275,9 @@ def ai_fn( field_name: str = "data", field_description: str = "The data to format.", **render_kwargs: Any, -) -> Union[Callable[[Callable[P, T]], AIFunction[P, T]], AIFunction[P, T]]: +) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]: if fn is not None: - return AIFunction.as_decorator( + return AIFunction.as_decorator( # type: ignore fn=fn, environment=environment, prompt=prompt, @@ -284,8 +288,8 @@ def ai_fn( **render_kwargs, ) - def decorator(func: Callable[P, T]) -> AIFunction[P, T]: - return AIFunction.as_decorator( + def decorator(func: Callable[P, T]) -> Callable[P, T]: + return AIFunction.as_decorator( # type: ignore fn=func, environment=environment, prompt=prompt,