Skip to content

Commit

Permalink
Update speech
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Jan 4, 2024
1 parent 44d88c6 commit 6abfb32
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/marvin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .settings import settings

from .components import fn, classifier, image
from .components import fn, classifier, image, speech
from .components.prompt.fn import prompt_fn

# compatibility with Marvin v1
Expand All @@ -14,6 +14,8 @@
__all__ = [
"fn",
"classifier",
"image",
"speech",
"prompt_fn",
"settings",
]
4 changes: 2 additions & 2 deletions src/marvin/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .function import fn, Function
from .classifier import classifier, Classifier
from .model import model
from .image import image, Image
from .image import image
from .speech import speech
from .prompt.fn import prompt_fn, PromptFunction

__all__ = [
Expand All @@ -10,7 +11,6 @@
"model",
"image",
"prompt_fn",
"Image",
"Function",
"Classifier",
"PromptFunction",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@
P = ParamSpec("P")


class AISpeechKwargs(TypedDict):
class SpeechKwargs(TypedDict):
environment: NotRequired[BaseEnvironment]
prompt: NotRequired[str]
client: NotRequired[Client]
aclient: NotRequired[AsyncClient]


class AISpeechKwargsDefaults(BaseModel):
class SpeechKwargsDefaults(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
environment: Optional[BaseEnvironment] = None
prompt: Optional[str] = SPEECH_PROMPT
client: Optional[Client] = None
aclient: Optional[AsyncClient] = None


class AISpeech(BaseModel, Generic[P]):
class Speech(BaseModel, Generic[P]):
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
fn: Optional[Callable[P, Any]] = None
environment: Optional[BaseEnvironment] = None
Expand Down Expand Up @@ -82,7 +82,7 @@ def as_prompt(
@classmethod
def as_decorator(
cls: type[Self],
**kwargs: Unpack[AISpeechKwargs],
**kwargs: Unpack[SpeechKwargs],
) -> Callable[P, Self]:
pass

Expand All @@ -91,15 +91,15 @@ def as_decorator(
def as_decorator(
cls: type[Self],
fn: Callable[P, Any],
**kwargs: Unpack[AISpeechKwargs],
**kwargs: Unpack[SpeechKwargs],
) -> Self:
pass

@classmethod
def as_decorator(
cls: type[Self],
fn: Optional[Callable[P, Any]] = None,
**kwargs: Unpack[AISpeechKwargs],
**kwargs: Unpack[SpeechKwargs],
) -> Union[Self, Callable[[Callable[P, Any]], Self]]:
passed_kwargs: dict[str, Any] = {
k: v for k, v in kwargs.items() if v is not None
Expand All @@ -116,9 +116,9 @@ def as_decorator(
)


def ai_speech(
def speech(
fn: Optional[Callable[P, Any]] = None,
**kwargs: Unpack[AISpeechKwargs],
**kwargs: Unpack[SpeechKwargs],
) -> Union[
Callable[
[Callable[P, Any]],
Expand All @@ -129,8 +129,8 @@ def ai_speech(
def wrapper(
func: Callable[P, Any], *args_: P.args, **kwargs_: P.kwargs
) -> Union[AudioResponse, Coroutine[Any, Any, AudioResponse]]:
f = AISpeech[P].as_decorator(
func, **AISpeechKwargsDefaults(**kwargs).model_dump(exclude_none=True)
f = Speech[P].as_decorator(
func, **SpeechKwargsDefaults(**kwargs).model_dump(exclude_none=True)
)
return f(*args_, **kwargs_)

Expand Down

0 comments on commit 6abfb32

Please sign in to comment.