Skip to content

Commit

Permalink
Merge pull request #32 from Forethought-Technologies/feat/yi/tool_desp
Browse files Browse the repository at this point in the history
Allow adding tool description using arg_description in Tool for function calling using openai
  • Loading branch information
yyiilluu authored Jun 19, 2023
2 parents 749ccd9 + a7096c6 commit 74bf5b3
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import logging
from string import Template
from typing import Any, Dict, List, Optional, Union

from colorama import Fore

from autochain.agent.base_agent import BaseAgent
from autochain.agent.message import ChatMessageHistory, UserMessage, SystemMessage
from autochain.agent.message import ChatMessageHistory, SystemMessage
from autochain.agent.openai_funtions_agent.output_parser import (
OpenAIFunctionOutputParser,
)
Expand Down
21 changes: 16 additions & 5 deletions autochain/models/chat_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
def convert_tool_to_dict(tool: Tool):
"""Convert tool into function parameter for openai"""
inspection = inspect.getfullargspec(tool.func)
arg_description = tool.arg_description or {}

def _type_to_string(t: type) -> str:
prog = re.compile(r"<class '(\w+)'>")
Expand All @@ -79,17 +80,27 @@ def _type_to_string(t: type) -> str:

return str(t)

def _format_property(t: type, arg_desp: str):
p = {"type": _type_to_string(t)}
if arg_desp:
p["description"] = arg_desp

return p

arg_annotations = inspection.annotations
if arg_annotations:
properties = {
arg: {"type": _type_to_string(t)} for arg, t in arg_annotations.items()
arg: _format_property(t, arg_description.get(arg))
for arg, t in arg_annotations.items()
}
else:
properties = {arg: {"type": "string"} for arg, t in inspection.args}
properties = {
arg: _format_property(str, arg_description.get(arg))
for arg in inspection.args
}

required_args = (
inspection.args[: len(inspection.defaults)] if inspection.defaults else []
)
default_args = inspection.defaults or []
required_args = inspection.args[: len(inspection.args) - len(default_args)]

output = {
"name": tool.name,
Expand Down
18 changes: 17 additions & 1 deletion autochain/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base implementation for tools or skills."""
from __future__ import annotations

import inspect
from abc import ABC
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union

Expand All @@ -25,6 +26,8 @@ class Tool(ABC, BaseModel):
You can provide few-shot examples as a part of the description.
"""

arg_description: Optional[Dict[str, Any]] = None

args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""

Expand All @@ -41,8 +44,21 @@ class Tool(ABC, BaseModel):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values.get("func") and not values.get("name"):
func = values.get("func")
if func and not values.get("name"):
values["name"] = values["func"].__name__

# check if all args from arg_description exist in func args
if values.get("arg_description") and func:
inspection = inspect.getfullargspec(func)
override_args = set(values["arg_description"].keys())
args = set(inspection.args)
override_without_args = override_args - args
if len(override_without_args) > 0:
raise ValueError(
f"Provide arg description for not existed args: {override_without_args}"
)

return values

def _parse_input(
Expand Down
83 changes: 82 additions & 1 deletion tests/models/test_chat_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@
from unittest import mock

import pytest
from autochain.tools.base import Tool

from autochain.agent.message import UserMessage
from autochain.models.base import LLMResult
from autochain.models.chat_openai import ChatOpenAI
from autochain.models.chat_openai import ChatOpenAI, convert_tool_to_dict


def sample_tool_func_no_type(k, *arg, **kwargs):
return f"run with {k}"


def sample_tool_func_with_type(k: int, *arg, **kwargs):
return str(k + 1)


def sample_tool_func_with_type_default(k: int, d: int = 1, *arg, **kwargs):
return str(k + d + 1)


@pytest.fixture
Expand All @@ -29,3 +42,71 @@ def test_chat_completion(openai_completion_fixture):
assert isinstance(response, LLMResult)
assert len(response.generations) == 1
assert response.generations[0].message.content == "generated message"


def test_convert_tool_to_dict():
no_type_tool = Tool(
func=sample_tool_func_no_type,
description="""This is just a dummy tool without typing info""",
)

tool_dict = convert_tool_to_dict(no_type_tool)

assert tool_dict == {
"name": "sample_tool_func_no_type",
"description": "This is just a " "dummy tool without typing info",
"parameters": {
"type": "object",
"properties": {"k": {"type": "string"}},
"required": ["k"],
},
}

with_type_tool = Tool(
func=sample_tool_func_with_type,
description="""This is just a dummy tool with typing info""",
)

with_type_tool_dict = convert_tool_to_dict(with_type_tool)
assert with_type_tool_dict == {
"name": "sample_tool_func_with_type",
"description": "This is just a dummy tool with typing info",
"parameters": {
"type": "object",
"properties": {"k": {"type": "int"}},
"required": ["k"],
},
}

with_type_default_tool = Tool(
func=sample_tool_func_with_type_default,
description="""This is just a dummy tool with typing info""",
)

with_type_default_tool_dict = convert_tool_to_dict(with_type_default_tool)
assert with_type_default_tool_dict == {
"name": "sample_tool_func_with_type_default",
"description": "This is just a dummy tool with typing info",
"parameters": {
"type": "object",
"properties": {"k": {"type": "int"}, "d": {"type": "int"}},
"required": ["k"],
},
}

with_type_and_desp_tool = Tool(
func=sample_tool_func_with_type,
description="""This is just a dummy tool with typing info""",
arg_description={"k": "key of the arg"},
)

with_type_and_desp_tool_dict = convert_tool_to_dict(with_type_and_desp_tool)
assert with_type_and_desp_tool_dict == {
"name": "sample_tool_func_with_type",
"description": "This is just a dummy tool with typing info",
"parameters": {
"type": "object",
"properties": {"k": {"type": "int", "description": "key of the arg"}},
"required": ["k"],
},
}
47 changes: 47 additions & 0 deletions tests/tools/test_base_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from autochain.tools.base import Tool


def sample_tool_func(k, *arg, **kwargs):
return f"run with {k}"


def test_run_tool():
tool = Tool(
func=sample_tool_func,
description="""This is just a dummy tool""",
)

output = tool.run("test")
assert output == "run with test"


def test_tool_name_override():
new_test_name = "new_name"
tool = Tool(
name=new_test_name,
func=sample_tool_func,
description="""This is just a dummy tool""",
)

assert tool.name == new_test_name


def test_arg_description():
valid_arg_description = {"k": "key of the arg"}

invalid_arg_description = {"not_k": "key of the arg"}

_ = Tool(
func=sample_tool_func,
description="""This is just a dummy tool""",
arg_description=valid_arg_description,
)

with pytest.raises(ValueError):
_ = Tool(
func=sample_tool_func,
description="""This is just a dummy tool""",
arg_description=invalid_arg_description,
)

0 comments on commit 74bf5b3

Please sign in to comment.