Skip to content

Commit

Permalink
Merge pull request #907 from PrefectHQ/return-index
Browse files Browse the repository at this point in the history
Add support for classifying objects and returning indices
  • Loading branch information
jlowin authored Apr 17, 2024
2 parents 4864b44 + 1fd3773 commit f1500a0
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 39 deletions.
57 changes: 43 additions & 14 deletions docs/docs/text/classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ Marvin has a powerful classification tool that can be used to categorize text in
</p>
</div>


!!! example "Example: categorize user feedback"
Categorize user feedback into labels such as "bug", "feature request", or "inquiry":
Categorize user feedback into labels such as "bug", "feature request", or "inquiry":

```python
import marvin

category = marvin.classify(
"The app crashes when I try to upload a file.",
"The app crashes when I try to upload a file.",
labels=["bug", "feature request", "inquiry"]
)
```
Expand All @@ -28,30 +27,28 @@ Marvin has a powerful classification tool that can be used to categorize text in
assert category == "bug"
```


<div class="admonition info">
<p class="admonition-title">How it works</p>
<p>
Marvin enumerates your options, and uses a <a href="https://twitter.com/AAAzzam/status/1669753721574633473">clever logit bias trick</a> to force the LLM to deductively choose the index of the best option given your provided input. It then returns the choice associated with that index.
</p>
</div>


## Providing labels

Marvin's classification tool is designed to accommodate a variety of label formats, each suited to different use cases.

### Lists

When quick, ad-hoc categorization is required, a simple list of strings is the most straightforward approach. For example:
When quick, ad-hoc categorization is required, a simple list of values is the most straightforward approach. The result of the classifier is the matching label from the list. Marvin will attempt to convert your labels to strings if they are not already strings in order to provide them to the LLM, though the original (potentially non-string) labels will be returned as your result.

!!! example "Example: sentiment analysis"

```python
import marvin

sentiment = marvin.classify(
"Marvin is so easy to use!",
"Marvin is so easy to use!",
labels=["positive", "negative", "meh"]
)
```
Expand All @@ -61,6 +58,25 @@ When quick, ad-hoc categorization is required, a simple list of strings is the m
assert sentiment == "positive"
```

#### Lists of objects

Marvin's classification tool can also handle lists of objects, in which case it will return the object that best matches the input. For example, here we use a text prompt to select a single person from a list of people:

```python
import marvin
from pydantic import BaseModel

class Person(BaseModel):
name: str
age: int

alice = Person(name="Alice", age=45)
bob = Person(name="Bob", age=16)

result = marvin.classify('who is a teenager?', [alice, bob])
assert result is bob
```


### Enums

Expand Down Expand Up @@ -101,12 +117,25 @@ In scenarios where labels are part of the function signatures or need to be infe
from typing import Literal
import marvin

RequestType = Literal["support request", "account issue", "general inquiry"]
RequestType = Literal["billing issue", "support request", "general inquiry"]


request = marvin.classify("Reset my password", RequestType)
assert request == "account issue"
assert request == "support request"
```

## Returning indices

In some cases, you may want to return the index of the selected label rather than the label itself:

```python
result = marvin.classify(
"Reset my password",
["billing issue", "support request", "general inquiry"],
return_index=True,
)
assert result == 1
```

## Providing instructions

Expand Down Expand Up @@ -160,7 +189,6 @@ assert predicted_sentiment == "Positive"

While the primary focus is on the `classify` function, Marvin also includes the `classifier` decorator. Applied to Enums, it enables them to be used as classifiers that can be instantiated with natural language. This interface is particularly handy when dealing with a fixed set of labels commonly reused in your application.


```python
@marvin.classifier
class IssueType(Enum):
Expand All @@ -175,6 +203,7 @@ assert issue == IssueType.BUG
While convenient for certain scenarios, it's recommended to use the `classify` function for its greater flexibility and broader application range.

## Model parameters

You can pass parameters to the underlying API via the `model_kwargs` argument of `classify` or `@classifier`. These parameters are passed directly to the API, so you can use any supported parameter.

## Best practices
Expand All @@ -190,9 +219,9 @@ If you are using Marvin in an async environment, you can use `classify_async`:

```python
result = await marvin.classify_async(
"The app crashes when I try to upload a file.",
"The app crashes when I try to upload a file.",
labels=["bug", "feature request", "inquiry"]
)
)

assert result == "bug"
```
Expand Down
37 changes: 29 additions & 8 deletions src/marvin/_mappings/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from types import GenericAlias
from typing import Any, Callable, Literal, Optional, Union, get_args, get_origin

from pydantic import BaseModel, create_model
from pydantic import BaseModel, TypeAdapter, create_model
from pydantic.fields import FieldInfo

from marvin.settings import settings
Expand Down Expand Up @@ -79,20 +79,41 @@ def cast_type_to_toolset(
)


def cast_type_to_labels(
type_: Union[type, GenericAlias],
) -> list[str]:
def cast_type_to_labels(type_: Union[type, GenericAlias]) -> list[str]:
"""
Converts a type to a string list of its possible values.
"""
if get_origin(type_) == Literal:
return [str(token) for token in get_args(type_)]
elif isinstance(type_, type) and issubclass(type_, Enum):
members: list[str] = [
member_values: list[str] = [
option.value for option in getattr(type_, "__members__", {}).values()
]
return members
return member_values
elif isinstance(type_, list):
# typeadapter handles all types known to Pydantic
try:
return [TypeAdapter(type(t)).dump_json(t).decode() for t in type_]
except Exception as exc:
raise ValueError(f"Unable to cast type to labels: {exc}")
elif type_ is bool:
return ["false", "true"], [False, True]
else:
raise TypeError(f"Expected Literal, Enum, bool, or list, got {type_}.")


def cast_type_to_list(type_: Union[type, GenericAlias]) -> list:
"""
Converts a type to a list of its possible values.
"""
if get_origin(type_) == Literal:
return [token for token in get_args(type_)]
elif isinstance(type_, type) and issubclass(type_, Enum):
return list(type_)
elif isinstance(type_, list):
return [str(token) for token in type_]
return type_
elif type_ is bool:
return ["false", "true"]
return [False, True]
else:
raise TypeError(f"Expected Literal, Enum, bool, or list, got {type_}.")

Expand Down
46 changes: 37 additions & 9 deletions src/marvin/ai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from marvin._mappings.types import (
cast_labels_to_grammar,
cast_type_to_labels,
cast_type_to_list,
)
from marvin.ai.prompts.text_prompts import (
CAST_PROMPT,
Expand Down Expand Up @@ -169,6 +170,7 @@ async def _generate_typed_llm_response_with_tool(
async def _generate_typed_llm_response_with_logit_bias(
prompt_template: str,
prompt_kwargs: dict,
return_index: bool = False,
encoder: Callable[[str], list[int]] = None,
max_tokens: int = 1,
model_kwargs: dict = None,
Expand All @@ -191,6 +193,8 @@ async def _generate_typed_llm_response_with_logit_bias(
Args:
prompt_template (str): The template for the prompt.
prompt_kwargs (dict): Additional keyword arguments for the prompt.
return_index (bool, optional): Whether to return the index of the label
instead of the label itself.
encoder (Callable[[str], list[int]], optional): The encoder function to
use for the generation. Defaults to None.
max_tokens (int, optional): The maximum number of tokens for the
Expand All @@ -207,6 +211,7 @@ async def _generate_typed_llm_response_with_logit_bias(
if "labels" not in prompt_kwargs:
raise ValueError("Labels must be provided as a kwarg to the prompt template.")
labels = prompt_kwargs["labels"]
label_list = cast_type_to_list(labels)
label_strings = cast_type_to_labels(labels)
grammar = cast_labels_to_grammar(
labels=label_strings, encoder=encoder, max_tokens=max_tokens
Expand All @@ -222,11 +227,13 @@ async def _generate_typed_llm_response_with_logit_bias(
# the response contains a single number representing the index of the chosen
label_index = int(response.response.choices[0].message.content)

if return_index:
return label_index

if labels is bool:
return bool(label_index)

result = label_strings[label_index]
return labels(result) if isinstance(labels, type) else result
return label_list[label_index]


async def cast_async(
Expand Down Expand Up @@ -260,6 +267,9 @@ async def cast_async(
"""
model_kwargs = model_kwargs or {}

if not isinstance(data, str):
data = marvin.utilities.tools.output_to_string(data)

if target is None and instructions is None:
raise ValueError("Must provide either a target type or instructions.")
elif target is None:
Expand Down Expand Up @@ -320,11 +330,16 @@ async def extract_async(
Returns:
list: A list of extracted entities of the specified type.
"""
model_kwargs = model_kwargs or {}

if target is None and instructions is None:
raise ValueError("Must provide either a target type or instructions.")
elif target is None:
target = str
model_kwargs = model_kwargs or {}

if not isinstance(data, str):
data = marvin.utilities.tools.output_to_string(data)

return await _generate_typed_llm_response_with_tool(
prompt_template=EXTRACT_PROMPT,
prompt_kwargs=dict(data=data, instructions=instructions),
Expand All @@ -338,9 +353,10 @@ async def classify_async(
data: str,
labels: Union[Enum, list[T], type],
instructions: str = None,
return_index: bool = False,
model_kwargs: dict = None,
client: Optional[AsyncMarvinClient] = None,
) -> T:
) -> Union[T, int]:
"""
Classifies the provided data based on the provided labels.
Expand All @@ -354,18 +370,23 @@ async def classify_async(
labels (Union[Enum, list[T], type]): The labels to classify the data into.
instructions (str, optional): Specific instructions for the
classification. Defaults to None.
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI function.
Returns:
T: The label that the data was classified into.
Union[T, int]: The label or index that the data was classified into.
"""

model_kwargs = model_kwargs or {}
if not isinstance(data, str):
data = marvin.utilities.tools.output_to_string(data)

return await _generate_typed_llm_response_with_logit_bias(
prompt_template=CLASSIFY_PROMPT,
prompt_kwargs=dict(data=data, labels=labels, instructions=instructions),
return_index=return_index,
model_kwargs=model_kwargs | dict(temperature=0),
client=client,
)
Expand Down Expand Up @@ -754,9 +775,10 @@ def classify(
data: str,
labels: Union[Enum, list[T], type],
instructions: str = None,
return_index: bool = False,
model_kwargs: dict = None,
client: Optional[AsyncMarvinClient] = None,
) -> T:
) -> Union[T, int]:
"""
Classifies the provided data based on the provided labels.
Expand All @@ -770,18 +792,20 @@ def classify(
labels (Union[Enum, list[T], type]): The labels to classify the data into.
instructions (str, optional): Specific instructions for the
classification. Defaults to None.
return_index (bool, optional): Whether to return the index of the label instead of the label itself.
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI function.
Returns:
T: The label that the data was classified into.
Union[T, int]: The label or index that the data was classified into.
"""
return run_sync(
classify_async(
data=data,
labels=labels,
instructions=instructions,
return_index=return_index,
model_kwargs=model_kwargs,
client=client,
)
Expand Down Expand Up @@ -878,15 +902,17 @@ async def classify_async_map(
data: list[str],
labels: Union[Enum, list[T], type],
instructions: Optional[str] = None,
return_index: bool = False,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
) -> list[T]:
) -> list[Union[T, int]]:
return await map_async(
fn=classify_async,
map_kwargs=dict(data=data),
unmapped_kwargs=dict(
labels=labels,
instructions=instructions,
return_index=return_index,
model_kwargs=model_kwargs,
client=client,
),
Expand All @@ -897,14 +923,16 @@ def classify_map(
data: list[str],
labels: Union[Enum, list[T], type],
instructions: Optional[str] = None,
return_index: bool = False,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
) -> list[T]:
) -> list[Union[T, int]]:
return run_sync(
classify_async_map(
data=data,
labels=labels,
instructions=instructions,
return_index=return_index,
model_kwargs=model_kwargs,
client=client,
)
Expand Down
Loading

0 comments on commit f1500a0

Please sign in to comment.