Skip to content

Commit

Permalink
work in progress - prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
harry-cohere committed Mar 27, 2024
1 parent 4ff583b commit f2d24a6
Showing 1 changed file with 37 additions and 23 deletions.
60 changes: 37 additions & 23 deletions libs/partners/cohere/langchain_cohere/multi_hop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts import PromptTemplate, BasePromptTemplate
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool
Expand All @@ -25,9 +25,7 @@ def create_cohere_multi_hop_agent(
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
) -> Runnable:
multi_hop_prompt = multi_hop_prompt_template.partial(
tools="\n".join([render_tool_description(t) for t in tools]),
)
multi_hop_prompt = _render_prompt(tools=tools)
llm_with_tools = llm.bind(stop=["\nObservation:"])

agent = (
Expand All @@ -46,6 +44,12 @@ def create_cohere_multi_hop_agent(
return agent


def _render_prompt(tools: Sequence[BaseTool]) -> BasePromptTemplate:
return multi_hop_prompt_template.partial(
tools="\n".join([render_tool_description(t) for t in tools]),
)


class CohereToolsMultiHopAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
):
Expand Down Expand Up @@ -190,33 +194,43 @@ def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:

def render_tool_description(tool: BaseTool) -> str:
"""Render the tool in the style of a Python function."""
function_signature = []
args_description = []
for parameter_name, parameter_definition in tool.args.items():
if "default" in parameter_definition:
parameter_type = f"Optional[{parameter_definition.get("type", "str")}]"
def get_type(type_: str, is_optional: bool) -> str:
if is_optional:
return f"Optional[{type_}]"
else:
parameter_type = parameter_definition.get("type", "str")
function_signature += f"{parameter_name}: {parameter_type}"
args_description += f"{parameter_name} ({parameter_type}): {parameter_definition.get("description", "")}"

if args_description:
args_description = """Args:
"""+"\n ".join(args_description)
return type_

def function_signature(tool_: BaseTool) -> str:
signature = []
for parameter_name, parameter_definition in tool_.args.items():
type_ = get_type(parameter_definition.get("type"), "default" in parameter_definition)

Check failure on line 206 in libs/partners/cohere/langchain_cohere/multi_hop/agent.py

View workflow job for this annotation

GitHub Actions / cd libs/partners/cohere / make lint #3.11

Ruff (E501)

langchain_cohere/multi_hop/agent.py:206:89: E501 Line too long (97 > 88)
signature += f"{parameter_name}: {type_}"
return f"def {tool_.name}({",".join(signature)}) -> List[Dict]"

def args_description(tool_: BaseTool) -> str:
if not tool_.args:
return ""
rendered_args = []
for parameter_name, parameter_definition in tool_.args.items():
type_ = get_type(parameter_definition.get("type"), "default" in parameter_definition)

Check failure on line 215 in libs/partners/cohere/langchain_cohere/multi_hop/agent.py

View workflow job for this annotation

GitHub Actions / cd libs/partners/cohere / make lint #3.11

Ruff (E501)

langchain_cohere/multi_hop/agent.py:215:89: E501 Line too long (97 > 88)
rendered_args += f"{parameter_name} ({type_}): {parameter_definition.get("description")}"

Check failure on line 216 in libs/partners/cohere/langchain_cohere/multi_hop/agent.py

View workflow job for this annotation

GitHub Actions / cd libs/partners/cohere / make lint #3.11

Ruff (E501)

langchain_cohere/multi_hop/agent.py:216:89: E501 Line too long (101 > 88)
rendered_args = "\n ".join(rendered_args)

return """
Args:
""" + rendered_args

template = PromptTemplate(template_format="f-string").from_template("""```python
def {TOOL_NAME}({SIGNATURE}) -> List[Dict]:
{signature}
\"\"\"
{DESCRIPTION}
{ARGS_DESCRIPTION}
{description}{args_description}
\"\"\"
pass
```""")
return template.format(
TOOL_NAME=tool.name,
DESCRIPTION=tool.description,
SIGNATURE=", ".join(function_signature),
ARGS_DESCRIPTION=args_description,
signature=function_signature(tool),
description=tool.description,
args_description=args_description,
)


Expand Down

0 comments on commit f2d24a6

Please sign in to comment.