diff --git a/src/controlflow/llm/rules.py b/src/controlflow/llm/rules.py index 259a02d6..ca3a706e 100644 --- a/src/controlflow/llm/rules.py +++ b/src/controlflow/llm/rules.py @@ -1,3 +1,4 @@ +import textwrap from typing import Optional from langchain_anthropic import ChatAnthropic @@ -16,6 +17,8 @@ class LLMRules(ControlFlowModel): necessary. """ + model: Optional[BaseChatModel] + # require at least one non-system message require_at_least_one_message: bool = False @@ -41,10 +44,30 @@ class LLMRules(ControlFlowModel): # the name associated with a message must conform to a specific format require_message_name_format: Optional[str] = None + def model_instructions(self) -> Optional[list[str]]: + pass + class OpenAIRules(LLMRules): require_message_name_format: str = r"[^a-zA-Z0-9_-]" + model: ChatOpenAI + + def model_instructions(self) -> list[str]: + instructions = [] + if self.model.model_name.endswith("gpt-4o-mini"): + instructions.append( + textwrap.dedent( + """ + You can only provide a single result for each task, and a + task can only be marked successful one time. Do not make + multiple tool calls in parallel to supply multiple results + to the same task. + """ + ) + ) + return instructions + class AnthropicRules(LLMRules): require_at_least_one_message: bool = True @@ -56,8 +79,8 @@ class AnthropicRules(LLMRules): def rules_for_model(model: BaseChatModel) -> LLMRules: if isinstance(model, (ChatOpenAI, AzureChatOpenAI)): - return OpenAIRules() + return OpenAIRules(model=model) elif isinstance(model, ChatAnthropic): - return AnthropicRules() + return AnthropicRules(model=model) else: - return LLMRules() + return LLMRules(model=model) diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index b11c970b..6fc4ed7a 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -392,11 +392,13 @@ def compile_prompt(self) -> str: """ from controlflow.orchestration.prompt_templates import ( InstructionsTemplate, + LLMInstructionsTemplate, TasksTemplate, ToolTemplate, ) tools = self.get_tools() + llm_rules = self.agent.get_llm_rules() prompts = [ self.agent.get_prompt(), @@ -404,7 +406,11 @@ def compile_prompt(self) -> str: TasksTemplate(tasks=self.get_tasks("ready")).render(), ToolTemplate(tools=tools).render(), InstructionsTemplate(instructions=get_instructions()).render(), + LLMInstructionsTemplate( + instructions=llm_rules.model_instructions() + ).render(), ] + prompt = "\n\n".join([p for p in prompts if p]) return prompt diff --git a/src/controlflow/orchestration/prompt_templates.py b/src/controlflow/orchestration/prompt_templates.py index 47b73b5e..8b9418aa 100644 --- a/src/controlflow/orchestration/prompt_templates.py +++ b/src/controlflow/orchestration/prompt_templates.py @@ -78,6 +78,14 @@ def should_render(self) -> bool: return bool(self.instructions) +class LLMInstructionsTemplate(Template): + template_path: str = "llm_instructions.jinja" + instructions: list[str] = [] + + def should_render(self) -> bool: + return bool(self.instructions) + + class ToolTemplate(Template): template_path: str = "tools.jinja" tools: list[Tool] diff --git a/src/controlflow/orchestration/prompt_templates/instructions.jinja b/src/controlflow/orchestration/prompt_templates/instructions.jinja index b6d068ec..8151ec15 100644 --- a/src/controlflow/orchestration/prompt_templates/instructions.jinja +++ b/src/controlflow/orchestration/prompt_templates/instructions.jinja @@ -1,6 +1,6 @@ # Instructions -You must follow these instructions. Note that instructions can be changed at any time. +You must follow these instructions at all times. Note that instructions can be changed at any time. {% for instruction in instructions %} - {{ instruction }} diff --git a/src/controlflow/orchestration/prompt_templates/llm_instructions.jinja b/src/controlflow/orchestration/prompt_templates/llm_instructions.jinja new file mode 100644 index 00000000..e1c0f346 --- /dev/null +++ b/src/controlflow/orchestration/prompt_templates/llm_instructions.jinja @@ -0,0 +1,9 @@ +# LLM Instructions + +These instructions are specific to your LLM model. They must be followed to ensure compliance with the orchestrator and +other agents. + +{% for instruction in instructions %} +- {{ instruction }} + +{% endfor %} \ No newline at end of file diff --git a/src/controlflow/orchestration/prompt_templates/tasks.jinja b/src/controlflow/orchestration/prompt_templates/tasks.jinja index 338fbde2..780712a8 100644 --- a/src/controlflow/orchestration/prompt_templates/tasks.jinja +++ b/src/controlflow/orchestration/prompt_templates/tasks.jinja @@ -23,11 +23,18 @@ The following tasks are active: {% endfor %} -Only agents assigned to a task are able to mark the task as complete. You must use a tool to end your turn to let other -agents participate. If you are asked to talk to other agents, post messages. Do not impersonate another agent! Do not -impersonate the orchestrator! +Only agents assigned to a task are able to mark the task as complete. You must +use a tool to end your turn to let other agents participate. If you are asked to +talk to other agents, post messages. Do not impersonate another agent! Do not +impersonate the orchestrator! If you have been assigned a task, then you (and +other agents) must have the resources, knowledge, or tools required to complete +it. + +A task can only be marked complete one time. Do not attempt to mark a task +successful more than once. Even if the `result_type` does not appear to match +the objective, you must supply a single compatible result. Only mark a task +failed if there is a technical error or issue preventing completion. -Only mark a task failed if there is a technical error or issue preventing completion. ## Task hierarchy diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index ff0f6221..76253ee0 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -1,4 +1,5 @@ import datetime +import textwrap import warnings from contextlib import ExitStack, contextmanager from enum import Enum @@ -511,7 +512,10 @@ def create_success_tool(self) -> Tool: Create an agent-compatible tool for marking this task as successful. """ options = {} - instructions = None + instructions = textwrap.dedent(""" + Use this tool to mark the task as successful and provide a result. + This tool can only be used one time per task. + """) result_schema = None # if the result_type is a tuple of options, then we want the LLM to provide @@ -532,10 +536,12 @@ def create_success_tool(self) -> Tool: options_str = "\n\n".join( f"Option {i}: {option}" for i, option in serialized_options.items() ) - instructions = f""" + instructions += "\n\n" + textwrap.dedent(""" Provide a single integer as the result, corresponding to the index - of your chosen option. Your options are: {options_str} - """ + of your chosen option. Your options are: + + {options_str} + """).format(options_str=options_str) # otherwise try to load the schema for the result type elif self.result_type is not None: