Skip to content

Commit

Permalink
backend: Tool config major refactoring (#822)
Browse files Browse the repository at this point in the history
* replace with iD

* wip

* wip

* Wip

* more wip

* Fix coral web unit tests

* Resolve coral_web tests

* Fix last case

* Fix coral-web test

* Refactor get_available_tools to use function directly

* wip

* wip

* wip, community tools todo

* Fix tests

* Lint

* wip

* Fix all unit tests

* Remove makefile change

* Fix unit tests

* Fixing tests wip

* Fix lint
  • Loading branch information
tianjing-li authored Nov 6, 2024
1 parent 0bb1f67 commit e629a8a
Show file tree
Hide file tree
Showing 69 changed files with 1,000 additions and 957 deletions.
2 changes: 0 additions & 2 deletions docs/config_details/config_description.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
- redis - Redis configurations
- url - URL of the redis, for example, redis://:redis@redis:6379
- tools - Tool configurations
- enabled_tools - these are the tools that are enabled for the toolkit. The full list of tools can be found in the src/backend/config/tools.py file.
The community tools are listed in the src/community/config/tools.py file. Please note that the tools availability is checked too.
- python_interpreter - Python interpreter configurations
- url - URL of the python interpreter tool
- feature_flags - Feature flags configurations
Expand Down
50 changes: 23 additions & 27 deletions docs/custom_tool_guides/tool_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ from community.tools import BaseTool


class ArxivRetriever(BaseTool):
NAME = "arxiv"
ID = "arxiv"

def __init__(self):
self.client = ArxivAPIWrapper()
Expand All @@ -64,6 +64,27 @@ class ArxivRetriever(BaseTool):
def is_available(cls) -> bool:
return True

@classmethod
# You will need to add a tool definition here
def get_tool_definition(cls) -> ToolDefinition:
return ToolDefinition(
name=cls.ID,
display_name="Arxiv",
implementation=cls,
parameter_definitions={
"query": {
"description": "Query for retrieval.",
"type": "str",
"required": True,
}
},
is_visible=False,
is_available=cls.is_available(),
error_message=cls.generate_error_message(),
category=ToolCategory.DataLoader,
description="Retrieves documents from Arxiv.",
)

# Your tool needs to implement this call() method
def call(self, parameters: str, **kwargs: Any) -> List[Dict[str, Any]]:
result = self.client.run(parameters)
Expand All @@ -84,34 +105,9 @@ return [{"text": "The fox is blue", "url": "wikipedia.org/foxes", "title": "Colo

Next, add your tool class to the init file by locating it in `src/community/tools/__init__.py`. Import your tool here, then add it to the `__all__` list.

To enable your tool, you will need to go to the `configuration.yaml` file and add your tool's name to the list of `enabled_tools`. This tool name will correspond to the one defined in the `NAME` attribute of your class.

Finally, you will need to add your tool definition to the config file. Locate it in `src/community/config/tools.py`, and import your tool at the top with `from backend.tools import ..`.

In the ToolName enum, add your tool as an enum value. For example, `My_Tool = MyTool.NAME`.

In the `ALL_TOOLS` dictionary, add your tool definition. This should look like:

```python
ToolName.My_Tool: ManagedTool( # THE TOOLNAME HERE CORRESPONDS TO THE ENUM YOU DEFINED EARLIER
display_name="My Tool",
implementation=MyTool, # THIS IS THE CLASS YOU IMPORTED AT THE TOP
parameter_definitions={ # THESE ARE PARAMS THE MODEL WILL SEND TO YOUR TOOL, ADJUST AS NEEDED
"query": {
"description": "Query to search with",
"type": "str",
"required": True,
}
},
is_visible=True,
is_available=MyTool.is_available(),
auth_implementation=None, # EMPTY IF NO AUTH NEEDED
error_message="Something went wrong",
category=Category.DataLoader, # CHECK CATEGORY ENUM FOR POSSIBLE VALUES
description="An example definition to get you started.",
),
```

Finally, to enable your tool, add your tool as an enum value. For example, `My_Tool = MyToolClass`.

## Step 5: Test Your Tool!

Expand Down
2 changes: 1 addition & 1 deletion docs/how_to_guides.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The core chat interface is the Coral frontend. To implement your own interface:
If you have already created a [connector](https://docs.cohere.com/docs/connectors), you can utilize it within the toolkit by following these steps:

1. Configure your connector using `ConnectorRetriever`.
2. Add its definition in [community/config/tools.py](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py), following the `Arxiv` implementation, using the category `Category.DataLoader`.
2. Add its definition in [community/config/tools.py](https://github.com/cohere-ai/cohere-toolkit/blob/main/src/community/config/tools.py), following the `Arxiv` implementation, using the category `ToolCategory.DataLoader`.

You can now use both the Coral frontend and API with your connector.

Expand Down
16 changes: 8 additions & 8 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@
from backend.chat.custom.tool_calls import async_call_tools
from backend.chat.custom.utils import get_deployment
from backend.chat.enums import StreamEvent
from backend.config.tools import AVAILABLE_TOOLS
from backend.config.tools import get_available_tools
from backend.database_models.file import File
from backend.model_deployments.base import BaseDeployment
from backend.schemas.chat import ChatMessage, ChatRole, EventState
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context
from backend.schemas.tool import Category, Tool
from backend.schemas.tool import Tool, ToolCategory
from backend.services.chat import check_death_loop
from backend.services.file import get_file_service
from backend.tools.utils.tools_checkers import tool_has_category

MAX_STEPS = 15


class CustomChat(BaseChat):
"""Custom chat flow not using integrations for models."""

Expand Down Expand Up @@ -163,7 +162,7 @@ async def call_chat(
file_reader_tools_names = []
if managed_tools:
chat_request.tools = managed_tools
file_reader_tools_names = [tool.name for tool in managed_tools_full_schema if tool_has_category(tool, Category.FileLoader)]
file_reader_tools_names = [tool.name for tool in managed_tools_full_schema if tool_has_category(tool, ToolCategory.FileLoader)]

# Get files if available
all_files = []
Expand Down Expand Up @@ -248,17 +247,18 @@ def update_chat_history_with_tool_results(
chat_request.chat_history.extend(tool_results)

def get_managed_tools(self, chat_request: CohereChatRequest, full_schema=False):
available_tools = get_available_tools()
if full_schema:
return [
AVAILABLE_TOOLS.get(tool.name)
available_tools.get(tool.name)
for tool in chat_request.tools
if AVAILABLE_TOOLS.get(tool.name)
if available_tools.get(tool.name)
]

return [
Tool(**AVAILABLE_TOOLS.get(tool.name).model_dump())
Tool(**available_tools.get(tool.name).model_dump())
for tool in chat_request.tools
if AVAILABLE_TOOLS.get(tool.name)
if available_tools.get(tool.name)
]

def add_files_to_chat_history(
Expand Down
4 changes: 2 additions & 2 deletions src/backend/chat/custom/tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.orm import Session

from backend.chat.collate import rerank_and_chunk, to_dict
from backend.config.tools import AVAILABLE_TOOLS
from backend.config.tools import get_available_tools
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services.logger.utils import LoggerFactory
Expand Down Expand Up @@ -76,7 +76,7 @@ async def _call_tool_async(
tool_call: dict,
deployment_model: BaseDeployment,
) -> List[Dict[str, Any]]:
tool = AVAILABLE_TOOLS.get(tool_call["name"])
tool = get_available_tools().get(tool_call["name"])
if not tool:
logger.info(
event=f"[Custom Chat] Tool not included in tools parameter: {tool_call['name']}",
Expand Down
18 changes: 15 additions & 3 deletions src/backend/cli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ class BuildTarget(StrEnum):
PROD = "prod"


class ToolName(StrEnum):
class Tool(StrEnum):
PythonInterpreter = "Python Interpreter"
TavilyInternetSearch = "Tavily Internet Search"
Wolfram_Alpha = "Wolfram Alpha"



WELCOME_MESSAGE = r"""
Expand All @@ -50,18 +52,28 @@ class ToolName(StrEnum):


TOOLS = {
ToolName.PythonInterpreter: {
Tool.PythonInterpreter: {
"secrets": {
"PYTHON_INTERPRETER_URL": PYTHON_INTERPRETER_URL_DEFAULT,
},
},
ToolName.TavilyInternetSearch: {
Tool.TavilyInternetSearch: {
"secrets": {
"TAVILY_API_KEY": None,
},
},
}

# For main.py cli setup script
COMMUNITY_TOOLS = {
Tool.Wolfram_Alpha: {
"secrets": {
"WOLFRAM_APP_ID": None, # default value
},
},
}


ENV_YAML_CONFIG_MAPPING = {
"USE_COMMUNITY_FEATURES": {
"type": "config",
Expand Down
5 changes: 2 additions & 3 deletions src/backend/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse

from backend.cli.constants import TOOLS
from backend.cli.constants import COMMUNITY_TOOLS, TOOLS
from backend.cli.prompts import (
PROMPTS,
community_tools_prompt,
Expand All @@ -23,7 +23,6 @@
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)
from community.config.tools import COMMUNITY_TOOLS_SETUP


def start():
Expand All @@ -43,7 +42,7 @@ def start():
# ENABLE COMMUNITY TOOLS
use_community_features = args.use_community and community_tools_prompt(secrets)
if use_community_features:
TOOLS.update(COMMUNITY_TOOLS_SETUP)
TOOLS.update(COMMUNITY_TOOLS)

# SET UP TOOLS
for name, configs in TOOLS.items():
Expand Down
8 changes: 0 additions & 8 deletions src/backend/config/configuration.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ database:
redis:
url: redis://:redis@redis:6379
tools:
enabled_tools:
- wikipedia
- search_file
- read_file
- toolkit_python_interpreter
- toolkit_calculator
- hybrid_web_search
- web_scrape
hybrid_web_search:
# List of web search tool names, eg: google_web_search, tavily_web_search
enabled_web_searches:
Expand Down
1 change: 0 additions & 1 deletion src/backend/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ class HybridWebSearchSettings(BaseSettings, BaseModel):

class ToolSettings(BaseSettings, BaseModel):
model_config = SETTINGS_CONFIG
enabled_tools: Optional[List[str]] = None

python_interpreter: Optional[PythonToolSettings] = Field(
default=PythonToolSettings()
Expand Down
Loading

0 comments on commit e629a8a

Please sign in to comment.