Skip to content

Commit

Permalink
Merge pull request #615 from PrefectHQ/fix-ai-fn-response-model
Browse files Browse the repository at this point in the history
fix pydantic v1 behavior - allow `BaseModel` return for `ai_fn`
  • Loading branch information
zzstoatzz authored Oct 19, 2023
2 parents 778034b + a3b797b commit 0271fcc
Show file tree
Hide file tree
Showing 20 changed files with 113 additions and 73 deletions.
21 changes: 20 additions & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,34 @@ permissions:

jobs:
run_tests:
name: ${{ matrix.test-type }} w/ python ${{ matrix.python-version }} on ${{ matrix.os }}
name: ${{ matrix.test-type }} w/ python ${{ matrix.python-version }} | pydantic ${{ matrix.pydantic_version }} on ${{ matrix.os }}
timeout-minutes: 15
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.9', '3.10', '3.11']
test-type: ['not llm']
llm_model: ['openai/gpt-3.5-turbo']
pydantic_version: ['>=2.4.2']

include:
- python-version: '3.9'
os: 'ubuntu-latest'
test-type: 'llm'
llm_model: 'openai/gpt-3.5-turbo'
pydantic_version: '>=2.4.2'

- python-version: '3.9'
os: 'ubuntu-latest'
test-type: 'llm'
llm_model: 'openai/gpt-3.5-turbo'
pydantic_version: '<2'

- python-version: '3.9'
os: 'ubuntu-latest'
test-type: 'not llm'
llm_model: 'openai/gpt-3.5-turbo'
pydantic_version: '<2'

runs-on: ${{ matrix.os }}

Expand All @@ -60,6 +74,11 @@ jobs:
cache: "pip"
- name: Install Marvin
run: pip install ".[tests]"

- name: Install pydantic
run: pip install "pydantic${{ matrix.pydantic_version }}"


- name: Run ${{ matrix.test-type }} tests (${{ matrix.llm_model }})
run: pytest -vv -m "${{ matrix.test-type }}"
env:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dev = [
"mkdocstrings[python]~=0.22",
"pdbpp~=0.10",
"pre-commit>=2.21,<4.0",
"pydantic",
"ruff",
]
tests = [
Expand All @@ -48,8 +49,8 @@ tests = [
"pytest-rerunfailures>=10,<13",
"pytest-sugar~=0.9",
"pytest~=7.3.1",
"pydantic-settings>=2.0.0",
]

framework = [
"aiosqlite>=0.19.0",
"alembic>=1.11.1",
Expand Down
19 changes: 14 additions & 5 deletions src/marvin/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,45 @@
get_origin,
)

from pydantic import BaseModel, create_model
from pydantic.version import VERSION as PYDANTIC_VERSION

_ModelT = TypeVar("_ModelT", bound="BaseModel")

PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")

if PYDANTIC_V2:
from pydantic.v1 import (
BaseSettings,
Field,
PrivateAttr,
SecretStr,
ValidationError,
validate_arguments,
)

SettingsConfigDict = BaseSettings.Config

from pydantic import field_validator # noqa # type: ignore
from pydantic import (
BaseModel,
Field,
create_model,
field_validator,
)

else:
from pydantic import ( # noqa # type: ignore
BaseSettings,
BaseModel,
create_model,
Field,
SecretStr,
validate_arguments,
validator as field_validator,
ValidationError,
PrivateAttr,
)

SettingsConfigDict = BaseSettings.Config

_ModelT = TypeVar("_ModelT", bound=BaseModel)


def model_dump(model: _ModelT, **kwargs: Any) -> dict[str, Any]:
if PYDANTIC_V2 and hasattr(model, "model_dump"):
Expand Down
17 changes: 9 additions & 8 deletions src/marvin/components/ai_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from typing import Any, Callable, Optional, Union

from jsonpatch import JsonPatch
from pydantic import BaseModel, Field, validator

import marvin
from marvin._compat import PYDANTIC_V2, model_dump
from marvin._compat import PYDANTIC_V2, BaseModel, Field, field_validator, model_dump
from marvin.core.ChatCompletion.providers.openai import get_context_size
from marvin.openai import ChatCompletion
from marvin.prompts import library as prompt_library
Expand Down Expand Up @@ -218,17 +217,19 @@ class AIApplication(LoggerMixin, MarvinBaseModel):
state_enabled: bool = True
plan_enabled: bool = True

@validator("description")
@field_validator("description")
def validate_description(cls, v):
return inspect.cleandoc(v)

@validator("additional_prompts")
@field_validator("additional_prompts")
def validate_additional_prompts(cls, v):
if v is None:
v = []
return v

@validator("tools", pre=True, always=True)
@field_validator(
"tools", **(dict(pre=True, always=True) if not PYDANTIC_V2 else {})
)
def validate_tools(cls, v):
if v is None:
v = []
Expand All @@ -245,7 +246,7 @@ def validate_tools(cls, v):
raise ValueError(f"Tool {tool} is not a `Tool` or callable.")
return tools

@validator("name", always=True)
@field_validator("name")
def validate_name(cls, v):
if v is None:
v = cls.__name__
Expand Down Expand Up @@ -351,8 +352,8 @@ class JSONPatchModel(

op: str
path: str
value: Union[str, float, int, bool, list, dict] = None
from_: str = Field(None, alias="from")
value: Union[str, float, int, bool, list, dict, None] = None
from_: Optional[str] = Field(None, alias="from")


class UpdateState(Tool):
Expand Down
2 changes: 1 addition & 1 deletion src/marvin/components/ai_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from functools import partial
from typing import Any, Callable, Literal, Optional, TypeVar

from pydantic import BaseModel, Field
from typing_extensions import ParamSpec, Self

from marvin._compat import BaseModel, Field
from marvin.core.ChatCompletion import ChatCompletion
from marvin.core.ChatCompletion.abstract import AbstractChatCompletion
from marvin.prompts import Prompt, prompt_fn
Expand Down
29 changes: 19 additions & 10 deletions src/marvin/components/ai_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from functools import partial
from typing import Any, Awaitable, Callable, Generic, Optional, TypeVar, Union

from pydantic import BaseModel, Field
from typing_extensions import ParamSpec, Self

from marvin._compat import BaseModel, Field
from marvin.core.ChatCompletion import ChatCompletion
from marvin.core.ChatCompletion.abstract import AbstractChatCompletion
from marvin.prompts import Prompt, prompt_fn
Expand Down Expand Up @@ -42,7 +42,7 @@ def prompt_wrapper(*args: P.args, **kwargs: P.kwargs) -> None: # type: ignore #
{{'def' + ''.join(inspect.getsource(func).split('def')[1:])}}
The user will provide function inputs (if any) and you must respond with
the most likely result.
the most likely result, which must be valid, double-quoted JSON.
User: The function was called with the following inputs:
{% set sig = inspect.signature(func) %}
Expand Down Expand Up @@ -118,20 +118,29 @@ def call(
*args: P.args,
**kwargs: P.kwargs,
) -> Any:
return getattr(
self.as_chat_completion(*args, **kwargs).create().to_model(),
self.response_model_field_name or "output",
)
model_instance = self.as_chat_completion(*args, **kwargs).create().to_model()
response_model_field_name = self.response_model_field_name or "output"

if not (output := getattr(model_instance, response_model_field_name, None)):
return model_instance

return output

async def acall(
self,
*args: P.args,
**kwargs: P.kwargs,
) -> Any:
return getattr(
(await self.as_chat_completion(*args, **kwargs).acreate()).to_model(),
self.response_model_field_name or "output",
)
model_instance = (
await self.as_chat_completion(*args, **kwargs).acreate()
).to_model()

response_model_field_name = self.response_model_field_name or "output"

if not (output := getattr(model_instance, response_model_field_name, None)):
return model_instance

return output

def map(self, *map_args: list[Any], **map_kwargs: list[Any]):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/marvin/components/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from functools import partial
from typing import Any, Callable, Optional, TypeVar

from pydantic import BaseModel
from typing_extensions import ParamSpec, Self

from marvin._compat import BaseModel
from marvin.core.ChatCompletion import ChatCompletion
from marvin.core.ChatCompletion.abstract import AbstractChatCompletion
from marvin.prompts import Prompt, prompt_fn
Expand Down
3 changes: 1 addition & 2 deletions src/marvin/components/ai_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

from pydantic import BaseModel

from marvin._compat import BaseModel
from marvin.components.ai_model import ai_model


Expand Down
3 changes: 1 addition & 2 deletions src/marvin/components/library/ai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import Optional

import httpx
from pydantic import BaseModel
from typing_extensions import Self

from marvin import ai_model
from marvin._compat import Field, SecretStr, field_validator
from marvin._compat import BaseModel, Field, SecretStr, field_validator
from marvin.settings import MarvinBaseSettings


Expand Down
3 changes: 2 additions & 1 deletion src/marvin/core/ChatCompletion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Any, TypeVar
from pydantic import BaseModel
from .abstract import AbstractChatCompletion

from marvin._compat import BaseModel
from marvin.settings import settings

T = TypeVar(
Expand Down
3 changes: 1 addition & 2 deletions src/marvin/core/ChatCompletion/abstract.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, TypeVar

from marvin._compat import model_copy, model_dump
from marvin._compat import BaseModel, Field, model_copy, model_dump
from marvin.utilities.messages import Message
from pydantic import BaseModel, Field
from typing_extensions import Self

from .handlers import Request, Response, Turn
Expand Down
6 changes: 4 additions & 2 deletions src/marvin/core/ChatCompletion/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
overload,
)

from marvin._compat import cast_to_json, model_dump
from marvin._compat import BaseModel, Field, cast_to_json, model_dump
from marvin.utilities.async_utils import run_sync
from marvin.utilities.logging import get_logger
from marvin.utilities.messages import Message, Role
from pydantic import BaseModel, Field
from typing_extensions import ParamSpec

from .utils import parse_raw
Expand Down Expand Up @@ -84,6 +83,9 @@ class Choice(BaseModel):
index: int
finish_reason: str

class Config:
arbitrary_types_allowed = True


class Usage(BaseModel):
prompt_tokens: int
Expand Down
3 changes: 1 addition & 2 deletions src/marvin/core/ChatCompletion/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import inspect
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, Union

from marvin._compat import cast_to_json, model_dump
from marvin._compat import BaseModel, cast_to_json, model_dump
from marvin.settings import settings
from marvin.types import Function
from marvin.utilities.async_utils import create_task
from marvin.utilities.messages import Message
from marvin.utilities.streaming import StreamHandler
from openai.openai_object import OpenAIObject
from pydantic import BaseModel

from ..abstract import AbstractChatCompletion
from ..handlers import Request, Response, Usage
Expand Down
11 changes: 4 additions & 7 deletions src/marvin/openai/ChatCompletion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from pydantic import BaseModel, Field, validator, Extra, BaseSettings, root_validator
from pydantic.main import ModelMetaclass

from typing import Any, Callable, List, Optional, Type, Union, Literal
from marvin import settings
from marvin.types import Function
from typing import Any, Callable, Optional
from operator import itemgetter
from marvin.utilities.module_loading import import_string
import warnings
import copy

from marvin import settings
from marvin._compat import BaseModel, Extra, Field
from marvin.types.request import Request as BaseRequest
from marvin.engine import ChatCompletionBase

Expand Down
6 changes: 1 addition & 5 deletions src/marvin/openai/Function/Registry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from fastapi.routing import APIRouter
from pydantic import BaseModel, validate_arguments
from marvin.utilities.types import function_to_model
from marvin.utilities.messages import Message
from typing import Any
from marvin.openai.Function import openai_fn
from openai.openai_object import OpenAIObject
from marvin.functions import FunctionRegistry
Expand Down
14 changes: 5 additions & 9 deletions src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from ._compat import (
BaseSettings,
Field,
SecretStr,
model_dump,
)
Expand Down Expand Up @@ -155,22 +154,19 @@ class Settings(MarvinBaseSettings):
azure_openai: AzureOpenAI = AzureOpenAI()

# SLACK
slack_api_token: Optional[SecretStr] = Field(
default=None,
description="The Slack API token to use for the Slack client",
)
slack_api_token: Optional[SecretStr] = None

# TOOLS

# chroma
chroma_server_host: Optional[str] = Field(default=None)
chroma_server_http_port: Optional[int] = Field(default=None)
chroma_server_host: Optional[str] = None
chroma_server_http_port: Optional[int] = None

# github
github_token: Optional[SecretStr] = Field(default=None)
github_token: Optional[SecretStr] = None

# wolfram
wolfram_app_id: Optional[SecretStr] = Field(default=None)
wolfram_app_id: Optional[SecretStr] = None

def get_defaults(self, provider: Optional[str] = None) -> dict[str, Any]:
response: dict[str, Any] = {}
Expand Down
Loading

0 comments on commit 0271fcc

Please sign in to comment.