From 615cadf41c7bb8a8eed0430d1334682b3fb9c754 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Tue, 5 Dec 2023 17:45:30 -0600 Subject: [PATCH 1/3] no `list[str]` --- src/marvin/serializers.py | 7 ++----- tests/components/test_ai_classifier.py | 18 ------------------ 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/src/marvin/serializers.py b/src/marvin/serializers.py index 122b85cb7..8a7c7422f 100644 --- a/src/marvin/serializers.py +++ b/src/marvin/serializers.py @@ -78,18 +78,15 @@ def create_tool_from_model( def create_vocabulary_from_type( - vocabulary: Union[GenericAlias, type, list[str]], + vocabulary: Union[GenericAlias, type], ) -> list[str]: if get_origin(vocabulary) == Literal: return [str(token) for token in get_args(vocabulary)] elif isinstance(vocabulary, type) and issubclass(vocabulary, Enum): return [str(token) for token in list(vocabulary.__members__.keys())] - elif isinstance(vocabulary, list) and next(iter(get_args(list[str])), None) == str: - return [str(token) for token in vocabulary] else: raise TypeError( - f"Expected Literal or Enum or list[str], got {type(vocabulary)} with value" - f" {vocabulary}" + f"Expected Literal or Enum, got {type(vocabulary)} with value {vocabulary}" ) diff --git a/tests/components/test_ai_classifier.py b/tests/components/test_ai_classifier.py index 781fdc53d..034f53bf0 100644 --- a/tests/components/test_ai_classifier.py +++ b/tests/components/test_ai_classifier.py @@ -1,6 +1,5 @@ from enum import Enum -import pytest from marvin import ai_classifier from typing_extensions import Literal @@ -46,20 +45,3 @@ def labeler(text: str) -> GitHubIssueTag: result = labeler("improve the docs you slugs") assert result == GitHubIssueTag.DOCS - - class TestList: - @pytest.mark.skip(reason="TODO: fix this") - def test_ai_classifier_list_return_type(self): - @ai_classifier - def labeler(text: str) -> list[str]: - """Select from the following GitHub issue tags - - - bug - - feature - - enhancement - - docs - """ - - result = labeler("i found a bug in the example from the docs") - - assert set(result) == {"bug", "docs"} From be5c592e7fc85a279c3a149990269f1b3f9fc4c9 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Tue, 5 Dec 2023 17:59:21 -0600 Subject: [PATCH 2/3] import pytest --- tests/components/test_ai_classifier.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/components/test_ai_classifier.py b/tests/components/test_ai_classifier.py index 9feab57ba..9a5774c5c 100644 --- a/tests/components/test_ai_classifier.py +++ b/tests/components/test_ai_classifier.py @@ -1,5 +1,6 @@ from enum import Enum +import pytest from marvin import ai_classifier from typing_extensions import Literal From 1055a46fa5268b50cb94acab04f7617b1df340c4 Mon Sep 17 00:00:00 2001 From: Nathan Nowack Date: Tue, 5 Dec 2023 18:20:10 -0600 Subject: [PATCH 3/3] rm skip --- tests/components/test_ai_classifier.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/components/test_ai_classifier.py b/tests/components/test_ai_classifier.py index 9a5774c5c..034f53bf0 100644 --- a/tests/components/test_ai_classifier.py +++ b/tests/components/test_ai_classifier.py @@ -1,6 +1,5 @@ from enum import Enum -import pytest from marvin import ai_classifier from typing_extensions import Literal @@ -16,7 +15,6 @@ class GitHubIssueTag(Enum): DOCS = "docs" -@pytest.mark.skip(reason="ai_classifier doesnt really work imo") @pytest_mark_class("llm") class TestAIClassifer: class TestLiteral: