Skip to content

Commit

Permalink
Merge pull request #708 from PrefectHQ/fix-no-return-error
Browse files Browse the repository at this point in the history
Assume `str` return annotation if none provided
  • Loading branch information
jlowin authored Jan 2, 2024
2 parents 905e710 + 1491813 commit 7295c29
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 2 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dev = [
"ipython",
"mkdocs-autolinks-plugin~=0.7",
"mkdocs-awesome-pages-plugin~=2.8",
"mkdocs-livereload",
"mkdocs-markdownextradata-plugin~=0.2",
"mkdocs-jupyter>=0.24.1",
"mkdocs-material>=9.1.17",
Expand Down
3 changes: 3 additions & 0 deletions src/marvin/_mappings/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def cast_type_to_model(
else:
metadata = FieldInfo(description=field_description)

if _type is None:
raise ValueError("No type provided; unable to create model for casting.")

return create_model(
model_name,
__doc__=model_description,
Expand Down
5 changes: 4 additions & 1 deletion src/marvin/components/prompt/fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,12 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs_: P.kwargs) -> Self:
signature = inspect.signature(func)
params = signature.bind(*args, **kwargs_)
params.apply_defaults()
_type = inspect.signature(func).return_annotation
if _type is inspect._empty:
_type = str

toolset = cast_type_to_toolset(
_type=inspect.signature(func).return_annotation,
_type=_type,
model_name=model_name,
model_description=model_description,
field_name=field_name,
Expand Down
24 changes: 24 additions & 0 deletions tests/components/test_ai_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,30 @@ async def list_fruit(n: int) -> list[str]:
assert len(result) == 3

class TestAnnotations:
def test_no_annotations(self):
@ai_fn
def f(x):
"""returns x + 1"""

result = f(3)
assert result == "4"

def test_arg_annotations(self):
@ai_fn
def f(x: int):
"""returns x + 1"""

result = f(3)
assert result == "4"

def test_return_annotations(self):
@ai_fn
def f(x) -> int:
"""returns x + 1"""

result = f("3")
assert result == 4

def test_list_fruit_with_generic_type_hints(self):
@ai_fn
def list_fruit(n: int) -> List[str]:
Expand Down

0 comments on commit 7295c29

Please sign in to comment.