Skip to content

Commit

Permalink
revert weird change
Browse files Browse the repository at this point in the history
hehheeh
  • Loading branch information
zzstoatzz committed Jan 22, 2025
1 parent 980e096 commit 5daf6c2
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ source = "vcs"
extend-select = ["I"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401", "I001"]
"__init__.py" = ["F401", "I001", "RUF013", "UP045"]

[build-system]
requires = ["hatchling>=1.21.0", "hatch-vcs>=0.4.0"]
Expand Down
4 changes: 2 additions & 2 deletions src/marvin/fns/cast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, TypeVar, get_args
from typing import Any, TypeVar

import marvin
from marvin.agents.agent import Agent
Expand Down Expand Up @@ -80,7 +80,7 @@ async def cast_async(
name="Cast Task",
instructions=prompt,
context=task_context,
result_type=t[0] if (t := get_args(target)) else target,
result_type=target,
agents=[agent] if agent else None,
)

Expand Down
3 changes: 1 addition & 2 deletions src/marvin/utilities/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
Mapping,
Optional,
Type,
TypedDict,
Union,
)

Expand All @@ -60,7 +59,7 @@
model_validator,
)
from pydantic_core import to_json
from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict

__all__ = ["jsonschema_to_type", "JSONSchema"]

Expand Down
56 changes: 33 additions & 23 deletions tests/ai/fns/test_cast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Any

import pytest
from pydantic import BaseModel, Field
Expand All @@ -14,29 +15,38 @@ class Location(BaseModel):


class TestBuiltins:
def test_cast_text_to_int(self):
result = marvin.cast("one", int)
assert result == 1

def test_cast_text_to_list_of_ints(self):
result = marvin.cast("one, TWO, three", list[int])
assert result == [1, 2, 3]

def test_cast_text_to_list_of_ints_2(self):
result = marvin.cast("4 and 5 then 6", list[int])
assert result == [4, 5, 6]

def test_cast_text_to_list_of_floats(self):
result = marvin.cast("1.0, 2.0, 3.0", list[float])
assert result == [1.0, 2.0, 3.0]

def test_cast_text_to_bool(self):
result = marvin.cast("no", bool)
assert result is False

def test_cast_text_to_bool_with_true(self):
result = marvin.cast("yes", bool)
assert result is True
@pytest.mark.parametrize(
"input_text, target_type, expected_result",
[
("one", int, 1),
("one, TWO, three", list[int], [1, 2, 3]),
("4 and 5 then 6", list[int], [4, 5, 6]),
("1.0, 2.0, 3.0", list[float], [1.0, 2.0, 3.0]),
],
ids=[
"cast_text_to_int",
"cast_text_to_list_of_ints",
"cast_text_to_list_of_ints_2",
"cast_text_to_list_of_floats",
],
)
def test_cast(self, input_text: str, target_type: type, expected_result: Any):
result = marvin.cast(input_text, target_type)
assert result == expected_result

@pytest.mark.parametrize(
"input_text, expected_result",
[
("no", False),
("yes", True),
],
ids=[
"cast_text_to_bool_false",
"cast_text_to_bool_true",
],
)
def test_cast_text_to_bool(self, input_text: str, expected_result: bool):
assert marvin.cast(input_text, bool) is expected_result

def test_str_not_json(self):
result = marvin.cast(
Expand Down

0 comments on commit 5daf6c2

Please sign in to comment.