Skip to content

Commit

Permalink
ai_classifier hookups
Browse files Browse the repository at this point in the history
  • Loading branch information
aaazzam committed Nov 15, 2023
1 parent e8359f0 commit b376ddd
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 37 deletions.
60 changes: 29 additions & 31 deletions src/marvin/components/ai_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@
class AIClassifier(BaseModel, Generic[P, T]):
fn: Optional[Callable[P, T]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=inspect.cleandoc("""
You are an expert classifier that always choose correctly.
- {{_doc}}
- You must classify `{{text}}` into one of the following classes:
{% for option in _options %}
Class {{ loop.index - 1}} (value: {{ option }})
{% endfor %}
ASSISTANT: The correct class label is Class
"""))
prompt: Optional[str] = Field(
default=inspect.cleandoc(
"You are an expert classifier that always choose correctly."
" \n- {{_doc}}"
" \n- You must classify `{{text}}` into one of the following classes:"
"{% for option in _options %}"
" Class {{ loop.index - 1}} (value: {{ option }})"
"{% endfor %}"
"ASSISTANT: The correct class label is Class"
)
)
enumerate: bool = True
encoder: Callable[[str], list[int]] = Field(default=None)
max_tokens: Optional[int] = 1
Expand Down Expand Up @@ -180,12 +182,11 @@ def ai_classifier(
*,
environment: Optional[BaseEnvironment] = None,
prompt: Optional[str] = None,
model_name: str = "FormatResponse",
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
enumerate: bool = True,
encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
max_tokens: Optional[int] = 1,
**render_kwargs: Any,
) -> Callable[[Callable[P, T]], Callable[P, list[T]]]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
pass


Expand All @@ -195,12 +196,11 @@ def ai_classifier(
*,
environment: Optional[BaseEnvironment] = None,
prompt: Optional[str] = None,
model_name: str = "FormatResponse",
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
enumerate: bool = True,
encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
max_tokens: Optional[int] = 1,
**render_kwargs: Any,
) -> Callable[P, list[T]]:
) -> Callable[P, T]:
pass


Expand All @@ -209,28 +209,26 @@ def ai_classifier(
*,
environment: Optional[BaseEnvironment] = None,
prompt: Optional[str] = None,
model_name: str = "FormatResponse",
model_description: str = "Formats the response.",
field_name: str = "data",
field_description: str = "The data to format.",
enumerate: bool = True,
encoder: Callable[[str], list[int]] = settings.openai.chat.completions.encoder,
max_tokens: Optional[int] = 1,
**render_kwargs: Any,
) -> Union[Callable[[Callable[P, T]], Callable[P, list[T]]], Callable[P, list[T]]]:
def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> list[T]:
) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]:
def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
return AIClassifier[P, T].as_decorator(
func,
environment=environment,
prompt=prompt,
model_name=model_name,
model_description=model_description,
field_name=field_name,
field_description=field_description,
enumerate=enumerate,
encoder=encoder,
max_tokens=max_tokens,
**render_kwargs,
)(*args, **kwargs)
)(*args, **kwargs)[0]

if fn is not None:
return wraps(fn)(partial(wrapper, fn))

def decorator(fn: Callable[P, T]) -> Callable[P, list[T]]:
def decorator(fn: Callable[P, T]) -> Callable[P, T]:
return wraps(fn)(partial(wrapper, fn))

return decorator
6 changes: 3 additions & 3 deletions src/marvin/components/ai_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def parse(self, response: "ChatCompletion") -> T:
model_description=self.description,
field_name=self.field_name,
field_description=self.field_description,
).function.model
if not tool:
).function
if not tool or not tool.model:
raise NotImplementedError

return getattr(tool.model_validate_json(arguments), self.field_name)
return getattr(tool.model.model_validate_json(arguments), self.field_name)

def as_prompt(
self,
Expand Down
5 changes: 3 additions & 2 deletions src/marvin/components/prompt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from pydantic import BaseModel
from typing_extensions import Self

from marvin import settings
from marvin.requests import BaseMessage as Message
from marvin.requests import Prompt
from marvin.serializers import (
create_grammar_from_vocabulary,
create_tool_from_type,
create_vocabulary_from_type,
)
from marvin.settings import settings
from marvin.utilities.jinja import (
BaseEnvironment,
Transcript,
Expand Down Expand Up @@ -176,6 +176,7 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self:
field_name=field_name,
field_description=field_description,
)

messages = Transcript(
content=prompt or func.__doc__ or ""
).render_to_messages(
Expand All @@ -193,7 +194,7 @@ def wrapper(func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> Self:
messages=messages,
tool_choice={
"type": "function",
"function": {"name": tool.function.name},
"function": {"name": getattr(tool.function, "name", model_name)},
},
tools=[tool],
)
Expand Down
2 changes: 1 addition & 1 deletion src/marvin/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Self

from marvin import settings
from marvin.settings import settings

T = TypeVar("T", bound=BaseModel)

Expand Down

0 comments on commit b376ddd

Please sign in to comment.