diff --git a/cookbook/apps/multi_agent.py b/cookbook/apps/multi_agent.py new file mode 100644 index 000000000..c09f56972 --- /dev/null +++ b/cookbook/apps/multi_agent.py @@ -0,0 +1,27 @@ +from marvin import AIApplication + + +def get_foo(): + """A function that returns the value of foo.""" + return 42 + + +worker = AIApplication( + name="worker", + description="A simple worker application.", + plan_enabled=False, + state_enabled=False, + tools=[get_foo], +) + +router = AIApplication( + name="router", + description="routes user requests to the appropriate worker", + plan_enabled=False, + state_enabled=False, + tools=[worker], +) + +message = router("what is the value of foo?") + +assert "42" in message.content, "The answer should be 42." diff --git a/src/marvin/components/ai_application.py b/src/marvin/components/ai_application.py index 4d15c775d..4cdbedebc 100644 --- a/src/marvin/components/ai_application.py +++ b/src/marvin/components/ai_application.py @@ -238,7 +238,7 @@ def validate_tools(cls, v): # convert AI Applications and functions to tools for tool in v: if isinstance(tool, (AIApplication, Tool)): - tools.append(tool.as_function()) + tools.append(tool.as_function(description=tool.description)) elif callable(tool): tools.append(tool) else: @@ -305,11 +305,15 @@ async def run(self, input_text: str = None, model: str = None) -> Message: self.logger.debug_kv("AI response", last_message.content, key_style="blue") return last_message - def as_tool(self, name: str = None) -> Tool: - return AIApplicationTool(app=self, name=name) + def as_tool( + self, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> Tool: + return AIApplicationTool(app=self, name=name, description=description) - def as_function(self, name: str = None) -> Callable: - return self.as_tool(name=name).as_function() + def as_function(self, name: str = None, description: str = None) -> Callable: + return self.as_tool(name=name, description=description).as_function() class AIApplicationTool(Tool): diff --git a/src/marvin/tools/base.py b/src/marvin/tools/base.py index 6f5823beb..d3c455084 100644 --- a/src/marvin/tools/base.py +++ b/src/marvin/tools/base.py @@ -43,9 +43,12 @@ def argument_schema(self) -> dict: schema.pop("title", None) return schema - def as_function(self) -> Function: - description = jinja_env.from_string(inspect.cleandoc(self.description or "")) - description = description.render(**self.dict(), TOOL=self) + def as_function(self, description: Optional[str] = None) -> Function: + if not description: + description = jinja_env.from_string( + inspect.cleandoc(self.description or "") + ) + description = description.render(**self.dict(), TOOL=self) def fn(*args, **kwargs): return self.run(*args, **kwargs) diff --git a/tests/test_components/test_ai_app.py b/tests/test_components/test_ai_app.py index 6eeefcbe3..e5ae37dab 100644 --- a/tests/test_components/test_ai_app.py +++ b/tests/test_components/test_ai_app.py @@ -165,6 +165,7 @@ def test_update_app_plan_non_existent_path(self): @pytest_mark_class("llm") class TestUpdatePlan: + @pytest.mark.flaky(max_runs=3) def test_keep_app_plan(self): app = AIApplication( name="Zoo planner app",