diff --git a/src/marvin/serializers.py b/src/marvin/serializers.py index 122b85cb7..8a7c7422f 100644 --- a/src/marvin/serializers.py +++ b/src/marvin/serializers.py @@ -78,18 +78,15 @@ def create_tool_from_model( def create_vocabulary_from_type( - vocabulary: Union[GenericAlias, type, list[str]], + vocabulary: Union[GenericAlias, type], ) -> list[str]: if get_origin(vocabulary) == Literal: return [str(token) for token in get_args(vocabulary)] elif isinstance(vocabulary, type) and issubclass(vocabulary, Enum): return [str(token) for token in list(vocabulary.__members__.keys())] - elif isinstance(vocabulary, list) and next(iter(get_args(list[str])), None) == str: - return [str(token) for token in vocabulary] else: raise TypeError( - f"Expected Literal or Enum or list[str], got {type(vocabulary)} with value" - f" {vocabulary}" + f"Expected Literal or Enum, got {type(vocabulary)} with value {vocabulary}" ) diff --git a/tests/components/test_ai_classifier.py b/tests/components/test_ai_classifier.py index 287a8b33c..034f53bf0 100644 --- a/tests/components/test_ai_classifier.py +++ b/tests/components/test_ai_classifier.py @@ -1,6 +1,5 @@ from enum import Enum -import pytest from marvin import ai_classifier from typing_extensions import Literal @@ -16,7 +15,6 @@ class GitHubIssueTag(Enum): DOCS = "docs" -@pytest.mark.skip(reason="ai_classifier doesnt really work imo") @pytest_mark_class("llm") class TestAIClassifer: class TestLiteral: @@ -47,19 +45,3 @@ def labeler(text: str) -> GitHubIssueTag: result = labeler("improve the docs you slugs") assert result == GitHubIssueTag.DOCS - - class TestList: - def test_ai_classifier_list_return_type(self): - @ai_classifier - def labeler(text: str) -> list[str]: - """Select from the following GitHub issue tags - - - bug - - feature - - enhancement - - docs - """ - - result = labeler("i found a bug in the example from the docs") - - assert set(result) == {"bug", "docs"}