Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for classifying objects and returning indices #907

Merged
merged 9 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions docs/docs/text/classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ Marvin has a powerful classification tool that can be used to categorize text in
</p>
</div>


!!! example "Example: categorize user feedback"
Categorize user feedback into labels such as "bug", "feature request", or "inquiry":
Categorize user feedback into labels such as "bug", "feature request", or "inquiry":

```python
import marvin

category = marvin.classify(
"The app crashes when I try to upload a file.",
"The app crashes when I try to upload a file.",
labels=["bug", "feature request", "inquiry"]
)
```
Expand All @@ -28,30 +27,28 @@ Marvin has a powerful classification tool that can be used to categorize text in
assert category == "bug"
```


<div class="admonition info">
<p class="admonition-title">How it works</p>
<p>
Marvin enumerates your options, and uses a <a href="https://twitter.com/AAAzzam/status/1669753721574633473">clever logit bias trick</a> to force the LLM to deductively choose the index of the best option given your provided input. It then returns the choice associated with that index.
</p>
</div>


## Providing labels

Marvin's classification tool is designed to accommodate a variety of label formats, each suited to different use cases.

### Lists

When quick, ad-hoc categorization is required, a simple list of strings is the most straightforward approach. For example:
When quick, ad-hoc categorization is required, a simple list of values is the most straightforward approach. The result of the classifier is the matching label from the list. Marvin will attempt to convert your labels to strings if they are not already strings. If you are trying to classify complex objects that have unusual or no simple string representation, consider manually creating labels and [classifying by index](#returning-indices).

!!! example "Example: sentiment analysis"

```python
import marvin

sentiment = marvin.classify(
"Marvin is so easy to use!",
"Marvin is so easy to use!",
labels=["positive", "negative", "meh"]
)
```
Expand All @@ -61,7 +58,6 @@ When quick, ad-hoc categorization is required, a simple list of strings is the m
assert sentiment == "positive"
```


### Enums

For applications where classification labels are more structured and recurring, Enums provide an organized and maintainable solution:
Expand Down Expand Up @@ -107,6 +103,18 @@ request = marvin.classify("Reset my password", RequestType)
assert request == "account issue"
```

## Returning indices

In some cases, you may want to return the index of the selected label rather than the label itself:

```python
result = marvin.classify(
"Reset my password",
["support request", "account issue", "general inquiry"],
return_index=True,
)
assert result == 1
jlowin marked this conversation as resolved.
Show resolved Hide resolved
```

## Providing instructions

Expand Down Expand Up @@ -160,7 +168,6 @@ assert predicted_sentiment == "Positive"

While the primary focus is on the `classify` function, Marvin also includes the `classifier` decorator. Applied to Enums, it enables them to be used as classifiers that can be instantiated with natural language. This interface is particularly handy when dealing with a fixed set of labels commonly reused in your application.


```python
@marvin.classifier
class IssueType(Enum):
Expand All @@ -175,6 +182,7 @@ assert issue == IssueType.BUG
While convenient for certain scenarios, it's recommended to use the `classify` function for its greater flexibility and broader application range.

## Model parameters

You can pass parameters to the underlying API via the `model_kwargs` argument of `classify` or `@classifier`. These parameters are passed directly to the API, so you can use any supported parameter.

## Best practices
Expand All @@ -190,9 +198,9 @@ If you are using Marvin in an async environment, you can use `classify_async`:

```python
result = await marvin.classify_async(
"The app crashes when I try to upload a file.",
"The app crashes when I try to upload a file.",
labels=["bug", "feature request", "inquiry"]
)
)

assert result == "bug"
```
Expand Down
37 changes: 29 additions & 8 deletions src/marvin/_mappings/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from types import GenericAlias
from typing import Any, Callable, Literal, Optional, Union, get_args, get_origin

from pydantic import BaseModel, create_model
from pydantic import BaseModel, TypeAdapter, create_model
from pydantic.fields import FieldInfo

from marvin.settings import settings
Expand Down Expand Up @@ -79,20 +79,41 @@ def cast_type_to_toolset(
)


def cast_type_to_labels(
type_: Union[type, GenericAlias],
) -> list[str]:
def cast_type_to_labels(type_: Union[type, GenericAlias]) -> list[str]:
"""
Converts a type to a string list of its possible values.
"""
if get_origin(type_) == Literal:
return [str(token) for token in get_args(type_)]
elif isinstance(type_, type) and issubclass(type_, Enum):
members: list[str] = [
member_values: list[str] = [
option.value for option in getattr(type_, "__members__", {}).values()
]
return members
return member_values
elif isinstance(type_, list):
# typeadapter handles all types known to Pydantic
try:
return [TypeAdapter(type(t)).dump_json(t).decode() for t in type_]
except Exception as exc:
raise ValueError(f"Unable to cast type to labels: {exc}")
elif type_ is bool:
return ["false", "true"], [False, True]
else:
raise TypeError(f"Expected Literal, Enum, bool, or list, got {type_}.")


def cast_type_to_list(type_: Union[type, GenericAlias]) -> list:
"""
Converts a type to a list of its possible values.
"""
if get_origin(type_) == Literal:
return [token for token in get_args(type_)]
elif isinstance(type_, type) and issubclass(type_, Enum):
return list(type_)
elif isinstance(type_, list):
return [str(token) for token in type_]
return type_
elif type_ is bool:
return ["false", "true"]
return [False, True]
else:
raise TypeError(f"Expected Literal, Enum, bool, or list, got {type_}.")

Expand Down
46 changes: 37 additions & 9 deletions src/marvin/ai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from marvin._mappings.types import (
cast_labels_to_grammar,
cast_type_to_labels,
cast_type_to_list,
)
from marvin.ai.prompts.text_prompts import (
CAST_PROMPT,
Expand Down Expand Up @@ -169,6 +170,7 @@ async def _generate_typed_llm_response_with_tool(
async def _generate_typed_llm_response_with_logit_bias(
prompt_template: str,
prompt_kwargs: dict,
return_index: bool = False,
encoder: Callable[[str], list[int]] = None,
max_tokens: int = 1,
model_kwargs: dict = None,
Expand All @@ -191,6 +193,8 @@ async def _generate_typed_llm_response_with_logit_bias(
Args:
prompt_template (str): The template for the prompt.
prompt_kwargs (dict): Additional keyword arguments for the prompt.
return_index (bool, optional): Whether to return the index of the label
instead of the label itself.
encoder (Callable[[str], list[int]], optional): The encoder function to
use for the generation. Defaults to None.
max_tokens (int, optional): The maximum number of tokens for the
Expand All @@ -207,6 +211,7 @@ async def _generate_typed_llm_response_with_logit_bias(
if "labels" not in prompt_kwargs:
raise ValueError("Labels must be provided as a kwarg to the prompt template.")
labels = prompt_kwargs["labels"]
label_list = cast_type_to_list(labels)
label_strings = cast_type_to_labels(labels)
grammar = cast_labels_to_grammar(
labels=label_strings, encoder=encoder, max_tokens=max_tokens
Expand All @@ -222,11 +227,13 @@ async def _generate_typed_llm_response_with_logit_bias(
# the response contains a single number representing the index of the chosen
label_index = int(response.response.choices[0].message.content)

if return_index:
return label_index

if labels is bool:
return bool(label_index)

result = label_strings[label_index]
return labels(result) if isinstance(labels, type) else result
return label_list[label_index]


async def cast_async(
Expand Down Expand Up @@ -260,6 +267,9 @@ async def cast_async(
"""
model_kwargs = model_kwargs or {}

if not isinstance(data, str):
data = marvin.utilities.tools.output_to_string(data)

if target is None and instructions is None:
raise ValueError("Must provide either a target type or instructions.")
elif target is None:
Expand Down Expand Up @@ -320,11 +330,16 @@ async def extract_async(
Returns:
list: A list of extracted entities of the specified type.
"""
model_kwargs = model_kwargs or {}

if target is None and instructions is None:
raise ValueError("Must provide either a target type or instructions.")
elif target is None:
target = str
model_kwargs = model_kwargs or {}

if not isinstance(data, str):
data = marvin.utilities.tools.output_to_string(data)

return await _generate_typed_llm_response_with_tool(
prompt_template=EXTRACT_PROMPT,
prompt_kwargs=dict(data=data, instructions=instructions),
Expand All @@ -338,9 +353,10 @@ async def classify_async(
data: str,
labels: Union[Enum, list[T], type],
instructions: str = None,
return_index: bool = False,
model_kwargs: dict = None,
client: Optional[AsyncMarvinClient] = None,
) -> T:
) -> Union[T, int]:
"""
Classifies the provided data based on the provided labels.

Expand All @@ -354,18 +370,23 @@ async def classify_async(
labels (Union[Enum, list[T], type]): The labels to classify the data into.
instructions (str, optional): Specific instructions for the
classification. Defaults to None.

model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI function.

Returns:
T: The label that the data was classified into.
Union[T, int]: The label or index that the data was classified into.
"""

model_kwargs = model_kwargs or {}
if not isinstance(data, str):
data = marvin.utilities.tools.output_to_string(data)

return await _generate_typed_llm_response_with_logit_bias(
prompt_template=CLASSIFY_PROMPT,
prompt_kwargs=dict(data=data, labels=labels, instructions=instructions),
return_index=return_index,
model_kwargs=model_kwargs | dict(temperature=0),
client=client,
)
Expand Down Expand Up @@ -754,9 +775,10 @@ def classify(
data: str,
labels: Union[Enum, list[T], type],
instructions: str = None,
return_index: bool = False,
model_kwargs: dict = None,
client: Optional[AsyncMarvinClient] = None,
) -> T:
) -> Union[T, int]:
"""
Classifies the provided data based on the provided labels.

Expand All @@ -770,18 +792,20 @@ def classify(
labels (Union[Enum, list[T], type]): The labels to classify the data into.
instructions (str, optional): Specific instructions for the
classification. Defaults to None.
return_index (bool, optional): Whether to return the index of the label instead of the label itself.
model_kwargs (dict, optional): Additional keyword arguments for the
language model. Defaults to None.
client (AsyncMarvinClient, optional): The client to use for the AI function.

Returns:
T: The label that the data was classified into.
Union[T, int]: The label or index that the data was classified into.
"""
return run_sync(
classify_async(
data=data,
labels=labels,
instructions=instructions,
return_index=return_index,
model_kwargs=model_kwargs,
client=client,
)
Expand Down Expand Up @@ -878,15 +902,17 @@ async def classify_async_map(
data: list[str],
labels: Union[Enum, list[T], type],
instructions: Optional[str] = None,
return_index: bool = False,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
) -> list[T]:
) -> list[Union[T, int]]:
return await map_async(
fn=classify_async,
map_kwargs=dict(data=data),
unmapped_kwargs=dict(
labels=labels,
instructions=instructions,
return_index=return_index,
model_kwargs=model_kwargs,
client=client,
),
Expand All @@ -897,14 +923,16 @@ def classify_map(
data: list[str],
labels: Union[Enum, list[T], type],
instructions: Optional[str] = None,
return_index: bool = False,
model_kwargs: Optional[dict] = None,
client: Optional[AsyncMarvinClient] = None,
) -> list[T]:
) -> list[Union[T, int]]:
return run_sync(
classify_async_map(
data=data,
labels=labels,
instructions=instructions,
return_index=return_index,
model_kwargs=model_kwargs,
client=client,
)
Expand Down
Loading
Loading