Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve how instructions are provided to high-level fns #1022

Merged
merged 2 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/marvin/fns/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
- When providing a string response, do not return JSON or a quoted string
unless they provided instructions requiring it. If you do return JSON, it
must be valid and parseable including double quotes.
- When converting to bool, treat "truthy" values as true

"""
- When converting to bool, treat "truthy" values as true"""


async def cast_async(
Expand Down Expand Up @@ -69,14 +67,18 @@ async def cast_async(
True

"""
if target is str and instructions is None:
raise ValueError("Instructions are required when casting to str")

task_context = context or {}
task_context["Data to transform"] = data
prompt = PROMPT
if instructions:
task_context["Additional instructions"] = instructions
prompt += f"\n\nYou must follow these instructions for your transformation:\n{instructions}"

task = marvin.Task[target](
name="Cast Task",
instructions=PROMPT,
instructions=prompt,
context=task_context,
result_type=target,
agents=[agent] if agent else None,
Expand Down
11 changes: 6 additions & 5 deletions src/marvin/fns/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
as possible when labeling text. You use inference or deduction whenever
necessary to understand missing or omitted data. Classify the provided `data`,
text, or information as one of the provided labels. For boolean labels,
consider "truthy" or affirmative inputs to be "true".
"""
consider "truthy" or affirmative inputs to be "true"."""


@overload
Expand Down Expand Up @@ -106,9 +105,11 @@ async def classify_async(

"""
task_context = context or {}
task_context["Data to classify"] = data
task_context.update({"Data to classify": data})

prompt = PROMPT
if instructions:
task_context["Additional instructions"] = instructions
prompt += f"\n\nYou must follow these instructions for your classification:\n{instructions}"

# Convert Enum class to sequence of values if needed
if labels is bool or issubclass_safe(labels, enum.Enum):
Expand All @@ -121,7 +122,7 @@ async def classify_async(

task = marvin.Task[result_type](
name="Classification Task",
instructions=PROMPT,
instructions=prompt,
context=task_context,
result_type=result_type,
agents=[agent] if agent else None,
Expand Down
8 changes: 4 additions & 4 deletions src/marvin/fns/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
- When providing integers, do not write out any decimals at all
- Use deduction where appropriate e.g. "3 dollars fifty cents" is a single
value [3.5] not two values [3, 50] unless the user specifically asks for
each part.
"""
each part."""


async def extract_async(
Expand Down Expand Up @@ -60,12 +59,13 @@ async def extract_async(

task_context = context or {}
task_context["Data to extract"] = data
prompt = PROMPT
if instructions:
task_context["Additional instructions"] = instructions
prompt += f"\n\nYou must follow these instructions for your extraction:\n{instructions}"

task = marvin.Task[list[target]](
name="Extraction Task",
instructions=PROMPT,
instructions=prompt,
context=task_context,
result_type=list[target],
agents=[agent] if agent else None,
Expand Down
8 changes: 4 additions & 4 deletions src/marvin/fns/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
If the user provides additional instructions or a description, assume they are
looking for examples that satisfy the description. Do not provide more
information than the user requests. For example, if they ask for various
technologies, give their names but do not explain what each technology is.
"""
technologies, give their names but do not explain what each technology is."""


async def generate_async(
Expand Down Expand Up @@ -62,12 +61,13 @@ async def generate_async(

task_context = context or {}
task_context["Number to generate"] = n
prompt = PROMPT
if instructions:
task_context["Additional instructions"] = instructions
prompt += f"\n\nYou must follow these instructions for your generation:\n{instructions}"

task = marvin.Task[list[target]](
name="Generation Task",
instructions=PROMPT,
instructions=prompt,
context=task_context,
result_type=conlist(target, min_length=n, max_length=n),
agents=[agent] if agent else None,
Expand Down
10 changes: 6 additions & 4 deletions src/marvin/fns/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
- Focus on the most important information
- Use clear, concise language
- Adapt tone and style to match the content type
- Ensure the summary stands alone as a coherent piece of text
"""
- Ensure the summary stands alone as a coherent piece of text"""


async def summarize_async(
Expand Down Expand Up @@ -59,12 +58,15 @@ async def summarize_async(
"""
task_context = context or {}
task_context["Data to summarize"] = data
prompt = PROMPT
if instructions:
task_context["Additional instructions"] = instructions
prompt += (
f"\n\nYou must follow these instructions for your summary:\n{instructions}"
)

task = marvin.Task[str](
name="Summarize Task",
instructions=PROMPT,
instructions=prompt,
context=task_context,
result_type=str,
agents=[agent] if agent else None,
Expand Down
4 changes: 3 additions & 1 deletion src/marvin/templates/task.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
</instructions>
{% if task.context %}
<context>
{{ task.context }}
{% for key, value in task.context.items() %}
<{{ key }}>{{ value }}</{{ key }}>
{% endfor %}
</context>
{% endif %}

Expand Down
2 changes: 1 addition & 1 deletion tests/ai/fns/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_cast_text_with_subtle_instructions(self):
result = marvin.cast(
"My name is marvin",
str,
instructions="Rewrite with uppercase names e.g. JOHN, but leave all other text unchanged",
instructions="Change only names to uppercase (e.g. JOHN), but leave all other text unchanged",
)
assert result == "My name is MARVIN"

Expand Down
Loading