diff --git a/docs/docs/text/classification.md b/docs/docs/text/classification.md
index d44f075bd..a88c4ef41 100644
--- a/docs/docs/text/classification.md
+++ b/docs/docs/text/classification.md
@@ -9,15 +9,14 @@ Marvin has a powerful classification tool that can be used to categorize text in
-
!!! example "Example: categorize user feedback"
- Categorize user feedback into labels such as "bug", "feature request", or "inquiry":
-
+Categorize user feedback into labels such as "bug", "feature request", or "inquiry":
+
```python
import marvin
category = marvin.classify(
- "The app crashes when I try to upload a file.",
+ "The app crashes when I try to upload a file.",
labels=["bug", "feature request", "inquiry"]
)
```
@@ -28,7 +27,6 @@ Marvin has a powerful classification tool that can be used to categorize text in
assert category == "bug"
```
-
How it works
@@ -36,22 +34,21 @@ Marvin has a powerful classification tool that can be used to categorize text in
-
## Providing labels
Marvin's classification tool is designed to accommodate a variety of label formats, each suited to different use cases.
### Lists
-When quick, ad-hoc categorization is required, a simple list of strings is the most straightforward approach. For example:
+When quick, ad-hoc categorization is required, a simple list of values is the most straightforward approach. The result of the classifier is the matching label from the list. Marvin will attempt to convert your labels to strings if they are not already strings in order to provide them to the LLM, though the original (potentially non-string) labels will be returned as your result.
!!! example "Example: sentiment analysis"
-
+
```python
import marvin
sentiment = marvin.classify(
- "Marvin is so easy to use!",
+ "Marvin is so easy to use!",
labels=["positive", "negative", "meh"]
)
```
@@ -61,6 +58,25 @@ When quick, ad-hoc categorization is required, a simple list of strings is the m
assert sentiment == "positive"
```
+#### Lists of objects
+
+Marvin's classification tool can also handle lists of objects, in which case it will return the object that best matches the input. For example, here we use a text prompt to select a single person from a list of people:
+
+```python
+import marvin
+from pydantic import BaseModel
+
+class Person(BaseModel):
+ name: str
+ age: int
+
+alice = Person(name="Alice", age=45)
+bob = Person(name="Bob", age=16)
+
+result = marvin.classify('who is a teenager?', [alice, bob])
+assert result is bob
+```
+
### Enums
@@ -101,12 +117,25 @@ In scenarios where labels are part of the function signatures or need to be infe
from typing import Literal
import marvin
-RequestType = Literal["support request", "account issue", "general inquiry"]
+RequestType = Literal["billing issue", "support request", "general inquiry"]
+
request = marvin.classify("Reset my password", RequestType)
-assert request == "account issue"
+assert request == "support request"
```
+## Returning indices
+
+In some cases, you may want to return the index of the selected label rather than the label itself:
+
+```python
+result = marvin.classify(
+ "Reset my password",
+ ["billing issue", "support request", "general inquiry"],
+ return_index=True,
+)
+assert result == 1
+```
## Providing instructions
@@ -160,7 +189,6 @@ assert predicted_sentiment == "Positive"
While the primary focus is on the `classify` function, Marvin also includes the `classifier` decorator. Applied to Enums, it enables them to be used as classifiers that can be instantiated with natural language. This interface is particularly handy when dealing with a fixed set of labels commonly reused in your application.
-
```python
@marvin.classifier
class IssueType(Enum):
@@ -175,6 +203,7 @@ assert issue == IssueType.BUG
While convenient for certain scenarios, it's recommended to use the `classify` function for its greater flexibility and broader application range.
## Model parameters
+
You can pass parameters to the underlying API via the `model_kwargs` argument of `classify` or `@classifier`. These parameters are passed directly to the API, so you can use any supported parameter.
## Best practices
@@ -190,9 +219,9 @@ If you are using Marvin in an async environment, you can use `classify_async`:
```python
result = await marvin.classify_async(
- "The app crashes when I try to upload a file.",
+ "The app crashes when I try to upload a file.",
labels=["bug", "feature request", "inquiry"]
-)
+)
assert result == "bug"
```
diff --git a/src/marvin/_mappings/types.py b/src/marvin/_mappings/types.py
index f4ac636ed..a2f2cb804 100644
--- a/src/marvin/_mappings/types.py
+++ b/src/marvin/_mappings/types.py
@@ -2,7 +2,7 @@
from types import GenericAlias
from typing import Any, Callable, Literal, Optional, Union, get_args, get_origin
-from pydantic import BaseModel, create_model
+from pydantic import BaseModel, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from marvin.settings import settings
@@ -79,20 +79,41 @@ def cast_type_to_toolset(
)
-def cast_type_to_labels(
- type_: Union[type, GenericAlias],
-) -> list[str]:
+def cast_type_to_labels(type_: Union[type, GenericAlias]) -> list[str]:
+ """
+ Converts a type to a string list of its possible values.
+ """
if get_origin(type_) == Literal:
return [str(token) for token in get_args(type_)]
elif isinstance(type_, type) and issubclass(type_, Enum):
- members: list[str] = [
+ member_values: list[str] = [
option.value for option in getattr(type_, "__members__", {}).values()
]
- return members
+ return member_values
+ elif isinstance(type_, list):
+ # typeadapter handles all types known to Pydantic
+ try:
+ return [TypeAdapter(type(t)).dump_json(t).decode() for t in type_]
+ except Exception as exc:
+ raise ValueError(f"Unable to cast type to labels: {exc}")
+ elif type_ is bool:
+ return ["false", "true"], [False, True]
+ else:
+ raise TypeError(f"Expected Literal, Enum, bool, or list, got {type_}.")
+
+
+def cast_type_to_list(type_: Union[type, GenericAlias]) -> list:
+ """
+ Converts a type to a list of its possible values.
+ """
+ if get_origin(type_) == Literal:
+ return [token for token in get_args(type_)]
+ elif isinstance(type_, type) and issubclass(type_, Enum):
+ return list(type_)
elif isinstance(type_, list):
- return [str(token) for token in type_]
+ return type_
elif type_ is bool:
- return ["false", "true"]
+ return [False, True]
else:
raise TypeError(f"Expected Literal, Enum, bool, or list, got {type_}.")
diff --git a/src/marvin/ai/text.py b/src/marvin/ai/text.py
index 9246ab222..a90e3ceeb 100644
--- a/src/marvin/ai/text.py
+++ b/src/marvin/ai/text.py
@@ -26,6 +26,7 @@
from marvin._mappings.types import (
cast_labels_to_grammar,
cast_type_to_labels,
+ cast_type_to_list,
)
from marvin.ai.prompts.text_prompts import (
CAST_PROMPT,
@@ -169,6 +170,7 @@ async def _generate_typed_llm_response_with_tool(
async def _generate_typed_llm_response_with_logit_bias(
prompt_template: str,
prompt_kwargs: dict,
+ return_index: bool = False,
encoder: Callable[[str], list[int]] = None,
max_tokens: int = 1,
model_kwargs: dict = None,
@@ -191,6 +193,8 @@ async def _generate_typed_llm_response_with_logit_bias(
Args:
prompt_template (str): The template for the prompt.
prompt_kwargs (dict): Additional keyword arguments for the prompt.
+ return_index (bool, optional): Whether to return the index of the label
+ instead of the label itself.
encoder (Callable[[str], list[int]], optional): The encoder function to
use for the generation. Defaults to None.
max_tokens (int, optional): The maximum number of tokens for the
@@ -207,6 +211,7 @@ async def _generate_typed_llm_response_with_logit_bias(
if "labels" not in prompt_kwargs:
raise ValueError("Labels must be provided as a kwarg to the prompt template.")
labels = prompt_kwargs["labels"]
+ label_list = cast_type_to_list(labels)
label_strings = cast_type_to_labels(labels)
grammar = cast_labels_to_grammar(
labels=label_strings, encoder=encoder, max_tokens=max_tokens
@@ -222,11 +227,13 @@ async def _generate_typed_llm_response_with_logit_bias(
# the response contains a single number representing the index of the chosen
label_index = int(response.response.choices[0].message.content)
+ if return_index:
+ return label_index
+
if labels is bool:
return bool(label_index)
- result = label_strings[label_index]
- return labels(result) if isinstance(labels, type) else result
+ return label_list[label_index]
async def cast_async(
@@ -260,6 +267,9 @@ async def cast_async(
"""
model_kwargs = model_kwargs or {}
+ if not isinstance(data, str):
+ data = marvin.utilities.tools.output_to_string(data)
+
if target is None and instructions is None:
raise ValueError("Must provide either a target type or instructions.")
elif target is None:
@@ -320,11 +330,16 @@ async def extract_async(
Returns:
list: A list of extracted entities of the specified type.
"""
+ model_kwargs = model_kwargs or {}
+
if target is None and instructions is None:
raise ValueError("Must provide either a target type or instructions.")
elif target is None:
target = str
- model_kwargs = model_kwargs or {}
+
+ if not isinstance(data, str):
+ data = marvin.utilities.tools.output_to_string(data)
+
return await _generate_typed_llm_response_with_tool(
prompt_template=EXTRACT_PROMPT,
prompt_kwargs=dict(data=data, instructions=instructions),
@@ -338,9 +353,10 @@ async def classify_async(
data: str,
labels: Union[Enum, list[T], type],
instructions: str = None,
+ return_index: bool = False,
model_kwargs: dict = None,
client: Optional[AsyncMarvinClient] = None,
-) -> T:
+) -> Union[T, int]:
"""
Classifies the provided data based on the provided labels.
@@ -354,18 +370,23 @@ async def classify_async(
labels (Union[Enum, list[T], type]): The labels to classify the data into.
instructions (str, optional): Specific instructions for the
classification. Defaults to None.
+
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI function.
Returns:
- T: The label that the data was classified into.
+ Union[T, int]: The label or index that the data was classified into.
"""
model_kwargs = model_kwargs or {}
+ if not isinstance(data, str):
+ data = marvin.utilities.tools.output_to_string(data)
+
return await _generate_typed_llm_response_with_logit_bias(
prompt_template=CLASSIFY_PROMPT,
prompt_kwargs=dict(data=data, labels=labels, instructions=instructions),
+ return_index=return_index,
model_kwargs=model_kwargs | dict(temperature=0),
client=client,
)
@@ -754,9 +775,10 @@ def classify(
data: str,
labels: Union[Enum, list[T], type],
instructions: str = None,
+ return_index: bool = False,
model_kwargs: dict = None,
client: Optional[AsyncMarvinClient] = None,
-) -> T:
+) -> Union[T, int]:
"""
Classifies the provided data based on the provided labels.
@@ -770,18 +792,20 @@ def classify(
labels (Union[Enum, list[T], type]): The labels to classify the data into.
instructions (str, optional): Specific instructions for the
classification. Defaults to None.
+ return_index (bool, optional): Whether to return the index of the label instead of the label itself.
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI function.
Returns:
- T: The label that the data was classified into.
+ Union[T, int]: The label or index that the data was classified into.
"""
return run_sync(
classify_async(
data=data,
labels=labels,
instructions=instructions,
+ return_index=return_index,
model_kwargs=model_kwargs,
client=client,
)
@@ -878,15 +902,17 @@ async def classify_async_map(
data: list[str],
labels: Union[Enum, list[T], type],
instructions: Optional[str] = None,
+ return_index: bool = False,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
-) -> list[T]:
+) -> list[Union[T, int]]:
return await map_async(
fn=classify_async,
map_kwargs=dict(data=data),
unmapped_kwargs=dict(
labels=labels,
instructions=instructions,
+ return_index=return_index,
model_kwargs=model_kwargs,
client=client,
),
@@ -897,14 +923,16 @@ def classify_map(
data: list[str],
labels: Union[Enum, list[T], type],
instructions: Optional[str] = None,
+ return_index: bool = False,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
-) -> list[T]:
+) -> list[Union[T, int]]:
return run_sync(
classify_async_map(
data=data,
labels=labels,
instructions=instructions,
+ return_index=return_index,
model_kwargs=model_kwargs,
client=client,
)
diff --git a/src/marvin/beta/vision.py b/src/marvin/beta/vision.py
index a17089229..42baefe61 100644
--- a/src/marvin/beta/vision.py
+++ b/src/marvin/beta/vision.py
@@ -280,9 +280,10 @@ async def classify_async(
labels: Union[Enum, list[T], type],
images: Union[Union[str, Path], list[Union[str, Path]]] = None,
instructions: str = None,
+ return_index: bool = False,
vision_model_kwargs: dict = None,
model_kwargs: dict = None,
-) -> T:
+) -> Union[T, int]:
"""
Classifies provided data and/or images into one of the specified labels.
Args:
@@ -290,11 +291,12 @@ async def classify_async(
labels (Union[Enum, list[T], type]): Labels to classify into.
images (Union[Union[str, Path], list[Union[str, Path]]], optional): Additional images for classification.
instructions (str, optional): Instructions for the classification.
+ return_index (bool, optional): Whether to return the index of the label instead of the label itself.
vision_model_kwargs (dict, optional): Arguments for the vision model.
model_kwargs (dict, optional): Arguments for the language model.
Returns:
- T: Label that the data/images were classified into.
+ Union[T, int]: Label or index that the data/images were classified into.
"""
async def marvin_call(x):
@@ -302,6 +304,7 @@ async def marvin_call(x):
data=x,
labels=labels,
instructions=instructions,
+ return_index=return_index,
model_kwargs=model_kwargs,
)
@@ -415,9 +418,10 @@ def classify(
labels: Union[Enum, list[T], type],
images: Union[Image, list[Image]] = None,
instructions: str = None,
+ return_index: bool = False,
vision_model_kwargs: dict = None,
model_kwargs: dict = None,
-) -> T:
+) -> Union[T, int]:
"""
Classifies provided data and/or images into one of the specified labels synchronously.
@@ -426,11 +430,12 @@ def classify(
labels (Union[Enum, list[T], type]): Labels to classify into.
images (Union[Image, list[Image]], optional): Additional images for classification.
instructions (str, optional): Instructions for the classification.
+ return_index (bool, optional): Whether to return the index of the label instead of the label itself.
vision_model_kwargs (dict, optional): Arguments for the vision model.
model_kwargs (dict, optional): Arguments for the language model.
Returns:
- T: Label that the data/images were classified into.
+ Union[T, int]: Label or index that the data/images were classified into.
"""
return run_sync(
classify_async(
@@ -438,6 +443,7 @@ def classify(
labels=labels,
images=images,
instructions=instructions,
+ return_index=return_index,
vision_model_kwargs=vision_model_kwargs,
model_kwargs=model_kwargs,
)
@@ -449,14 +455,16 @@ async def classify_async_map(
data: list[Union[str, Image]],
labels: Union[Enum, list[T], type],
instructions: Optional[str] = None,
+ return_index: bool = False,
model_kwargs: Optional[dict] = None,
-) -> list[T]:
+) -> list[Union[T, int]]:
return await map_async(
fn=classify_async,
map_kwargs=dict(data=data),
unmapped_kwargs=dict(
labels=labels,
instructions=instructions,
+ return_index=return_index,
model_kwargs=model_kwargs,
),
)
@@ -466,13 +474,15 @@ def classify_map(
data: list[Union[str, Image]],
labels: Union[Enum, list[T], type],
instructions: Optional[str] = None,
+ return_index: bool = False,
model_kwargs: Optional[dict] = None,
-) -> list[T]:
+) -> list[Union[T, int]]:
return run_sync(
classify_async_map(
data=data,
labels=labels,
instructions=instructions,
+ return_index=return_index,
model_kwargs=model_kwargs,
)
)
diff --git a/tests/ai/beta/vision/test_classify.py b/tests/ai/beta/vision/test_classify.py
index 50411412e..042a61637 100644
--- a/tests/ai/beta/vision/test_classify.py
+++ b/tests/ai/beta/vision/test_classify.py
@@ -95,6 +95,14 @@ async def test_ny(self):
assert result == "urban"
+class TestReturnIndex:
+ def test_return_index(self):
+ result = marvin.beta.classify(
+ "This is a great feature!", ["bad", "good"], return_index=True
+ )
+ assert result == 1
+
+
class TestMapping:
def test_map(self):
ny = marvin.beta.Image(
diff --git a/tests/ai/test_classify.py b/tests/ai/test_classify.py
index 652f821a8..0a25a7ea8 100644
--- a/tests/ai/test_classify.py
+++ b/tests/ai/test_classify.py
@@ -3,8 +3,9 @@
import marvin
import pytest
+from pydantic import BaseModel
-Sentiment = Literal["Positive", "Negative"]
+Sentiment = Literal["Negative", "Positive"]
class GitHubIssueTag(Enum):
@@ -57,13 +58,33 @@ def test_classify_number(self):
result = marvin.classify(0, ["letter", "number"])
assert result == "number"
+ def test_classify_object(self):
+ """
+ Test that objects are returned from classify
+ """
+
+ class Person(BaseModel):
+ name: str
+ age: int
+
+ p1 = Person(name="Alice", age=30)
+ p2 = Person(name="Bob", age=25)
+ p3 = Person(name="Charlie", age=35)
+
+ result = marvin.classify("a person in wonderland", [p1, p2, p3])
+ assert result is p1
+
class TestBool:
def test_classify_positive_sentiment(self):
result = marvin.classify("This is a great feature!", bool)
assert result is True
def test_classify_negative_sentiment(self):
- result = marvin.classify("This feature is terrible!", bool)
+ result = marvin.classify(
+ "This feature is terrible!",
+ bool,
+ instructions="Is the sentiment positive?",
+ )
assert result is False
def test_classify_falseish(self):
@@ -82,6 +103,13 @@ async def test_classify_positive_sentiment(self):
result = await marvin.classify_async("This is a great feature!", bool)
assert result is True
+ class TestReturnIndex:
+ def test_return_index(self):
+ result = marvin.classify(
+ "This is a great feature!", ["bad", "good"], return_index=True
+ )
+ assert result == 1
+
class TestExamples:
async def test_hogwarts_sorting_hat(self):
description = "Brave, daring, chivalrous, and sometimes a bit reckless."
@@ -116,6 +144,17 @@ def router(transcript: str) -> Department:
assert router(user_input).value == expected_selection
+ class TestConvertInputData:
+ def test_convert_input_data(self):
+ class Name(BaseModel):
+ first: str
+ last: str
+
+ result = marvin.classify(
+ Name(first="Alice", last="Smith"), ["Alice", "Bob"]
+ )
+ assert result == "Alice"
+
class TestMapping:
def test_classify_map(self):
@@ -138,3 +177,9 @@ async def test_async_classify_map(self):
)
assert isinstance(result, list)
assert result == ["Positive", "Negative"]
+
+ def test_classify_return_index(self):
+ result = marvin.classify.map(
+ ["This is great!", "This is terrible!"], Sentiment, return_index=True
+ )
+ assert result == [1, 0]