From 14918130535c2806b47c05895ed778e43639bbda Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Tue, 2 Jan 2024 10:21:17 -0500 Subject: [PATCH] Assume `str` return annotation if none provided --- pyproject.toml | 1 - src/marvin/_mappings/types.py | 3 +++ src/marvin/components/prompt/fn.py | 5 ++++- tests/components/test_ai_functions.py | 24 ++++++++++++++++++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 732677241..8b9c3e2bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ dev = [ "ipython", "mkdocs-autolinks-plugin~=0.7", "mkdocs-awesome-pages-plugin~=2.8", - "mkdocs-livereload", "mkdocs-markdownextradata-plugin~=0.2", "mkdocs-jupyter>=0.24.1", "mkdocs-material>=9.1.17", diff --git a/src/marvin/_mappings/types.py b/src/marvin/_mappings/types.py index e68e0bf49..a32472e04 100644 --- a/src/marvin/_mappings/types.py +++ b/src/marvin/_mappings/types.py @@ -25,6 +25,9 @@ def cast_type_to_model( else: metadata = FieldInfo(description=field_description) + if _type is None: + raise ValueError("No type provided; unable to create model for casting.") + return create_model( model_name, __doc__=model_description, diff --git a/src/marvin/components/prompt/fn.py b/src/marvin/components/prompt/fn.py index b6637fdde..1dfa70cb8 100644 --- a/src/marvin/components/prompt/fn.py +++ b/src/marvin/components/prompt/fn.py @@ -192,9 +192,12 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self: signature = inspect.signature(func) params = signature.bind(*args, **kwargs_) params.apply_defaults() + _type = inspect.signature(func).return_annotation + if _type is inspect._empty: + _type = str toolset = cast_type_to_toolset( - _type=inspect.signature(func).return_annotation, + _type=_type, model_name=model_name, model_description=model_description, field_name=field_name, diff --git a/tests/components/test_ai_functions.py b/tests/components/test_ai_functions.py index 070a0b97d..79845ee08 100644 --- a/tests/components/test_ai_functions.py +++ b/tests/components/test_ai_functions.py @@ -41,6 +41,30 @@ async def list_fruit(n: int) -> list[str]: assert len(result) == 3 class TestAnnotations: + def test_no_annotations(self): + @ai_fn + def f(x): + """returns x + 1""" + + result = f(3) + assert result == "4" + + def test_arg_annotations(self): + @ai_fn + def f(x: int): + """returns x + 1""" + + result = f(3) + assert result == "4" + + def test_return_annotations(self): + @ai_fn + def f(x) -> int: + """returns x + 1""" + + result = f("3") + assert result == 4 + def test_list_fruit_with_generic_type_hints(self): @ai_fn def list_fruit(n: int) -> List[str]: