From 9c20ab1ad994809aeb86f3b34713b380a69d0025 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sun, 18 Feb 2024 16:01:47 -0500 Subject: [PATCH] Allow passing instructions to @model --- docs/docs/text/transformation.md | 16 ++++++++ src/marvin/ai/text.py | 16 +++++++- tests/ai/test_models.py | 69 -------------------------------- 3 files changed, 30 insertions(+), 71 deletions(-) diff --git a/docs/docs/text/transformation.md b/docs/docs/text/transformation.md index 9695f80b7..633a63295 100644 --- a/docs/docs/text/transformation.md +++ b/docs/docs/text/transformation.md @@ -116,6 +116,22 @@ Location('CHI') ## Model parameters You can pass parameters to the underlying API via the `model_kwargs` argument of `cast` or `@model`. These parameters are passed directly to the API, so you can use any supported parameter. +### Instructions + +You can pass instructions to steer model transformation via the `instructions` parameter: + +```python +@marvin.model(instructions='Always generate locations in California') +class Location(BaseModel): + city: str + state: str + +Location('a large city') +# Location(city='Los Angeles', state='California') +``` + +Note that instructions are set at the class level, so they will apply to all instances of the model. To customize instructions on a per-instance basis, use `cast` with the `instructions` parameter instead. + ## Async support If you are using `marvin` in an async environment, you can use `cast_async`: diff --git a/src/marvin/ai/text.py b/src/marvin/ai/text.py index 2bffcf594..875625834 100644 --- a/src/marvin/ai/text.py +++ b/src/marvin/ai/text.py @@ -570,6 +570,7 @@ def __init__( self, text: Optional[str] = None, *, + instructions: Optional[str] = None, model_kwargs: Optional[dict] = None, client: Optional[MarvinClient] = None, **kwargs, @@ -580,6 +581,7 @@ def __init__( Args: text (str, optional): The natural language string to convert into an instance of the model. Defaults to None. + instructions (str, optional): Specific instructions for the conversion. model_kwargs (dict, optional): Additional keyword arguments for the language model. Defaults to None. **kwargs: Additional keyword arguments to pass to the model's constructor. @@ -587,7 +589,11 @@ def __init__( ai_kwargs = kwargs if text is not None: ai_kwargs = cast( - text, type(self), model_kwargs=model_kwargs, client=client + text, + type(self), + instructions=instructions, + model_kwargs=model_kwargs, + client=client, ).model_dump() ai_kwargs.update(kwargs) super().__init__(**ai_kwargs) @@ -644,6 +650,7 @@ def new(cls, value): def model( type_: Union[Type[M], None] = None, + instructions: Optional[str] = None, model_kwargs: Optional[dict] = None, client: Optional[MarvinClient] = None, ) -> Union[Type[M], Callable[[Type[M]], Type[M]]]: @@ -656,6 +663,7 @@ def model( Args: type_ (Union[Type[M], None], optional): The type of the Pydantic model. Defaults to None. + instructions (str, optional): Specific instructions for the conversion. model_kwargs (dict, optional): Additional keyword arguments for the language model. Defaults to None. @@ -669,7 +677,11 @@ class WrappedModel(Model, cls): @wraps(cls.__init__) def __init__(self, *args, **kwargs): super().__init__( - *args, model_kwargs=model_kwargs, client=client, **kwargs + *args, + instructions=instructions, + model_kwargs=model_kwargs, + client=client, + **kwargs, ) WrappedModel.__name__ = cls.__name__ diff --git a/tests/ai/test_models.py b/tests/ai/test_models.py index 25007543b..6d457712a 100644 --- a/tests/ai/test_models.py +++ b/tests/ai/test_models.py @@ -156,20 +156,7 @@ class Fruit(BaseModel): assert isinstance(fruit, Fruit) -@pytest.mark.skip(reason="old behavior, may revisit") class TestInstructions: - def test_instructions_error(self): - @marvin.model - class Test(BaseModel): - text: str - - with pytest.raises( - ValueError, match="(Received `instructions` but this model)" - ): - Test("Hello!", instructions="Translate to French") - with pytest.raises(ValueError, match="(Received `model` but this model)"): - Test("Hello!", model=None) - def test_instructions(self): @marvin.model class Text(BaseModel): @@ -186,62 +173,6 @@ class Text(BaseModel): t2 = Text("Hello") assert t2.text == "Bonjour" - def test_follow_instance_instructions(self): - @marvin.model - class Test(BaseModel): - text: str - - t1 = Test("Hello") - assert t1.text == "Hello" - - # this model is identical except it has an instruction - @marvin.model - class Test(BaseModel): - text: str - - t2 = Test("Hello", instructions_="first translate the text to French") - assert t2.text == "Bonjour" - - def test_follow_global_and_instance_instructions(self): - @marvin.model(instructions="Always set color_1 to 'red'") - class Test(BaseModel): - color_1: str - color_2: str - - t1 = Test("Hello", instructions_="Always set color_2 to 'blue'") - assert t1 == Test(color_1="red", color_2="blue") - - def test_follow_docstring_and_global_and_instance_instructions(self): - @marvin.model(instructions="Always set color_1 to 'red'") - class Test(BaseModel): - """Always set color_3 to 'orange'""" - - color_1: str - color_2: str - color_3: str - - t1 = Test("Hello", instructions_="Always set color_2 to 'blue'") - assert t1 == Test(color_1="red", color_2="blue", color_3="orange") - - def test_follow_multiple_instructions(self): - # ensure that instructions don't bleed to other invocations - @marvin.model - class Translation(BaseModel): - """Translates from one language to another language""" - - original_text: str - translated_text: str - - t1 = Translation("Hello, world!", instructions_="Translate to French") - t2 = Translation("Hello, world!", instructions_="Translate to German") - - assert t1 == Translation( - original_text="Hello, world!", translated_text="Bonjour, monde!" - ) - assert t2 == Translation( - original_text="Hello, world!", translated_text="Hallo, Welt!" - ) - class TestAsync: async def test_basic_async(self):