Skip to content

Commit

Permalink
top level .ai module
Browse files Browse the repository at this point in the history
  • Loading branch information
aaazzam committed Nov 16, 2023
1 parent 15f7d1b commit 12e1822
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/marvin/ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from marvin.components.ai_classifier import ai_classifier as classifier
from marvin.components.ai_function import ai_fn as fn
from marvin.components.ai_image import create_image as image
from marvin.components.ai_model import ai_model as model
from marvin.components.speech import speak

__all__ = ["speak", "fn", "model", "image", "classifier"]
3 changes: 3 additions & 0 deletions src/marvin/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from regex import B
from .ai_function import ai_fn, AIFunction
from .ai_classifier import ai_classifier, AIClassifier
from .ai_model import ai_model
from .ai_image import ai_image, AIImage
from .speech import speak
from .prompt import prompt_fn, PromptFunction

__all__ = [
"ai_fn",
"ai_classifier",
"ai_model",
"ai_image",
"speak",
"AIImage",
"prompt_fn",
"AIFunction",
Expand Down
17 changes: 13 additions & 4 deletions src/marvin/components/ai_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ class AIImage(BaseModel, Generic[P]):
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = Field(default=None)
name: str = "FormatResponse"
description: str = "Formats the response."
field_name: str = "data"
field_description: str = "The data to format."
render_kwargs: dict[str, Any] = Field(default_factory=dict)

generate: Optional[Callable[..., "ImagesResponse"]] = Field(default=None)
Expand Down Expand Up @@ -143,3 +139,16 @@ def decorator(fn: Callable[P, Any]) -> Callable[P, "ImagesResponse"]:
return wraps(fn)(partial(wrapper, fn))

return decorator


def create_image(
prompt: str,
environment: Optional[BaseEnvironment] = None,
generate: Optional[Callable[..., "ImagesResponse"]] = None,
**model_kwargs: Any,
) -> "ImagesResponse":
if generate is None:
from marvin.settings import settings

generate = settings.openai.images.generate
return generate(prompt=prompt, **model_kwargs)
73 changes: 73 additions & 0 deletions src/marvin/components/speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Literal,
Optional,
TypeVar,
)

from typing_extensions import ParamSpec

if TYPE_CHECKING:
from openai._base_client import HttpxBinaryResponseContent

T = TypeVar("T")

P = ParamSpec("P")


def speak(
input: str,
*,
create: Optional[Callable[..., "HttpxBinaryResponseContent"]] = None,
model: Optional[str] = "tts-1-hd",
voice: Optional[
Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
] = None,
response_format: Optional[Literal["mp3", "opus", "aac", "flac"]] = None,
speed: Optional[float] = None,
filepath: Path,
) -> None:
if create is None:
from marvin.settings import settings

create = settings.openai.audio.speech.create
return create(
input=input,
**({"model": model} if model else {}),
**({"voice": voice} if voice else {}),
**({"response_format": response_format} if response_format else {}),
**({"speed": speed} if speed else {}),
).stream_to_file(filepath)


async def aspeak(
input: str,
*,
acreate: Optional[
Callable[..., Coroutine[Any, Any, "HttpxBinaryResponseContent"]]
] = None,
model: Optional[str],
voice: Optional[
Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
] = None,
response_format: Optional[Literal["mp3", "opus", "aac", "flac"]] = None,
speed: Optional[float] = None,
filepath: Path,
) -> None:
if acreate is None:
from marvin.settings import settings

acreate = settings.openai.audio.speech.acreate
return (
await acreate(
input=input,
**({"model": model} if model else {}),
**({"voice": voice} if voice else {}),
**({"response_format": response_format} if response_format else {}),
**({"speed": speed} if speed else {}),
)
).stream_to_file(filepath)
40 changes: 40 additions & 0 deletions src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

if TYPE_CHECKING:
from openai import AsyncClient, Client
from openai._base_client import HttpxBinaryResponseContent
from openai.types.chat import ChatCompletion
from openai.types.images_response import ImagesResponse

Expand Down Expand Up @@ -89,6 +90,40 @@ def generate(self, prompt: str, **kwargs: Any) -> "ImagesResponse":
)


class SpeechSettings(MarvinModelSettings):
model: str = Field(
default="tts-1-hd",
description="The default image model to use.",
)
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field(
default="alloy",
)
response_format: Literal["mp3", "opus", "aac", "flac"] = Field(default="mp3")
speed: float = Field(default=1.0)

async def acreate(self, input: str, **kwargs: Any) -> "HttpxBinaryResponseContent":
from marvin.settings import settings

return await settings.openai.async_client.audio.speech.create(
model=kwargs.get("model", self.model),
input=input,
voice=kwargs.get("voice", self.voice),
response_format=kwargs.get("response_format", self.response_format),
speed=kwargs.get("speed", self.speed),
)

def create(self, input: str, **kwargs: Any) -> "HttpxBinaryResponseContent":
from marvin.settings import settings

return settings.openai.client.audio.speech.create(
model=kwargs.get("model", self.model),
input=input,
voice=kwargs.get("voice", self.voice),
response_format=kwargs.get("response_format", self.response_format),
speed=kwargs.get("speed", self.speed),
)


class AssistantSettings(MarvinModelSettings):
model: str = Field(
default="gpt-4-1106-preview",
Expand All @@ -100,6 +135,10 @@ class ChatSettings(MarvinSettings):
completions: ChatCompletionSettings = Field(default_factory=ChatCompletionSettings)


class AudioSettings(MarvinSettings):
speech: SpeechSettings = Field(default_factory=SpeechSettings)


class OpenAISettings(MarvinSettings):
model_config = SettingsConfigDict(env_prefix="marvin_openai_")

Expand All @@ -115,6 +154,7 @@ class OpenAISettings(MarvinSettings):

chat: ChatSettings = Field(default_factory=ChatSettings)
images: ImageSettings = Field(default_factory=ImageSettings)
audio: AudioSettings = Field(default_factory=AudioSettings)
assistants: AssistantSettings = Field(default_factory=AssistantSettings)

@property
Expand Down

0 comments on commit 12e1822

Please sign in to comment.