Skip to content

Commit

Permalink
Merge pull request #669 from PrefectHQ/restrict-ai-classifier-returns
Browse files Browse the repository at this point in the history
ai_classifier: no `list[str]`
  • Loading branch information
zzstoatzz authored Dec 6, 2023
2 parents 5c65d3b + 1055a46 commit ffc5e34
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 23 deletions.
7 changes: 2 additions & 5 deletions src/marvin/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,15 @@ def create_tool_from_model(


def create_vocabulary_from_type(
vocabulary: Union[GenericAlias, type, list[str]],
vocabulary: Union[GenericAlias, type],
) -> list[str]:
if get_origin(vocabulary) == Literal:
return [str(token) for token in get_args(vocabulary)]
elif isinstance(vocabulary, type) and issubclass(vocabulary, Enum):
return [str(token) for token in list(vocabulary.__members__.keys())]
elif isinstance(vocabulary, list) and next(iter(get_args(list[str])), None) == str:
return [str(token) for token in vocabulary]
else:
raise TypeError(
f"Expected Literal or Enum or list[str], got {type(vocabulary)} with value"
f" {vocabulary}"
f"Expected Literal or Enum, got {type(vocabulary)} with value {vocabulary}"
)


Expand Down
18 changes: 0 additions & 18 deletions tests/components/test_ai_classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from enum import Enum

import pytest
from marvin import ai_classifier
from typing_extensions import Literal

Expand All @@ -16,7 +15,6 @@ class GitHubIssueTag(Enum):
DOCS = "docs"


@pytest.mark.skip(reason="ai_classifier doesnt really work imo")
@pytest_mark_class("llm")
class TestAIClassifer:
class TestLiteral:
Expand Down Expand Up @@ -47,19 +45,3 @@ def labeler(text: str) -> GitHubIssueTag:
result = labeler("improve the docs you slugs")

assert result == GitHubIssueTag.DOCS

class TestList:
def test_ai_classifier_list_return_type(self):
@ai_classifier
def labeler(text: str) -> list[str]:
"""Select from the following GitHub issue tags
- bug
- feature
- enhancement
- docs
"""

result = labeler("i found a bug in the example from the docs")

assert set(result) == {"bug", "docs"}

0 comments on commit ffc5e34

Please sign in to comment.