diff --git a/libs/partners/cohere/langchain_cohere/multi_hop/agent.py b/libs/partners/cohere/langchain_cohere/multi_hop/agent.py index a058026df067d..bfefd5c08cfcc 100644 --- a/libs/partners/cohere/langchain_cohere/multi_hop/agent.py +++ b/libs/partners/cohere/langchain_cohere/multi_hop/agent.py @@ -6,6 +6,7 @@ from langchain_core.agents import AgentAction, AgentFinish from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser +from langchain_core.prompts import PromptTemplate from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool @@ -188,14 +189,35 @@ def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: def render_tool_description(tool: BaseTool) -> str: - """Render the tool name and description. TODO: better way to render?""" - template = """```python - def {TOOL_NAME}(): - \"\"\"{DESCRIPTION} - \"\"\" + """Render the tool in the style of a Python function.""" + function_signature = [] + args_description = [] + for parameter_name, parameter_definition in tool.args.items(): + if "default" in parameter_definition: + parameter_type = f"Optional[{parameter_definition.get("type", "str")}]" + else: + parameter_type = parameter_definition.get("type", "str") + function_signature += f"{parameter_name}: {parameter_type}" + args_description += f"{parameter_name} ({parameter_type}): {parameter_definition.get("description", "")}" + + if args_description: + args_description = """Args: + """+"\n ".join(args_description) + + template = PromptTemplate(template_format="f-string").from_template("""```python + def {TOOL_NAME}({SIGNATURE}) -> List[Dict]: + \"\"\" + {DESCRIPTION} + {ARGS_DESCRIPTION} + \"\"\" pass - ```""" - return template.format(TOOL_NAME=tool.name, DESCRIPTION=tool.description) + ```""") + return template.format( + TOOL_NAME=tool.name, + DESCRIPTION=tool.description, + SIGNATURE=", ".join(function_signature), + ARGS_DESCRIPTION=args_description, + ) def format_cohere_log_to_str(